volcano_plot.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import os
  2. import pandas as pd
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from tqdm import tqdm
  6. # Define cm for converting cm to inches
  7. cm = 1 / 2.54
  8. # Get the directory where the code file is located
  9. code_dir = os.path.dirname(os.path.abspath(__file__))
  10. # Get the parent directory of the code directory
  11. parent_dir = os.path.dirname(code_dir)
  12. # Define the path for the input CSV files
  13. original_file_path = os.path.join(parent_dir, 'output', 'Quantitative_outputs', 'old', 'Quantitative_results_from_dwi_processing.csv')
  14. results_file_path = os.path.join(parent_dir, 'output', 'Quantitative_outputs', 'Significance_stroke_vs_sham_difference.csv')
  15. # Define the path for the output folders to save plots
  16. plots_output_dir = os.path.join(parent_dir, 'output', 'Figures', 'pythonFigs')
  17. fa_over_time_plots_dir = os.path.join(parent_dir, 'output', 'Figures', 'fa_over_time_plots')
  18. os.makedirs(fa_over_time_plots_dir, exist_ok=True)
  19. os.makedirs(plots_output_dir, exist_ok=True)
  20. # Load the original dataset for analysis
  21. df = pd.read_csv(original_file_path, low_memory=False)
  22. # Load the results CSV
  23. results_df = pd.read_csv(results_file_path)
  24. # Filter results to exclude those with "AMBA" in the mask name
  25. results_df = results_df[~results_df['mask_name'].str.contains("AMBA")]
  26. # Define functions to map abbreviations and locations
  27. def map_abbreviation(mask_name):
  28. if mask_name.startswith("CC"):
  29. return "CC"
  30. elif mask_name.startswith("CRuT"):
  31. return "RS"
  32. elif mask_name.startswith("CReT"):
  33. return "RetS"
  34. elif mask_name.startswith("CST"):
  35. return "CST"
  36. elif mask_name.startswith("TC"):
  37. return "TC"
  38. elif mask_name.startswith("OT"):
  39. return "OT"
  40. else:
  41. return "Unknown"
  42. def map_location(mask_name):
  43. if "ipsi" in mask_name:
  44. return "Ips"
  45. elif "contra" in mask_name:
  46. return "Con"
  47. else:
  48. return "None"
  49. # Add new columns to the dataframe for abbreviation and location
  50. results_df['abbreviation'] = results_df['mask_name'].apply(map_abbreviation)
  51. results_df['location'] = results_df['mask_name'].apply(map_location)
  52. # Get unique time points and qtypes
  53. timepoints = results_df['merged_timepoint'].unique()
  54. qtypes = results_df['Qtype'].unique()
  55. # Define different marker shapes for each unique abbreviation
  56. unique_abbreviations = results_df['abbreviation'].unique()
  57. markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'X', 'h']
  58. marker_mapping = {abbr: markers[i % len(markers)] for i, abbr in enumerate(unique_abbreviations)}
  59. # Iterate over each time point and Qtype to create individual volcano plots
  60. for timepoint in timepoints:
  61. for qtype in qtypes:
  62. subset_df = results_df[(results_df['merged_timepoint'] == timepoint) & (results_df['Qtype'] == qtype)]
  63. # Skip if there is no data for the specific subset
  64. if subset_df.empty:
  65. continue
  66. # Calculate mean difference for the current subset
  67. mean_diff = []
  68. with tqdm(total=len(subset_df), desc=f"Calculating mean differences for {timepoint}, Qtype: {qtype}") as pbar:
  69. for _, row in subset_df.iterrows():
  70. mask = row['mask_name']
  71. # Filter original data for Stroke and Sham
  72. stroke_values = df[(df['Group'] == 'Stroke') &
  73. (df['mask_name'] == mask) &
  74. (df['merged_timepoint'] == timepoint) &
  75. (df['dialation_amount'] == row['dialation_amount']) &
  76. (df['Qtype'] == qtype)]['Value'].dropna()
  77. sham_values = df[(df['Group'] == 'Sham') &
  78. (df['mask_name'] == mask) &
  79. (df['merged_timepoint'] == timepoint) &
  80. (df['dialation_amount'] == row['dialation_amount']) &
  81. (df['Qtype'] == qtype)]['Value'].dropna()
  82. # Calculate mean difference
  83. if len(stroke_values) > 0 and len(sham_values) > 0:
  84. mean_diff.append(stroke_values.mean() - sham_values.mean())
  85. else:
  86. mean_diff.append(np.nan)
  87. # Update progress bar
  88. pbar.update(1)
  89. subset_df['Mean_Difference'] = mean_diff
  90. subset_df['-log10(Pvalue)'] = -np.log10(subset_df['Pvalue'])
  91. # Plot the volcano plot for the current time point and Qtype
  92. plt.figure(figsize=(8 * cm, 8 * cm), dpi=300) # 8 cm by 8 cm in inches, with high DPI for better quality
  93. # Plot each mask using its corresponding marker shape and location suffix
  94. for abbr in unique_abbreviations:
  95. abbr_subset = subset_df[subset_df['abbreviation'] == abbr]
  96. for location in abbr_subset['location'].unique():
  97. loc_subset = abbr_subset[abbr_subset['location'] == location]
  98. label = f"{abbr} ({location})" if location != "None" else abbr
  99. plt.scatter(loc_subset['Mean_Difference'], loc_subset['-log10(Pvalue)'],
  100. alpha=0.7, s=10, marker=marker_mapping[abbr], label=label)
  101. # Labels and title for each plot
  102. plt.axhline(y=-np.log10(0.05), color='blue', linestyle='--')
  103. plt.xlabel('Mean Difference (Stroke - Sham)', fontsize=12, fontname='Calibri')
  104. plt.ylabel('-log10(Pvalue)', fontsize=12, fontname='Calibri')
  105. plt.title(f'Volcano Plot: {qtype} for {timepoint}', fontsize=12, fontname='Calibri')
  106. plt.grid(False)
  107. # Create the legend with marker shapes
  108. plt.legend(loc='best', fontsize=6, frameon=False)
  109. # Save the plot as an SVG file
  110. plot_file_name = f'volcano_plot_{timepoint}_{qtype}.svg'
  111. plot_file_path = os.path.join(plots_output_dir, plot_file_name)
  112. plt.savefig(plot_file_path, format='svg', bbox_inches='tight')
  113. plt.show()