volcano_plot_BL_P28_change.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Tue Nov 12 15:23:26 2024
  4. @author: arefks
  5. """
  6. import os
  7. import pandas as pd
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. from tqdm import tqdm
  11. # Define cm for converting cm to inches
  12. cm = 1 / 2.54
  13. # Get the directory where the code file is located
  14. code_dir = os.path.dirname(os.path.abspath(__file__))
  15. # Get the parent directory of the code directory
  16. parent_dir = os.path.dirname(code_dir)
  17. # Define the path for the input CSV files
  18. original_file_path = os.path.join(parent_dir, 'output', 'Quantitative_outputs', 'Quantitative_results_from_dwi_processing_only_in_stroke_affected_slices.csv')
  19. results_file_path = os.path.join(parent_dir, 'output', 'Quantitative_outputs', 'Significance_timepoint_0_vs_28_only_in_stroke_slices.csv')
  20. # Define the path for the output folders to save plots
  21. plots_output_dir = os.path.join(parent_dir, 'output', 'Figures', 'pythonFigs')
  22. fa_over_time_plots_dir = os.path.join(parent_dir, 'output', 'Figures', 'fa_over_time_plots')
  23. os.makedirs(fa_over_time_plots_dir, exist_ok=True)
  24. os.makedirs(plots_output_dir, exist_ok=True)
  25. # Load the original dataset for analysis
  26. df = pd.read_csv(original_file_path, low_memory=False)
  27. # Filter out Sham mice
  28. df = df[df['Group'] != 'Sham']
  29. # Load the results CSV
  30. results_df = pd.read_csv(results_file_path)
  31. # Filter out Sham mice from results_df
  32. results_df = results_df[results_df['Group'] != 'Sham']
  33. # Filter results to exclude those with "AMBA" in the mask name
  34. results_df = results_df[~results_df['mask_name'].str.contains("AMBA")]
  35. # Define functions to map abbreviations and locations
  36. def map_abbreviation(mask_name):
  37. if mask_name.startswith("CC"):
  38. return "CC"
  39. elif mask_name.startswith("CRuT"):
  40. return "CRuT"
  41. elif mask_name.startswith("CReT"):
  42. return "CReT"
  43. elif mask_name.startswith("CST"):
  44. return "CST"
  45. elif mask_name.startswith("TC"):
  46. return "TC"
  47. elif mask_name.startswith("OT"):
  48. return "OT"
  49. else:
  50. return "Unknown"
  51. def map_location(mask_name):
  52. if "ipsi" in mask_name:
  53. return "Ips"
  54. elif "contra" in mask_name:
  55. return "Con"
  56. else:
  57. return "None"
  58. # Add new columns to the dataframe for abbreviation and location
  59. results_df['abbreviation'] = results_df['mask_name'].apply(map_abbreviation)
  60. results_df['location'] = results_df['mask_name'].apply(map_location)
  61. # Get unique abbreviations and locations
  62. abbreviations = results_df['abbreviation'].unique()
  63. locations = results_df['location'].unique()
  64. qtypes = results_df['Qtype'].unique()
  65. # Define different marker shapes for each unique Qtype
  66. markers = ['o', 's', '^', 'D']
  67. marker_mapping = {qtype: markers[i % len(markers)] for i, qtype in enumerate(qtypes)}
  68. # Flag to toggle displaying the highest point labels
  69. MaxPlotter = True
  70. # To store the highest points details for printing at the end
  71. highest_points_details = []
  72. # Iterate over each abbreviation and location to create individual plots comparing timepoint 0 vs 28
  73. for abbr in abbreviations:
  74. for location in locations:
  75. subset_df = results_df[(results_df['abbreviation'] == abbr) & (results_df['location'] == location)]
  76. # Skip if there is no data for the specific abbreviation and location
  77. if subset_df.empty:
  78. continue
  79. # Create a figure for the current abbreviation and location
  80. plt.figure(figsize=(6 * cm, 6 * cm), dpi=300) # 8 cm by 8 cm in inches, with high DPI for better quality
  81. mean_diff_list = []
  82. neg_log_pvalue_list = []
  83. mask_list = []
  84. qtype_list = []
  85. dialation_list = []
  86. # Iterate over each unique mask in the subset
  87. for mask in subset_df['mask_name'].unique():
  88. # Filter original data for timepoints 0 and 28 for the given mask and location
  89. timepoint_0_df = df[(df['merged_timepoint'] == 0) & (df['mask_name'] == mask)]
  90. timepoint_28_df = df[(df['merged_timepoint'] == 28) & (df['mask_name'] == mask)]
  91. # Iterate over each Qtype and dialation_amount amount to calculate mean differences and p-values
  92. for qtype in qtypes:
  93. for dialation_amount in subset_df['dialation_amount'].unique():
  94. tp0_values = timepoint_0_df[(timepoint_0_df['Qtype'] == qtype) & (timepoint_0_df['dialation_amount'] == dialation_amount)]['Value'].dropna()
  95. tp28_values = timepoint_28_df[(timepoint_28_df['Qtype'] == qtype) & (timepoint_28_df['dialation_amount'] == dialation_amount)]['Value'].dropna()
  96. # Calculate mean difference
  97. if len(tp0_values) > 0 and len(tp28_values) > 0:
  98. mean_diff = tp28_values.mean() - tp0_values.mean()
  99. else:
  100. mean_diff = np.nan
  101. # Get the corresponding p-value from results_df
  102. pvalue = results_df[(results_df['mask_name'] == mask) &
  103. (results_df['Qtype'] == qtype) &
  104. (results_df['dialation_amount'] == dialation_amount)]['Pvalue'].values
  105. if len(pvalue) > 0:
  106. neg_log_pvalue = -np.log10(pvalue[0])
  107. else:
  108. neg_log_pvalue = np.nan
  109. mean_diff_list.append(mean_diff)
  110. neg_log_pvalue_list.append(neg_log_pvalue)
  111. mask_list.append(mask)
  112. qtype_list.append(qtype)
  113. dialation_list.append(dialation_amount)
  114. # Plot the mean difference vs -log10(Pvalue) with the corresponding marker, combining all qtypes and dilations
  115. for qtype in qtypes:
  116. qtype_mean_diff = [mean_diff_list[i] for i in range(len(mean_diff_list)) if qtype_list[i] == qtype]
  117. qtype_neg_log_pvalue = [neg_log_pvalue_list[i] for i in range(len(neg_log_pvalue_list)) if qtype_list[i] == qtype]
  118. plt.scatter(qtype_mean_diff, qtype_neg_log_pvalue, alpha=0.7, s=10, marker=marker_mapping[qtype], label=qtype)
  119. # Add a vertical line at x = 0
  120. plt.axvline(x=0, color='red', linestyle='--')
  121. # Labels and title for each plot
  122. plt.axhline(y=-np.log10(0.05), color='blue', linestyle='--')
  123. plt.xlabel('Mean Difference (28 - BL)', fontsize=12, fontname='Calibri')
  124. plt.ylabel('-log10(Pvalue)', fontsize=12, fontname='Calibri')
  125. plt.title(f'{abbr},{location}', fontsize=12, fontname='Calibri')
  126. plt.grid(False)
  127. # Create the legend with marker shapes
  128. plt.legend(loc='best', fontsize=6, frameon=False)
  129. # Find and label the highest dot for each Qtype if MaxPlotter is True
  130. if MaxPlotter and len(mean_diff_list) > 0:
  131. for qtype in qtypes:
  132. qtype_indices = [i for i in range(len(mean_diff_list)) if qtype_list[i] == qtype]
  133. if qtype_indices:
  134. max_qtype_index = max(qtype_indices, key=lambda i: neg_log_pvalue_list[i])
  135. plt.text(mean_diff_list[max_qtype_index], neg_log_pvalue_list[max_qtype_index],
  136. f"{mask_list[max_qtype_index]}, d={dialation_list[max_qtype_index]}",
  137. fontsize=6, fontname='Calibri', ha='right', va='bottom')
  138. highest_points_details.append(
  139. f"Highest point for Qtype {qtype} in {abbr}, {location}: {mask_list[max_qtype_index]}, d={dialation_list[max_qtype_index]}, Mean Diff={mean_diff_list[max_qtype_index]}, -log10(P)={neg_log_pvalue_list[max_qtype_index]}"
  140. )
  141. # Save the plot as a PNG file
  142. plot_file_name = f'volcano_plot_{abbr}_{location}.png'
  143. plot_file_path = os.path.join(plots_output_dir, plot_file_name)
  144. plt.savefig(plot_file_path, format='png', bbox_inches='tight')
  145. plt.show()
  146. # Print the details of the highest points
  147. print("\nDetails of the highest points:")
  148. for detail in highest_points_details:
  149. print(detail)