heat_maps_correlation.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import os
  2. import pandas as pd
  3. import seaborn as sns
  4. import matplotlib.pyplot as plt
  5. # Set up the directory paths
  6. code_dir = os.path.dirname(os.path.abspath(__file__))
  7. parent_dir = os.path.dirname(code_dir)
  8. csv_file = os.path.join(parent_dir, 'output', "Correlation_with_behavior", 'correlation_dti_with_behavior.csv')
  9. output_dir = os.path.dirname(csv_file)
  10. # Read the CSV file into a DataFrame
  11. df = pd.read_csv(csv_file)
  12. df = df[df["dialation_amount"] == 3]
  13. # Get unique Qtypes
  14. qtypes = df["Qtype"].unique()
  15. # Iterate over unique values of 'Qtype'
  16. for qtype in qtypes:
  17. # Filter the DataFrame for the current 'Qtype'
  18. filtered_df = df[df["Qtype"] == qtype]
  19. # Get unique groups within the filtered DataFrame
  20. groups = filtered_df["Group"].unique()
  21. # Set up figure for the current Qtype
  22. num_groups = len(groups)
  23. fig, axes = plt.subplots(num_groups, 2, figsize=(16, num_groups * 6))
  24. # Iterate over unique values of 'Group' within the current Qtype
  25. for i, group in enumerate(groups):
  26. # Filter the DataFrame for the current 'Group'
  27. group_df = filtered_df[filtered_df["Group"] == group]
  28. # Pivot the filtered DataFrame to prepare for heatmap
  29. for j, value in enumerate(['R', 'Pval']):
  30. heatmap_df = group_df.pivot(index='mask_name', columns='merged_timepoint', values=value)
  31. # Plot heatmap
  32. sns.heatmap(heatmap_df, ax=axes[i, j], cmap='coolwarm_r', annot=True, fmt=".2f", cbar=True)
  33. axes[i, j].set_title(f'Group: {group}, Qtype: {qtype}, Value: {value}')
  34. axes[i, j].set_xlabel('Time Point', fontsize=8)
  35. axes[i, j].set_ylabel('Tract Name', fontsize=8)
  36. axes[i, j].tick_params(axis='both', which='major', labelsize=8)
  37. plt.tight_layout()
  38. save_path = os.path.join(output_dir, f'heatmap_{qtype}_2dialation_corrected_spearman.png')
  39. #plt.savefig(save_path, dpi=300)
  40. plt.show()