3.5 KB

  1. import seaborn as sns
  2. import matplotlib.pyplot as plt
  3. import pandas as pd
  4. import scipy.stats as stats
  5. import numpy as np
  6. import os
  7. cm = 1/2.54 # centimeters in inches
  8. # Specify the path to your Excel file
  9. # Read data from the CSV file
  10. script_dir = os.path.dirname(__file__)
  11. excel_file_path = os.path.join(script_dir, '..', 'input', 'combined_data_anat.csv')
  12. out_path = os.path.join(script_dir, '..', 'figures')
  13. # Read the data into a pandas DataFrame
  14. df = pd.read_csv(excel_file_path)
  15. # Set seaborn style
  16. sns.set_style('ticks')
  17. # Create a list to store correlation and p-value for each dataset
  18. correlation_list = []
  19. p_value_list = []
  20. plt.figure(figsize=(10*cm,10*cm),dpi=300)
  21. # Get unique values in the 'dataset' column
  22. def extract_number(dataset):
  23. # Extract numeric part from the 'dataset' column
  24. # Assuming the numeric part is always at the beginning of the string
  25. # If it's not, you might need a more sophisticated method
  26. return int(''.join(filter(str.isdigit, dataset)))
  27. df['sorting_key'] = df['dataset'].apply(extract_number)
  28. df['SNR Chang'] = df['SNR Chang']#.apply(lambda x: np.power(10,x/20))
  29. df['SNR Normal'] = df['SNR Normal']#.apply(lambda x: np.power(10,x/20))
  30. df = df.sort_values(by=["sorting_key"],ascending=True)
  31. datasets = df['dataset'].unique()
  32. SS = int(np.ceil(np.sqrt(len(datasets))))
  33. # Create subplots based on the number of datasets
  34. fig, axes = plt.subplots(SS, SS, figsize=(18*cm, 18*cm), dpi=300, constrained_layout=True)
  35. # Flatten the axes array to iterate over it
  36. axes = axes.flatten()
  37. for i, dataset in enumerate(datasets):
  38. # Filter the dataframe for the current dataset
  39. df_subset = df[df['dataset'] == dataset]
  40. # Calculate the correlation and p-value for the current dataset
  41. correlation, p_value = stats.spearmanr(df_subset['SNR Chang'], df_subset['SNR Normal'], nan_policy='omit', alternative='two-sided')
  42. # Append the correlation and p-value to the lists
  43. correlation_list.append(correlation)
  44. p_value_list.append(p_value)
  45. # Create a scatter plot for the current dataset
  46. ax = sns.scatterplot(data=df_subset, x='SNR Chang', y='SNR Normal', s=7, ax=axes[i],color="red")
  47. ax.set_title(f"{dataset}", weight='bold', fontsize=8, fontname='Times New Roman')
  48. # Set title and labels including the correlation and p-value
  49. ax.set_xlabel('SNR-Chang (db)', fontname='Times New Roman',fontsize=8)
  50. ax.set_ylabel('SNR-Standard (db)', fontname='Times New Roman',fontsize=8)
  51. for tick in ax.get_xticklabels():
  52. tick.set_fontname("Times New Roman")
  53. tick.set_fontsize(8)
  54. for tick in ax.get_yticklabels():
  55. tick.set_fontname("Times New Roman")
  56. tick.set_fontsize(8)
  57. # Set xlim and ylim
  58. #ax.set_xlim(20.978242760551243, 88.420371212099)
  59. #ax.set_ylim(3.251536979292914, 43.47414376123412)
  60. # Remove borders
  61. ax.spines['top'].set_linewidth(0.5)
  62. ax.spines['right'].set_linewidth(0.5)
  63. ax.spines['bottom'].set_linewidth(0.5)
  64. ax.spines['left'].set_linewidth(0.5)
  65. # Show the plot
  66. fig_path_png = os.path.join(out_path, 'StandardVSchang_all.png')
  67. fig_path_svg = os.path.join(out_path, 'StandardVSchangall.svg')
  68. plt.savefig(fig_path_png, format='png', bbox_inches='tight')
  69. plt.savefig(fig_path_svg, format='svg', bbox_inches='tight')
  70. # Show the plot
  72. # Print the correlation and p-value for each dataset
  73. for i, dataset in enumerate(datasets):
  74. print(f"{dataset} - Correlation: {correlation_list[i]}, p-value: {p_value_list[i]}")