123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- # -*- coding: utf-8 -*-
- """
- Created on Thu Oct 17 14:51:34 2024
- @author: arefks
- """
- import os
- import pandas as pd
- import numpy as np
- import seaborn as sns
- import matplotlib.pyplot as plt
- # Define the specific parameters for visualization
- selected_regions= ["R_SSp-tr_2361_AMBA", "CST_MOp-int-py_ipsilesional_selfdrawnROA+CCcut","CReT_MOp-TRN_ipsilesional_CCcut","L_STRd_485_AMBA",
- "L_FRP_184_AMBA","L_SSp-n_353_AMBA","OT_och-lgn_lgncut"]
- selected_groups = ["Stroke", "Sham"]
- selected_qtype = "fa"
- selected_mask_name = "CC_MOs-MOs"
- selected_dialation_amount = 2
- # Get the directory where the code file is located
- code_dir = os.path.dirname(os.path.abspath(__file__))
- # Get the parent directory of the code directory
- parent_dir = os.path.dirname(code_dir)
- # Step 4: Save the resulting dataframe to a CSV file
- input_file_path = os.path.join(parent_dir, 'output', "Quantitative_outputs", 'Quantitative_results_from_dwi_processing_merged_with_behavior_data.csv')
- # Read the CSV with low_memory=False to avoid mixed type warning
- df = pd.read_csv(input_file_path, low_memory=False)
- # Remove rows that contain 'NoBehavior' in any column
- df = df[~df.apply(lambda row: row.astype(str).str.contains('NoBehavior')).any(axis=1)]
- # Create a new column called 'chronic_timepoint' and set it to False initially
- df['chronic_timepoint'] = False
- # Iterate over each unique subject
- for subject in df['subjectID'].unique():
- # Filter the rows for the current subject
- subject_df = df[df['subjectID'] == subject]
-
- # Find the maximum timepoint for this subject
- maxTP = subject_df['merged_timepoint'].max()
-
- # Only proceed if the maximum timepoint is greater than 27
- if maxTP > 27:
- # Select the row where the timepoint equals the maximum timepoint
- subject_df_maxTP = subject_df[subject_df['merged_timepoint'] == maxTP]
-
- # Update 'chronic_timepoint' to True for this row
- df.loc[subject_df_maxTP.index, 'chronic_timepoint'] = True
- #% Visualization
- # Merge acute and chronic data based on subjectID for accurate matching
- merged_df_3_7 = pd.merge(df[df['merged_timepoint'] == 3],
- df[df['merged_timepoint'] == 7],
- on=['subjectID', 'Group', 'Qtype', 'mask_name', 'dialation_amount'],
- suffixes=('_acute', '_7'))
- # Merge acute and chronic data based on subjectID for behavior change
- merged_df_3_chronic = pd.merge(df[df['merged_timepoint'] == 3],
- df[df['chronic_timepoint'] == True],
- on=['subjectID', 'Group', 'Qtype', 'mask_name', 'dialation_amount'],
- suffixes=('_acute', '_chronic'))
- # Filter the merged DataFrames for visualization based on the predefined settings
- df_filtered_3_7_vis = merged_df_3_7[(merged_df_3_7['Group'].isin(selected_groups)) &
- (merged_df_3_7['Qtype'] == selected_qtype) &
- (merged_df_3_7['mask_name'] == selected_mask_name) &
- (merged_df_3_7['dialation_amount'] == selected_dialation_amount)]
- df_filtered_3_chronic_vis = merged_df_3_chronic[(merged_df_3_chronic['Group'].isin(selected_groups)) &
- (merged_df_3_chronic['Qtype'] == selected_qtype) &
- (merged_df_3_chronic['mask_name'] == selected_mask_name) &
- (merged_df_3_chronic['dialation_amount'] == selected_dialation_amount)]
- # If there are enough data points for visualization
- if not df_filtered_3_7_vis.empty and not df_filtered_3_chronic_vis.empty:
- # Merge the filtered DataFrames based on subjectID to ensure matching
- merged_filtered_vis = pd.merge(df_filtered_3_7_vis[['subjectID', 'Value_acute', 'Value_7', 'Group']],
- df_filtered_3_chronic_vis[['subjectID', 'DeficitScore_acute', 'DeficitScore_chronic']],
- on='subjectID')
- # Calculate the change in FA value and behavior deficit score
- merged_filtered_vis['fa_change'] = merged_filtered_vis['Value_7'] - merged_filtered_vis['Value_acute']
- merged_filtered_vis['behavior_change'] = merged_filtered_vis['DeficitScore_chronic'] - merged_filtered_vis['DeficitScore_acute']
- # Remove rows with NaN or infinite values
- merged_filtered_vis = merged_filtered_vis.replace([np.nan, np.inf, -np.inf], np.nan).dropna()
- # Create a scatter plot with seaborn
- plt.figure(figsize=(9/2.54, 7.5/2.54)) # Size in inches (convert cm to inches)
- colors = {'Stroke': 'red', 'Sham': 'gray'}
- for group in selected_groups:
- group_data = merged_filtered_vis[merged_filtered_vis['Group'] == group]
- sns.regplot(x='fa_change', y='behavior_change', data=group_data, ci=None, label=group, color=colors[group], scatter_kws={'alpha': 0.6})
-
- plt.xlabel('Acute change in fa value (day 7 - day 3)[a.u]', fontsize=10, fontname='Calibri')
- plt.ylabel('Change in Deficit Score (day 28 - day 3)[a.u]', fontsize=10, fontname='Calibri')
- plt.title(f'{selected_mask_name}',
- fontsize=10, fontweight='bold', fontname='Calibri')
-
- # Remove gridlines
- plt.grid(False)
-
- # Remove the upper and right plot borders
- sns.despine()
-
- # Add legend to distinguish between groups with 9pt font and no border
- legend = plt.legend(title='Group', fontsize=9, title_fontsize=9, frameon=False)
- # Define output paths
- figures_folder = os.path.join(parent_dir, 'output', 'Figures')
- pythonFigs_folder = os.path.join(figures_folder, 'pythonFigs')
- # Create directories if they do not exist
- os.makedirs(pythonFigs_folder, exist_ok=True)
- # Save the figure, including the region name in the filename
- fig_filename = f'predictive_fa_behavior_correlation_{selected_mask_name}.svg'
- plt.savefig(os.path.join(pythonFigs_folder, fig_filename), dpi=300, bbox_inches='tight', format="svg")
-
- plt.show()
- else:
- print("No data available for the specified visualization settings.")
|