123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- import os
- import pandas as pd
- import numpy as np
- import matplotlib.pyplot as plt
- from tqdm import tqdm
- # Define cm for converting cm to inches
- cm = 1 / 2.54
- # Get the directory where the code file is located
- code_dir = os.path.dirname(os.path.abspath(__file__))
- # Get the parent directory of the code directory
- parent_dir = os.path.dirname(code_dir)
- # Define the path for the input CSV files
- original_file_path = os.path.join(parent_dir, 'output', 'Quantitative_outputs',"onlyCST", 'Quantitative_results_from_dwi_processing_withOutWM_and_partialCST.csv')
- results_file_path = os.path.join(parent_dir, 'output', 'Quantitative_outputs',"onlyCST", 'Significance_stroke_vs_sham_difference_withoutWMmask_partialCST.csv')
- # Define the path for the output folders to save plots
- plots_output_dir = os.path.join(parent_dir, 'output', 'Figures', 'pythonFigs')
- fa_over_time_plots_dir = os.path.join(parent_dir, 'output', 'Figures', 'fa_over_time_plots')
- os.makedirs(fa_over_time_plots_dir, exist_ok=True)
- os.makedirs(plots_output_dir, exist_ok=True)
- # Load the original dataset for analysis
- df = pd.read_csv(original_file_path, low_memory=False)
- # Load the results CSV
- results_df = pd.read_csv(results_file_path)
- # Filter results to exclude those with "AMBA" in the mask name
- results_df = results_df[~results_df['mask_name'].str.contains("AMBA")]
- # Define functions to map abbreviations and locations
- def map_abbreviation(mask_name):
- if mask_name.startswith("CC"):
- return "CC"
- elif mask_name.startswith("CRuT"):
- return "RS"
- elif mask_name.startswith("CReT"):
- return "RetS"
- elif mask_name.startswith("CST"):
- return "CST"
- elif mask_name.startswith("TC"):
- return "TC"
- elif mask_name.startswith("OT"):
- return "OT"
- else:
- return "Unknown"
- def map_location(mask_name):
- if "ipsi" in mask_name:
- return "Ips"
- elif "contra" in mask_name:
- return "Con"
- else:
- return "None"
- # Add new columns to the dataframe for abbreviation and location
- results_df['abbreviation'] = results_df['mask_name'].apply(map_abbreviation)
- results_df['location'] = results_df['mask_name'].apply(map_location)
- # Get unique time points and qtypes
- timepoints = results_df['merged_timepoint'].unique()
- qtypes = results_df['Qtype'].unique()
- # Define different marker shapes for each unique abbreviation
- unique_abbreviations = results_df['abbreviation'].unique()
- markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'X', 'h']
- marker_mapping = {abbr: markers[i % len(markers)] for i, abbr in enumerate(unique_abbreviations)}
- # Iterate over each time point and Qtype to create individual volcano plots
- for timepoint in timepoints:
- for qtype in qtypes:
- subset_df = results_df[(results_df['merged_timepoint'] == timepoint) & (results_df['Qtype'] == qtype)]
- # Skip if there is no data for the specific subset
- if subset_df.empty:
- continue
- # Calculate mean difference for the current subset
- mean_diff = []
- with tqdm(total=len(subset_df), desc=f"Calculating mean differences for {timepoint}, Qtype: {qtype}") as pbar:
- for _, row in subset_df.iterrows():
- mask = row['mask_name']
- # Filter original data for Stroke and Sham
- stroke_values = df[(df['Group'] == 'Stroke') &
- (df['mask_name'] == mask) &
- (df['merged_timepoint'] == timepoint) &
- (df['dialation_amount'] == row['dialation_amount']) &
- (df['Qtype'] == qtype)]['Value'].dropna()
- sham_values = df[(df['Group'] == 'Sham') &
- (df['mask_name'] == mask) &
- (df['merged_timepoint'] == timepoint) &
- (df['dialation_amount'] == row['dialation_amount']) &
- (df['Qtype'] == qtype)]['Value'].dropna()
- # Calculate mean difference
- if len(stroke_values) > 0 and len(sham_values) > 0:
- mean_diff.append(stroke_values.mean() - sham_values.mean())
- else:
- mean_diff.append(np.nan)
-
- # Update progress bar
- pbar.update(1)
- subset_df['Mean_Difference'] = mean_diff
- subset_df['-log10(Pvalue)'] = -np.log10(subset_df['Pvalue'])
- # Plot the volcano plot for the current time point and Qtype
- plt.figure(figsize=(8 * cm, 8 * cm), dpi=300) # 8 cm by 8 cm in inches, with high DPI for better quality
- # Plot each mask using its corresponding marker shape and location suffix
- for abbr in unique_abbreviations:
- abbr_subset = subset_df[subset_df['abbreviation'] == abbr]
- for location in abbr_subset['location'].unique():
- loc_subset = abbr_subset[abbr_subset['location'] == location]
- label = f"{abbr} ({location})" if location != "None" else abbr
- plt.scatter(loc_subset['Mean_Difference'], loc_subset['-log10(Pvalue)'],
- alpha=0.7, s=10, marker=marker_mapping[abbr], label=label)
- # Labels and title for each plot
- plt.axhline(y=-np.log10(0.05), color='blue', linestyle='--')
- plt.xlabel('Mean Difference (Stroke - Sham)', fontsize=12, fontname='Calibri')
- plt.ylabel('-log10(Pvalue)', fontsize=12, fontname='Calibri')
- plt.title(f'Volcano Plot: {qtype} for {timepoint}', fontsize=12, fontname='Calibri')
- plt.grid(False)
- # Create the legend with marker shapes
- plt.legend(loc='best', fontsize=6, frameon=False)
- # Save the plot as an SVG file
- plot_file_name = f'volcano_plot_{timepoint}_{qtype}.svg'
- plot_file_path = os.path.join(plots_output_dir, plot_file_name)
- plt.savefig(plot_file_path, format='svg', bbox_inches='tight')
- plt.show()
|