heat_maps_correlation.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import os
  2. import pandas as pd
  3. import seaborn as sns
  4. import matplotlib.pyplot as plt
  5. # Set up the directory paths
  6. code_dir = os.path.dirname(os.path.abspath(__file__))
  7. parent_dir = os.path.dirname(code_dir)
  8. csv_file = os.path.join(parent_dir, 'output', "Correlation_with_behavior", 'correlation_dti_with_behavior.csv')
  9. output_dir = os.path.dirname(code_dir)
  10. # Read the CSV file into a DataFrame
  11. df = pd.read_csv(csv_file)
  12. df = df[df["dialation_amount"] == 1]
  13. # Get unique Qtypes
  14. qtypes = df["Qtype"].unique()
  15. qtypes = ["fa"]
  16. # Iterate over unique values of 'Qtype'
  17. for qtype in qtypes:
  18. # Filter the DataFrame for the current 'Qtype'
  19. filtered_df = df[df["Qtype"] == qtype]
  20. # Get unique groups within the filtered DataFrame
  21. groups = filtered_df["Group"].unique()
  22. # Set up figure for the current Qtype
  23. num_groups = len(groups)
  24. fig, axes = plt.subplots(num_groups, 2, figsize=(40, num_groups * 15))
  25. # Iterate over unique values of 'Group' within the current Qtype
  26. for i, group in enumerate(groups):
  27. # Filter the DataFrame for the current 'Group'
  28. group_df = filtered_df[filtered_df["Group"] == group]
  29. # Pivot the filtered DataFrame to prepare for heatmap
  30. for j, value in enumerate(['R', 'Pval']):
  31. heatmap_df = group_df.pivot(index='mask_name', columns='merged_timepoint', values=value)
  32. # Plot heatmap
  33. if value == "R":
  34. sns.heatmap(heatmap_df, ax=axes[i, j], cmap='coolwarm', annot=True, fmt=".2f", cbar=True)
  35. else:
  36. sns.heatmap(heatmap_df, ax=axes[i, j], cmap='coolwarm_r', annot=True, fmt=".2f", cbar=True)
  37. axes[i, j].set_title(f'Group: {group}, Qtype: {qtype}, Value: {value}')
  38. axes[i, j].set_xlabel('Time Point', fontsize=8)
  39. axes[i, j].set_ylabel('Tract Name', fontsize=8)
  40. axes[i, j].tick_params(axis='both', which='major', labelsize=8)
  41. plt.tight_layout()
  42. save_path = os.path.join(output_dir,"output","Figures","pythonFigs", f'heatmap_{qtype}_2dialation_corrected_spearman.svg')
  43. plt.savefig(save_path, dpi=300,format="svg")
  44. plt.show()