heat_maps_correlation.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import os
  2. import pandas as pd
  3. import seaborn as sns
  4. import matplotlib.pyplot as plt
  5. # Set up the directory paths
  6. code_dir = os.path.dirname(os.path.abspath(__file__))
  7. parent_dir = os.path.dirname(code_dir)
  8. csv_file = os.path.join(parent_dir, 'output', "Correlation_with_behavior", 'correlation_dti_with_behavior5.csv')
  9. output_dir = os.path.dirname(csv_file)
  10. # Read the CSV file into a DataFrame
  11. df = pd.read_csv(csv_file)
  12. # Get unique Qtypes
  13. qtypes = df["Qtype"].unique()
  14. # Iterate over unique values of 'Qtype'
  15. for qtype in qtypes:
  16. # Filter the DataFrame for the current 'Qtype'
  17. filtered_df = df[df["Qtype"] == qtype]
  18. # Get unique groups within the filtered DataFrame
  19. groups = filtered_df["Group"].unique()
  20. # Set up figure for the current Qtype
  21. num_groups = len(groups)
  22. fig, axes = plt.subplots(num_groups, 2, figsize=(16, num_groups * 6))
  23. # Iterate over unique values of 'Group' within the current Qtype
  24. for i, group in enumerate(groups):
  25. # Filter the DataFrame for the current 'Group'
  26. group_df = filtered_df[filtered_df["Group"] == group]
  27. # Pivot the filtered DataFrame to prepare for heatmap
  28. for j, value in enumerate(['R_corr', 'Pval_corr']):
  29. heatmap_df = group_df.pivot(index='mask_name', columns='merged_timepoint', values=value)
  30. # Plot heatmap
  31. sns.heatmap(heatmap_df, ax=axes[i, j], cmap='coolwarm', annot=True, fmt=".2f", cbar=True)
  32. axes[i, j].set_title(f'Group: {group}, Qtype: {qtype}, Value: {value}')
  33. axes[i, j].set_xlabel('Time Point', fontsize=8)
  34. axes[i, j].set_ylabel('Tract Name', fontsize=8)
  35. axes[i, j].tick_params(axis='both', which='major', labelsize=8)
  36. plt.tight_layout()
  37. save_path = os.path.join(output_dir, f'heatmap_{qtype}_2dialation_corrected_spearman.png')
  38. plt.savefig(save_path, dpi=300)
  39. plt.show()