visualization_predictive_correlation.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Thu Oct 17 14:51:34 2024
  4. @author: arefks
  5. """
  6. import os
  7. import pandas as pd
  8. import numpy as np
  9. import seaborn as sns
  10. import matplotlib.pyplot as plt
  11. # Define the specific parameters for visualization
  12. selected_regions= ["R_SSp-tr_2361_AMBA", "CST_MOp-int-py_ipsilesional_selfdrawnROA+CCcut","CReT_MOp-TRN_ipsilesional_CCcut","L_STRd_485_AMBA",
  13. "L_FRP_184_AMBA","L_SSp-n_353_AMBA","OT_och-lgn_lgncut"]
  14. selected_groups = ["Stroke", "Sham"]
  15. selected_qtype = "fa"
  16. selected_mask_name = "CC_MOs-MOs"
  17. selected_dialation_amount = 2
  18. # Get the directory where the code file is located
  19. code_dir = os.path.dirname(os.path.abspath(__file__))
  20. # Get the parent directory of the code directory
  21. parent_dir = os.path.dirname(code_dir)
  22. # Step 4: Save the resulting dataframe to a CSV file
  23. input_file_path = os.path.join(parent_dir, 'output', "Quantitative_outputs", 'Quantitative_results_from_dwi_processing_merged_with_behavior_data.csv')
  24. # Read the CSV with low_memory=False to avoid mixed type warning
  25. df = pd.read_csv(input_file_path, low_memory=False)
  26. # Remove rows that contain 'NoBehavior' in any column
  27. df = df[~df.apply(lambda row: row.astype(str).str.contains('NoBehavior')).any(axis=1)]
  28. # Create a new column called 'chronic_timepoint' and set it to False initially
  29. df['chronic_timepoint'] = False
  30. # Iterate over each unique subject
  31. for subject in df['subjectID'].unique():
  32. # Filter the rows for the current subject
  33. subject_df = df[df['subjectID'] == subject]
  34. # Find the maximum timepoint for this subject
  35. maxTP = subject_df['merged_timepoint'].max()
  36. # Only proceed if the maximum timepoint is greater than 27
  37. if maxTP > 27:
  38. # Select the row where the timepoint equals the maximum timepoint
  39. subject_df_maxTP = subject_df[subject_df['merged_timepoint'] == maxTP]
  40. # Update 'chronic_timepoint' to True for this row
  41. df.loc[subject_df_maxTP.index, 'chronic_timepoint'] = True
  42. #% Visualization
  43. # Merge acute and chronic data based on subjectID for accurate matching
  44. merged_df_3_7 = pd.merge(df[df['merged_timepoint'] == 3],
  45. df[df['merged_timepoint'] == 7],
  46. on=['subjectID', 'Group', 'Qtype', 'mask_name', 'dialation_amount'],
  47. suffixes=('_acute', '_7'))
  48. # Merge acute and chronic data based on subjectID for behavior change
  49. merged_df_3_chronic = pd.merge(df[df['merged_timepoint'] == 3],
  50. df[df['chronic_timepoint'] == True],
  51. on=['subjectID', 'Group', 'Qtype', 'mask_name', 'dialation_amount'],
  52. suffixes=('_acute', '_chronic'))
  53. # Filter the merged DataFrames for visualization based on the predefined settings
  54. df_filtered_3_7_vis = merged_df_3_7[(merged_df_3_7['Group'].isin(selected_groups)) &
  55. (merged_df_3_7['Qtype'] == selected_qtype) &
  56. (merged_df_3_7['mask_name'] == selected_mask_name) &
  57. (merged_df_3_7['dialation_amount'] == selected_dialation_amount)]
  58. df_filtered_3_chronic_vis = merged_df_3_chronic[(merged_df_3_chronic['Group'].isin(selected_groups)) &
  59. (merged_df_3_chronic['Qtype'] == selected_qtype) &
  60. (merged_df_3_chronic['mask_name'] == selected_mask_name) &
  61. (merged_df_3_chronic['dialation_amount'] == selected_dialation_amount)]
  62. # If there are enough data points for visualization
  63. if not df_filtered_3_7_vis.empty and not df_filtered_3_chronic_vis.empty:
  64. # Merge the filtered DataFrames based on subjectID to ensure matching
  65. merged_filtered_vis = pd.merge(df_filtered_3_7_vis[['subjectID', 'Value_acute', 'Value_7', 'Group']],
  66. df_filtered_3_chronic_vis[['subjectID', 'DeficitScore_acute', 'DeficitScore_chronic']],
  67. on='subjectID')
  68. # Calculate the change in FA value and behavior deficit score
  69. merged_filtered_vis['fa_change'] = merged_filtered_vis['Value_7'] - merged_filtered_vis['Value_acute']
  70. merged_filtered_vis['behavior_change'] = merged_filtered_vis['DeficitScore_chronic'] - merged_filtered_vis['DeficitScore_acute']
  71. # Remove rows with NaN or infinite values
  72. merged_filtered_vis = merged_filtered_vis.replace([np.nan, np.inf, -np.inf], np.nan).dropna()
  73. # Create a scatter plot with seaborn
  74. plt.figure(figsize=(9/2.54, 7.5/2.54)) # Size in inches (convert cm to inches)
  75. colors = {'Stroke': 'red', 'Sham': 'gray'}
  76. for group in selected_groups:
  77. group_data = merged_filtered_vis[merged_filtered_vis['Group'] == group]
  78. sns.regplot(x='fa_change', y='behavior_change', data=group_data, ci=None, label=group, color=colors[group], scatter_kws={'alpha': 0.6})
  79. plt.xlabel('Acute change in fa value (day 7 - day 3)[a.u]', fontsize=10, fontname='Calibri')
  80. plt.ylabel('Change in Deficit Score (day 28 - day 3)[a.u]', fontsize=10, fontname='Calibri')
  81. plt.title(f'{selected_mask_name}',
  82. fontsize=10, fontweight='bold', fontname='Calibri')
  83. # Remove gridlines
  84. plt.grid(False)
  85. # Remove the upper and right plot borders
  86. sns.despine()
  87. # Add legend to distinguish between groups with 9pt font and no border
  88. legend = plt.legend(title='Group', fontsize=9, title_fontsize=9, frameon=False)
  89. # Define output paths
  90. figures_folder = os.path.join(parent_dir, 'output', 'Figures')
  91. pythonFigs_folder = os.path.join(figures_folder, 'pythonFigs')
  92. # Create directories if they do not exist
  93. os.makedirs(pythonFigs_folder, exist_ok=True)
  94. # Save the figure, including the region name in the filename
  95. fig_filename = f'predictive_fa_behavior_correlation_{selected_mask_name}.svg'
  96. plt.savefig(os.path.join(pythonFigs_folder, fig_filename), dpi=300, bbox_inches='tight', format="svg")
  97. plt.show()
  98. else:
  99. print("No data available for the specified visualization settings.")