ROCcurve.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import pandas as pd
  2. import seaborn as sns
  3. import matplotlib.pyplot as plt
  4. from matplotlib.ticker import MaxNLocator
  5. import os
  6. # Read data from the CSV file
  7. script_dir = os.path.dirname(__file__)
  8. file_path = os.path.join(script_dir, '..', 'input')
  9. out_path = os.path.join(script_dir, '..', 'figures')
  10. # Load the data
  11. result_df = pd.read_csv(os.path.join(file_path, 'Confusion_matrix_metrics.csv'))
  12. cm = 1/2.54
  13. # Calculate Actual_Label
  14. # Specify the font size for the plot
  15. sns.set_style('ticks')
  16. sns.set(font='Times New Roman', style=None) # Set font to Times New Roman and font size to 9
  17. palette = 'Set1'
  18. subset_df = result_df[(result_df['TP']+result_df['FN'] > 0)]
  19. # Function to print mean, std, max, and mean for "sequencetype"
  20. def print_statistics(data, x, y, hue):
  21. mean_values = data.groupby(hue).agg({y: 'mean'}).reset_index()
  22. std_values = data.groupby(hue).agg({y: 'std'}).reset_index()
  23. max_values = data.groupby(hue).agg({y: 'max'}).reset_index()
  24. for i, seq_type in enumerate(mean_values[hue]):
  25. print(f"Sequence Type: {seq_type}")
  26. print(f"Mean {y}: {mean_values[y][i]:.2f}")
  27. print(f"Standard Deviation {y}: {std_values[y][i]:.2f}")
  28. print(f"Maximum {y}: {max_values[y][i]:.2f}")
  29. print("\n")
  30. # Create a 2x3 subplot
  31. fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(19*cm, 8*cm),dpi=300)
  32. # Plot for "Accuracy" vs. "Thresold_Human_Voters"
  33. sns.lineplot(
  34. data=subset_df,
  35. x="Thresold_Human_Voters", y="Accuracy", hue="sequence_name",
  36. dashes=False, markers=True, ci=30,
  37. ax=axes[0, 0]
  38. )
  39. axes[0, 0].set_xlabel("Voting Threshold: Manual-rater", fontsize=8)
  40. axes[0, 0].set_ylabel("Accuracy", fontsize=8)
  41. axes[0, 0].yaxis.set_major_locator(MaxNLocator(nbins=5)) # Set major y ticks
  42. axes[0, 0].get_legend().remove()
  43. print("Statistics for Accuracy vs. Thresold_Human_Voters:")
  44. print_statistics(subset_df, "Thresold_Human_Voters", "Accuracy", "sequence_name")
  45. # Plot for "Specificity" vs. "Thresold_Human_Voters"
  46. sns.lineplot(
  47. data=subset_df,
  48. x="Thresold_Human_Voters", y="Specificity", hue="sequence_name",
  49. dashes=False, markers=True, ci=30,
  50. ax=axes[0, 1]
  51. )
  52. axes[0, 1].set_xlabel("Voting Threshold: Manual-rater", fontsize=8)
  53. axes[0, 1].set_ylabel("Specificity", fontsize=8)
  54. axes[0, 1].get_legend().remove() # Remove legend for subsequent plots
  55. axes[0, 1].yaxis.set_major_locator(MaxNLocator(nbins=5)) # Set major y ticks
  56. print("Statistics for Specificity vs. Thresold_Human_Voters:")
  57. print_statistics(subset_df, "Thresold_Human_Voters", "Specificity", "sequence_name")
  58. # Plot for "Sensitivity-Recall" vs. "Thresold_Human_Voters"
  59. sns.lineplot(
  60. data=subset_df,
  61. x="Thresold_Human_Voters", y="Sensitivity-Recall", hue="sequence_name",
  62. dashes=False, markers=True, ci=30,
  63. ax=axes[0, 2]
  64. )
  65. axes[0, 2].set_xlabel("Voting Threshold: Manual-rater", fontsize=8)
  66. axes[0, 2].set_ylabel("Sensitivity", fontsize=8)
  67. axes[0, 2].get_legend().remove() # Remove legend for subsequent plots
  68. axes[0, 2].yaxis.set_major_locator(MaxNLocator(nbins=5)) # Set major y ticks
  69. print("Statistics for Sensitivity-Recall vs. Thresold_Human_Voters:")
  70. print_statistics(subset_df, "Thresold_Human_Voters", "Sensitivity-Recall", "sequence_name")
  71. # Plot for "Accuracy" vs. "Thresold_ML_Voters"
  72. sns.lineplot(
  73. data=subset_df,
  74. x="Thresold_ML_Voters", y="Accuracy", hue="sequence_name",
  75. dashes=False, markers=True, ci=30,
  76. ax=axes[1, 0]
  77. )
  78. axes[1, 0].set_xlabel("Voting Threshold: AIDAqc", fontsize=8)
  79. axes[1, 0].set_ylabel("Accuracy", fontsize=8)
  80. axes[1, 0].yaxis.set_major_locator(MaxNLocator(nbins=5)) # Set major y ticks
  81. axes[1, 0].get_legend().remove()
  82. print("Statistics for Accuracy vs. Thresold_ML_Voters:")
  83. print_statistics(subset_df, "Thresold_ML_Voters", "Accuracy", "sequence_name")
  84. # Plot for "Specificity" vs. "Thresold_ML_Voters"
  85. sns.lineplot(
  86. data=subset_df,
  87. x="Thresold_ML_Voters", y="Specificity", hue="sequence_name",
  88. dashes=False, markers=True, ci=30,
  89. ax=axes[1, 1]
  90. )
  91. axes[1, 1].set_xlabel("Voting Threshold: AIDAqc", fontsize=8)
  92. axes[1, 1].set_ylabel("Specificity", fontsize=8)
  93. axes[1, 1].get_legend().remove() # Remove legend for subsequent plots
  94. axes[1, 1].yaxis.set_major_locator(MaxNLocator(nbins=5)) # Set major y ticks
  95. print("Statistics for Specificity vs. Thresold_ML_Voters:")
  96. print_statistics(subset_df, "Thresold_ML_Voters", "Specificity", "sequence_name")
  97. # Plot for "Sensitivity-Recall" vs. "Thresold_ML_Voters"
  98. sns.lineplot(
  99. data=subset_df,
  100. x="Thresold_ML_Voters", y="Sensitivity-Recall", hue="sequence_name",
  101. dashes=False, markers=True, ci=30,
  102. ax=axes[1, 2]
  103. )
  104. axes[1, 2].set_xlabel("Voting Threshold: AIDAqc", fontsize=8)
  105. axes[1, 2].set_ylabel("Sensitivity", fontsize=8)
  106. #axes[1, 2].get_legend().remove() # Add legend for the last plot only
  107. axes[1, 2].yaxis.set_major_locator(MaxNLocator(nbins=5)) # Set major y ticks
  108. print("Statistics for Sensitivity-Recall vs. Thresold_ML_Voters:")
  109. print_statistics(subset_df, "Thresold_ML_Voters", "Sensitivity-Recall", "sequence_name")
  110. axes[1, 2].legend(fontsize=8,frameon=False)
  111. # Customize spines and tick parameters
  112. for ax in axes.flatten():
  113. ax.tick_params(axis='both', which='both', labelsize=8)
  114. ax.spines['top'].set_visible(True)
  115. ax.spines['right'].set_visible(True)
  116. ax.spines['bottom'].set_visible(True)
  117. ax.spines['left'].set_visible(True)
  118. ax.spines['top'].set_linewidth(0.5)
  119. ax.spines['right'].set_linewidth(0.5)
  120. ax.spines['bottom'].set_linewidth(0.5)
  121. ax.spines['left'].set_linewidth(0.5)
  122. ax.tick_params(direction='out', length=4, width=1,
  123. grid_alpha=0.5)
  124. # Adjust layout manually
  125. plt.tight_layout()
  126. # Save the figure as SVG and PNG
  127. output_path = out_path
  128. output_filename = "Subplots_Sensitivity_Accuracy_Specificity_2x3"
  129. # Save as SVG
  130. plt.savefig(f"{output_path}/{output_filename}.svg", format="svg")
  131. # Save as PNG
  132. plt.savefig(f"{output_path}/{output_filename}.png", format="png")
  133. plt.show()