12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- import os
- import pandas as pd
- import seaborn as sns
- import matplotlib.pyplot as plt
- # Set up the directory paths
- code_dir = os.path.dirname(os.path.abspath(__file__))
- parent_dir = os.path.dirname(code_dir)
- csv_file = os.path.join(parent_dir, 'output', "Correlation_with_behavior", 'correlation_dti_with_behavior.csv')
- output_dir = os.path.dirname(csv_file)
- # Read the CSV file into a DataFrame
- df = pd.read_csv(csv_file)
- df = df[df["dialation_amount"] == 1]
- # Get unique Qtypes
- qtypes = df["Qtype"].unique()
- qtypes = ["fa"]
- # Iterate over unique values of 'Qtype'
- for qtype in qtypes:
- # Filter the DataFrame for the current 'Qtype'
- filtered_df = df[df["Qtype"] == qtype]
-
- # Get unique groups within the filtered DataFrame
- groups = filtered_df["Group"].unique()
-
- # Set up figure for the current Qtype
- num_groups = len(groups)
- fig, axes = plt.subplots(num_groups, 2, figsize=(16, num_groups * 6))
-
- # Iterate over unique values of 'Group' within the current Qtype
- for i, group in enumerate(groups):
- # Filter the DataFrame for the current 'Group'
- group_df = filtered_df[filtered_df["Group"] == group]
-
- # Pivot the filtered DataFrame to prepare for heatmap
- for j, value in enumerate(['R', 'Pval']):
- heatmap_df = group_df.pivot(index='mask_name', columns='merged_timepoint', values=value)
-
- # Plot heatmap
- if value == "R":
- sns.heatmap(heatmap_df, ax=axes[i, j], cmap='coolwarm', annot=True, fmt=".2f", cbar=True)
- else:
- sns.heatmap(heatmap_df, ax=axes[i, j], cmap='coolwarm_r', annot=True, fmt=".2f", cbar=True)
-
- axes[i, j].set_title(f'Group: {group}, Qtype: {qtype}, Value: {value}')
- axes[i, j].set_xlabel('Time Point', fontsize=8)
- axes[i, j].set_ylabel('Tract Name', fontsize=8)
- axes[i, j].tick_params(axis='both', which='major', labelsize=8)
-
- plt.tight_layout()
- save_path = os.path.join(output_dir, f'heatmap_{qtype}_2dialation_corrected_spearman.png')
- #plt.savefig(save_path, dpi=300)
- plt.show()
|