vizualization_fa28_behavior28subtractedbyBaseline.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import os
  2. import pandas as pd
  3. import numpy as np
  4. import seaborn as sns
  5. import matplotlib.pyplot as plt
  6. from scipy.stats import pearsonr, spearmanr, shapiro
  7. from statsmodels.stats.multitest import multipletests
  8. from tqdm import tqdm
  9. # Get the directory where the code file is located
  10. code_dir = os.path.dirname(os.path.abspath(__file__))
  11. parent_dir = os.path.dirname(code_dir)
  12. # Load the CSV data
  13. input_file_path = os.path.join(parent_dir, 'output', "Quantitative_outputs", 'Quantitative_results_from_dwi_processing_merged_with_behavior_data.csv')
  14. df = pd.read_csv(input_file_path, low_memory=False)
  15. # Remove duplicate rows
  16. df = df.drop_duplicates()
  17. # Remove specified columns
  18. df = df.drop(columns=['is_it_AMBA', '2 Cluster', '3 Cluster', '4 Cluster', '5 Cluster', '6 Cluster', 'Exclude_Aref', 'Voting outliers (from 5)', 'fullpath'])
  19. # Convert columns that should be numeric to numeric, coercing errors to NaN
  20. numeric_cols = ['Value', 'DeficitScore']
  21. for col in numeric_cols:
  22. if col in df.columns:
  23. df[col] = pd.to_numeric(df[col], errors='coerce')
  24. # Merge data for timepoint 0 with timepoint 28 for DeficitScore change calculation
  25. merged_df_0_28 = pd.merge(df[df['merged_timepoint'] == 0],
  26. df[df['merged_timepoint'] == 28],
  27. on=['subjectID', 'Group', 'Qtype', 'mask_name', 'dialation_amount'],
  28. suffixes=('_0', '_28'))
  29. # Store all p-values for multiple test correction
  30. all_p_values = []
  31. # Define the settings for selected masks, groups, qtype, and timepoints
  32. selected_groups = ["Sham"]
  33. selected_qtype = "fa"
  34. selected_mask_names = [
  35. "CRuT_MOp-RN_ipsilesional_CCcut",
  36. "CRuT_MOp-RN_contralesional_mirrored_CCcut",
  37. "CReT_MOp-TRN_ipsilesional_CCcut",
  38. "CReT_MOp-TRN_contralesional_CCcut",
  39. "CST_MOp-int-py_ipsilesional_selfdrawnROA+CCcut",
  40. "CST_MOp-int-py_contralesional_selfdrawnROA+CCcut",
  41. "TC_DORsm-SSp_ll+SSp_ul_ipsilesional_END+higherstepsize_",
  42. "TC_DORsm-SSp_ll+SSp_ul_contralesional_END+higherstepsize",
  43. "CC_MOp-MOp_cut",
  44. "OT_och-lgn_lgncut"
  45. ]
  46. #selected_timepoints = [28]
  47. selected_dialation_amount = 0
  48. # Define simple anatomical names for the masks
  49. mask_name_mapping = {
  50. "CC_MOp-MOp_cut": "Corpus Callosum",
  51. "CRuT_MOp-RN_ipsilesional_CCcut": "Rubropsinal (Ipsilesional)",
  52. "CRuT_MOp-RN_contralesional_mirrored_CCcut": "Rubropsinal (Contralesional)",
  53. "TC_DORsm-SSp_ll+SSp_ul_ipsilesional_END+higherstepsize_": "Thalamocortical (Ipsilesional)",
  54. "TC_DORsm-SSp_ll+SSp_ul_contralesional_END+higherstepsize": "Thalamocortical (Contralesional)",
  55. "CReT_MOp-TRN_contralesional_CCcut": "Reticulospinal (Contralesional)",
  56. "CReT_MOp-TRN_ipsilesional_CCcut": "Reticulospinal (Ipsilesional)",
  57. "CST_MOp-int-py_contralesional_selfdrawnROA+CCcut": "Corticospinal (Contralesional)",
  58. "CST_MOp-int-py_ipsilesional_selfdrawnROA+CCcut": "Corticospinal (Ipsilesional)",
  59. "OT_och-lgn_lgncut": "Optic"
  60. }
  61. # Create the figure for subplots
  62. fig, axes = plt.subplots(5, 2, figsize=(14 / 2.54, 25 / 2.54)) # 2 columns, 5 rows
  63. axes = axes.flatten() # Flatten the axes array for easy iteration
  64. # Iterate through each mask name
  65. for idx, selected_mask_name in enumerate(selected_mask_names):
  66. # Filter the merged DataFrame for the current settings
  67. df_filtered_28 = merged_df_0_28[(merged_df_0_28["Group"].isin(selected_groups)) &
  68. (merged_df_0_28["Qtype"] == selected_qtype) &
  69. (merged_df_0_28["mask_name"] == selected_mask_name) &
  70. (merged_df_0_28["dialation_amount"] == selected_dialation_amount)]
  71. df_filtered_28 = df_filtered_28.drop_duplicates()
  72. # Ensure there are enough data points for correlation
  73. if len(df_filtered_28) >= 3:
  74. # Calculate the change in DeficitScore between timepoint 0 and 28
  75. fa_value_28 = df_filtered_28['Value_28']
  76. deficit_score_change_28 = df_filtered_28['DeficitScore_28'] - df_filtered_28['DeficitScore_0']
  77. # Remove rows with NaN or infinite values
  78. valid_idx = ~(fa_value_28.isin([np.nan, np.inf, -np.inf]) | deficit_score_change_28.isin([np.nan, np.inf, -np.inf]))
  79. fa_value_28 = fa_value_28[valid_idx]
  80. deficit_score_change_28 = deficit_score_change_28[valid_idx]
  81. # Perform Shapiro-Wilk test for normality
  82. if len(fa_value_28) >= 3 and len(deficit_score_change_28) >= 3:
  83. shapiro_statValue, shapiro_pvalueQValue = shapiro(fa_value_28)
  84. shapiro_statScore, shapiro_pvalueBehavior = shapiro(deficit_score_change_28)
  85. # Use Pearson or Spearman correlation based on normality test
  86. if shapiro_pvalueQValue < 0.05 or shapiro_pvalueBehavior < 0.05:
  87. # Use Spearman correlation if data is not normally distributed
  88. correlation_coefficient, p_value = spearmanr(fa_value_28, deficit_score_change_28)
  89. else:
  90. # Use Pearson correlation if data is normally distributed
  91. correlation_coefficient, p_value = pearsonr(fa_value_28, deficit_score_change_28)
  92. # Store p-value for multiple testing correction
  93. all_p_values.append(p_value)
  94. # Plot the regression
  95. ax = axes[idx] # Select the appropriate subplot axis
  96. sns.regplot(
  97. x=fa_value_28, y=deficit_score_change_28, ci=95, ax=ax,
  98. line_kws={'color': 'lightcoral', 'alpha': 0.6}, # Regression line in light red with higher transparency
  99. scatter_kws={'color': 'red', 'alpha': 0.6, 's': 10}, # Scatter points in red
  100. label=f'Stroke (R={correlation_coefficient:.2f})'
  101. )
  102. # Determine significance level for R value
  103. pval_text = ''
  104. if p_value < 0.001:
  105. pval_text = '***'
  106. elif p_value < 0.01:
  107. pval_text = '**'
  108. elif p_value < 0.05:
  109. pval_text = '*'
  110. # Add text annotation for R and significance level
  111. ax.text(0.05, 0.95, f'R={correlation_coefficient:.2f}{pval_text}', transform=ax.transAxes,
  112. fontsize=10, verticalalignment='top', bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0))
  113. # Set labels and customize font settings
  114. anatomical_name = mask_name_mapping.get(selected_mask_name, selected_mask_name)
  115. ax.set_xlabel(f'{selected_qtype} Value (Timepoint 28) - {anatomical_name}', fontsize=10)
  116. ax.set_ylabel('DS Change [28 - 0]', fontsize=10)
  117. anatomical_name = mask_name_mapping.get(selected_mask_name, selected_mask_name)
  118. ax.set_xlabel(f'{selected_qtype}-{anatomical_name}', fontsize=10)
  119. anatomical_name = mask_name_mapping.get(selected_mask_name, selected_mask_name)
  120. # Remove gridlines
  121. ax.grid(False)
  122. sns.despine(ax=ax) # Remove the top and right spines for cleaner look
  123. else:
  124. print(f"Not enough data for combination: Group=Stroke, Qtype={selected_qtype}, Mask={selected_mask_name}, Dilation={selected_dialation_amount}")
  125. # Adjust layout
  126. plt.tight_layout()
  127. # Save the complete figure with all subplots
  128. figures_folder = os.path.join(parent_dir, 'output', 'Figures')
  129. pythonFigs_folder = os.path.join(figures_folder, 'pythonFigs')
  130. os.makedirs(pythonFigs_folder, exist_ok=True)
  131. file_suffix = f'timepoint_{selected_timepoints[0]}_dilation_{selected_dialation_amount}_groups_{"-".join(selected_groups)}'
  132. fig_filename = f'qtype_{selected_qtype}_behavior_change_correlation_all_masks_{file_suffix}_groups_comparison.svg'
  133. plt.savefig(os.path.join(pythonFigs_folder, fig_filename), dpi=300, bbox_inches='tight', format="svg", transparent=True)
  134. # Show the figure
  135. plt.show()