plotting_quantitative_dti_values.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import os
  2. import pandas as pd
  3. import seaborn as sns
  4. import matplotlib.pyplot as plt
  5. import statsmodels.formula.api as smf
  6. from statsmodels.stats.multicomp import MultiComparison
  7. from itertools import combinations
  8. import statsmodels
  9. import scipy
  10. from scipy.stats import dunnett
  11. # Set up warning suppression for convergence
  12. import warnings
  13. warnings.simplefilter(action='ignore')
  14. # Define the PlotBoxplot function
  15. def PlotBoxplot(filtered_df, qq, mm,dd, output_folder,sig):
  16. """
  17. Plot boxplots for filtered dataframe and save as images.
  18. Parameters:
  19. filtered_df (DataFrame): Filtered dataframe based on "Qtype" and "mask_name".
  20. qq (str): Qtype value.
  21. mm (str): Mask_name value.
  22. output_folder (str): Path to the output folder to save images.
  23. """
  24. # Set the color palette to Set2
  25. sns.set_palette("grey")
  26. # Set the figure size to 18 cm in width
  27. plt.figure(figsize=(18/2.54, 4)) # Convert 18 cm to inches
  28. # Create a boxplot
  29. sns.boxplot(data=filtered_df, x="Group", y="Value", hue="merged_timepoint",
  30. fliersize=3 ,flierprops={"marker": "o"})
  31. # Set title and labels
  32. plt.title(f'{qq}_{mm}_{dd}')
  33. plt.xlabel('Timepoint')
  34. plt.ylabel(qq+"(n.a)")
  35. # Show legend without title and frame
  36. plt.legend(title=None, frameon=False,bbox_to_anchor=(1.05, 1), loc='best')
  37. plt.tight_layout()
  38. # Save the plot as an image
  39. if sig:
  40. output_path = os.path.join(output_folder, f'{qq}_{mm}_{dd}_mixedlm_significant.png')
  41. else:
  42. output_path = os.path.join(output_folder, f'{qq}_{mm}_{dd}_mixedlm.png')
  43. plt.savefig(output_path, dpi=100)
  44. # Close the plot to avoid displaying it
  45. plt.close()
  46. # Get the directory where the code file is located
  47. code_dir = os.path.dirname(os.path.abspath(__file__))
  48. # Get the parent directory of the code directory
  49. parent_dir = os.path.dirname(code_dir)
  50. # Create a new folder for mixed model analysis
  51. mixed_model_dir = os.path.join(parent_dir, 'output', "Final_Quantitative_output", 'mixed_model_analysis_dunnett')
  52. os.makedirs(mixed_model_dir, exist_ok=True)
  53. # Step 4: Save the resulting dataframe to a CSV file
  54. input_file_path = os.path.join(parent_dir, 'output', "Final_Quantitative_output", 'Quantitative_results_from_dwi_processing.csv')
  55. df = pd.read_csv(input_file_path)
  56. qc_csv = os.path.join(parent_dir,"input","AIDAqc_ouptut_for_data","Voting_remapped.csv")
  57. df_qc = pd.read_csv(qc_csv)
  58. df_qc_5 = df_qc[(df_qc["Voting outliers (from 5)"]>3) & (df_qc["sequence_type"]=="diff")]
  59. # Filtering the dataframe based on conditions
  60. df_f0 = df[~df["merged_timepoint"].isin([42, 56])]
  61. # Merge df_f1 and df_qc_5 on "subjectID" and "merged_timepoint"
  62. merged_df = pd.merge(df_f0, df_qc_5, on=["subjectID", "merged_timepoint"], how="left", indicator=True)
  63. # Drop the entries that are present in both df_f1 and df_qc_5
  64. df_f1 = merged_df[merged_df["_merge"] == "left_only"].drop(columns=["_merge"])
  65. # Drop unnecessary columns with NaN values
  66. df_f1 = df_f1.dropna(axis=1, how="all")
  67. # Create lists to store results and terminal output
  68. results = []
  69. # Iterate over unique values of "Qtype" and "mask_name"
  70. for dd in df["dialation_amount"].unique():
  71. for qq in df["Qtype"].unique():
  72. for mm in df["mask_name"].unique():
  73. # Filter the dataframe based on current "Qtype" and "mask_name"
  74. filtered_df = df_f1[(df_f1["Qtype"] == qq) & (df_f1["mask_name"] == mm) & (df_f1["dialation_amount"] == dd)]
  75. # Fit the mixed-effects model
  76. model_name = f"{qq}_{mm}_{dd}"
  77. md = smf.mixedlm("Value ~ merged_timepoint", filtered_df, groups=filtered_df["Group"])
  78. mdf = md.fit(method=["lbfgs"])
  79. # Log file for each combination of qq and mm
  80. # Check if the p-value is significant
  81. if mdf.pvalues['merged_timepoint'] < 0.05:
  82. log_file = os.path.join(mixed_model_dir, f"{model_name}_log_significant.txt")
  83. sig = True
  84. else:
  85. log_file = os.path.join(mixed_model_dir, f"{model_name}_log.txt")
  86. sig = False
  87. with open(log_file, 'w') as log:
  88. # Print model name to log file
  89. log.write(f"Model Name: {model_name}\n")
  90. # Print the model summary to log file
  91. log.write(str(mdf.summary()) + '\n')
  92. # Plot and save boxplots for significant models
  93. PlotBoxplot(filtered_df, qq, mm,dd ,mixed_model_dir,sig)
  94. # Perform Dunnet test for post hoc analysis
  95. log.write("Dunnett Multiple Comparisons for time effect:\n")
  96. log.write("Group\tTP1\tTP2\tPValue\tSignificant\n")
  97. for gg in filtered_df["Group"].unique():
  98. control_data = filtered_df[(filtered_df["Group"]==gg) & (filtered_df["merged_timepoint"]==0)]["Value"]
  99. for tt in filtered_df["merged_timepoint"].unique():
  100. treatment_data = filtered_df[(filtered_df["Group"]==gg) & (filtered_df["merged_timepoint"]==tt)]["Value"]
  101. result = dunnett(treatment_data, control=control_data)
  102. if result.pvalue < 0.05: # You can adjust the significance level as needed
  103. significant = "Yes"
  104. else:
  105. significant = "No"
  106. # Write the results to the log file
  107. log.write(f"{gg}\t0\t{tt}\t{result.pvalue}\t{significant}\n")
  108. # Perform the Sidak multiple comparisons for the group effect
  109. group_combinations = combinations(filtered_df["Group"].unique(), 2)
  110. log.write("Sidak Multiple Comparisons for group effect:\n")
  111. log.write("Group1\tGroup2\tTimePoint\tPValue\tSignificant\n")
  112. for group1, group2 in group_combinations:
  113. for time_point in filtered_df["merged_timepoint"].unique():
  114. group1_data = filtered_df[(filtered_df["Group"] == group1) & (filtered_df["merged_timepoint"] == time_point)]["Value"]
  115. group2_data = filtered_df[(filtered_df["Group"] == group2) & (filtered_df["merged_timepoint"] == time_point)]["Value"]
  116. sidak_result = statsmodels.stats.multitest.multipletests(scipy.stats.ttest_ind(group1_data, group2_data)[1], method='sidak')
  117. if sidak_result[0]:
  118. log.write(f"{group1}\t{group2}\t{time_point}\t{sidak_result[1]}\tYes\n")
  119. else:
  120. log.write(f"{group1}\t{group2}\t{time_point}\t{sidak_result[1]}\tNo\n")
  121. # Append the results to the list
  122. results.append({'Qtype': qq, 'mask_name': mm,'dialation': dd,'p_value': mdf.pvalues['merged_timepoint']})
  123. # Create a DataFrame from results
  124. results_df = pd.DataFrame(results)
  125. # Save results_df as CSV in mixed_model_analysis folder
  126. output_file_path = os.path.join(mixed_model_dir, 'mixed_model_results.csv')
  127. results_df.to_csv(output_file_path, index=False)
  128. # Print the table
  129. print(results_df)