vizualization_fa28subtractedbyBaseline_behavior28subtractedbyBaseline.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Tue Nov 5 17:28:49 2024
  4. @author: arefks
  5. """
  6. import os
  7. import pandas as pd
  8. import numpy as np
  9. import seaborn as sns
  10. import matplotlib.pyplot as plt
  11. from scipy.stats import pearsonr, spearmanr, shapiro
  12. from statsmodels.stats.multitest import multipletests
  13. from tqdm import tqdm
  14. # Get the directory where the code file is located
  15. code_dir = os.path.dirname(os.path.abspath(__file__))
  16. parent_dir = os.path.dirname(code_dir)
  17. # Load the CSV data
  18. input_file_path = os.path.join(parent_dir, 'output', "Quantitative_outputs", 'Quantitative_results_from_dwi_processing_merged_with_behavior_data.csv')
  19. df = pd.read_csv(input_file_path, low_memory=False)
  20. # Remove duplicate rows
  21. df = df.drop_duplicates()
  22. # Remove specified columns
  23. df = df.drop(columns=['is_it_AMBA', '2 Cluster', '3 Cluster', '4 Cluster', '5 Cluster', '6 Cluster', 'Exclude_Aref', 'Voting outliers (from 5)', 'fullpath'])
  24. # Convert columns that should be numeric to numeric, coercing errors to NaN
  25. numeric_cols = ['Value', 'DeficitScore']
  26. for col in numeric_cols:
  27. if col in df.columns:
  28. df[col] = pd.to_numeric(df[col], errors='coerce')
  29. # Merge data for timepoint 0 with timepoint 28 for DeficitScore change calculation
  30. merged_df_0_28 = pd.merge(df[df['merged_timepoint'] == 0],
  31. df[df['merged_timepoint'] == 28],
  32. on=['subjectID', 'Group', 'Qtype', 'mask_name', 'dialation_amount'],
  33. suffixes=('_0', '_28'))
  34. # Store all p-values for multiple test correction
  35. all_p_values = []
  36. # Define the settings for selected masks, groups, qtype, and timepoints
  37. selected_groups = ["Sham"]
  38. selected_qtype = "fa"
  39. selected_mask_names = [
  40. "CRuT_MOp-RN_ipsilesional_CCcut",
  41. "CRuT_MOp-RN_contralesional_mirrored_CCcut",
  42. "CReT_MOp-TRN_ipsilesional_CCcut",
  43. "CReT_MOp-TRN_contralesional_CCcut",
  44. "CST_MOp-int-py_ipsilesional_selfdrawnROA+CCcut",
  45. "CST_MOp-int-py_contralesional_selfdrawnROA+CCcut",
  46. "TC_DORsm-SSp_ll+SSp_ul_ipsilesional_END+higherstepsize_",
  47. "TC_DORsm-SSp_ll+SSp_ul_contralesional_END+higherstepsize",
  48. "CC_MOp-MOp_cut",
  49. "OT_och-lgn_lgncut"
  50. ]
  51. selected_dialation_amount = 0
  52. # Define simple anatomical names for the masks
  53. mask_name_mapping = {
  54. "CC_MOp-MOp_cut": "Corpus Callosum",
  55. "CRuT_MOp-RN_ipsilesional_CCcut": "Rubropsinal (Ipsilesional)",
  56. "CRuT_MOp-RN_contralesional_mirrored_CCcut": "Rubropsinal (Contralesional)",
  57. "TC_DORsm-SSp_ll+SSp_ul_ipsilesional_END+higherstepsize_": "Thalamocortical (Ipsilesional)",
  58. "TC_DORsm-SSp_ll+SSp_ul_contralesional_END+higherstepsize": "Thalamocortical (Contralesional)",
  59. "CReT_MOp-TRN_contralesional_CCcut": "Reticulospinal (Contralesional)",
  60. "CReT_MOp-TRN_ipsilesional_CCcut": "Reticulospinal (Ipsilesional)",
  61. "CST_MOp-int-py_contralesional_selfdrawnROA+CCcut": "Corticospinal (Contralesional)",
  62. "CST_MOp-int-py_ipsilesional_selfdrawnROA+CCcut": "Corticospinal (Ipsilesional)",
  63. "OT_och-lgn_lgncut": "Optic"
  64. }
  65. # Create the figure for subplots
  66. fig, axes = plt.subplots(5, 2, figsize=(14 / 2.54, 25 / 2.54)) # 2 columns, 5 rows
  67. axes = axes.flatten() # Flatten the axes array for easy iteration
  68. # Iterate through each mask name
  69. for idx, selected_mask_name in enumerate(selected_mask_names):
  70. # Filter the merged DataFrame for the current settings
  71. df_filtered_28 = merged_df_0_28[(merged_df_0_28["Group"].isin(selected_groups)) &
  72. (merged_df_0_28["Qtype"] == selected_qtype) &
  73. (merged_df_0_28["mask_name"] == selected_mask_name) &
  74. (merged_df_0_28["dialation_amount"] == selected_dialation_amount)]
  75. df_filtered_28 = df_filtered_28.drop_duplicates()
  76. # Ensure there are enough data points for correlation
  77. if len(df_filtered_28) >= 3:
  78. # Calculate the change in DeficitScore and Qtype (Value) between timepoint 0 and 28
  79. fa_value_change_28 = df_filtered_28['Value_28'] - df_filtered_28['Value_0']
  80. deficit_score_change_28 = df_filtered_28['DeficitScore_28'] - df_filtered_28['DeficitScore_0']
  81. # Remove rows with NaN or infinite values
  82. valid_idx = ~(fa_value_change_28.isin([np.nan, np.inf, -np.inf]) | deficit_score_change_28.isin([np.nan, np.inf, -np.inf]))
  83. fa_value_change_28 = fa_value_change_28[valid_idx]
  84. deficit_score_change_28 = deficit_score_change_28[valid_idx]
  85. # Perform Shapiro-Wilk test for normality
  86. if len(fa_value_change_28) >= 3 and len(deficit_score_change_28) >= 3:
  87. shapiro_statValue, shapiro_pvalueQValue = shapiro(fa_value_change_28)
  88. shapiro_statScore, shapiro_pvalueBehavior = shapiro(deficit_score_change_28)
  89. # Use Pearson or Spearman correlation based on normality test
  90. if shapiro_pvalueQValue < 0.05 or shapiro_pvalueBehavior < 0.05:
  91. # Use Spearman correlation if data is not normally distributed
  92. correlation_coefficient, p_value = spearmanr(fa_value_change_28, deficit_score_change_28)
  93. else:
  94. # Use Pearson correlation if data is normally distributed
  95. correlation_coefficient, p_value = pearsonr(fa_value_change_28, deficit_score_change_28)
  96. # Store p-value for multiple testing correction
  97. all_p_values.append(p_value)
  98. # Plot the regression
  99. ax = axes[idx] # Select the appropriate subplot axis
  100. sns.regplot(
  101. x=fa_value_change_28, y=deficit_score_change_28, ci=95, ax=ax,
  102. line_kws={'color': 'lightcoral', 'alpha': 0.6}, # Regression line in light red with higher transparency
  103. scatter_kws={'color': 'red', 'alpha': 0.6, 's': 10}, # Scatter points in red
  104. label=f'Stroke (R={correlation_coefficient:.2f})'
  105. )
  106. # Determine significance level for R value
  107. pval_text = ''
  108. if p_value < 0.001:
  109. pval_text = '***'
  110. elif p_value < 0.01:
  111. pval_text = '**'
  112. elif p_value < 0.05:
  113. pval_text = '*'
  114. # Add text annotation for R and significance level
  115. ax.text(0.05, 0.95, f'R={correlation_coefficient:.2f}{pval_text}', transform=ax.transAxes,
  116. fontsize=10, verticalalignment='top', bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0))
  117. # Set labels and customize font settings
  118. anatomical_name = mask_name_mapping.get(selected_mask_name, selected_mask_name)
  119. ax.set_xlabel(f'Δ{selected_qtype}-{anatomical_name}', fontsize=10)
  120. ax.set_ylabel('ΔDS[28-0]', fontsize=10)
  121. # Remove gridlines
  122. ax.grid(False)
  123. sns.despine(ax=ax) # Remove the top and right spines for cleaner look
  124. else:
  125. print(f"Not enough data for combination: Group=Stroke, Qtype={selected_qtype}, Mask={selected_mask_name}, Dilation={selected_dialation_amount}")
  126. # Adjust layout
  127. plt.tight_layout()
  128. # Save the complete figure with all subplots
  129. figures_folder = os.path.join(parent_dir, 'output', 'Figures')
  130. pythonFigs_folder = os.path.join(figures_folder, 'pythonFigs')
  131. os.makedirs(pythonFigs_folder, exist_ok=True)
  132. file_suffix = f'timepoint_28-0_dilation_{selected_dialation_amount}_groups_{"-".join(selected_groups)}'
  133. fig_filename = f'qtype_{selected_qtype}_behavior_change_correlation_all_masks_{file_suffix}_groups_comparison.svg'
  134. plt.savefig(os.path.join(pythonFigs_folder, fig_filename), dpi=300, bbox_inches='tight', format="svg", transparent=True)
  135. # Show the figure
  136. plt.show()