ttest_group_differences.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Mon Oct 28 12:21:08 2024
  4. @author: arefk
  5. """
  6. import os
  7. import pandas as pd
  8. import numpy as np
  9. from scipy.stats import shapiro, ttest_ind, mannwhitneyu
  10. from tqdm import tqdm
  11. from concurrent.futures import ThreadPoolExecutor, as_completed
  12. # Get the directory where the code file is located
  13. code_dir = os.path.dirname(os.path.abspath(__file__))
  14. parent_dir = os.path.dirname(code_dir)
  15. # Load the CSV data
  16. input_file_path = os.path.join(parent_dir, 'output', "Quantitative_outputs",'Quantitative_results_from_dwi_processing_only_in_stroke_slices.csv')
  17. df = pd.read_csv(input_file_path, low_memory=False)
  18. # Initialize an empty list to store results
  19. results = []
  20. # Get unique values of masks, qtypes, timepoints, and dilation amounts
  21. unique_masks = df['mask_name'].unique()
  22. unique_qtypes = df['Qtype'].unique()
  23. unique_timepoints = df['merged_timepoint'].unique()
  24. unique_dilations = df['dialation_amount'].unique()
  25. # Prepare the combinations for parallel processing
  26. combinations = [(mask, qtype, timepoint, dilation) for mask in unique_masks for qtype in unique_qtypes for timepoint in unique_timepoints for dilation in unique_dilations]
  27. # Function to process each combination
  28. def process_combination(mask, qtype, timepoint, dilation):
  29. result = None
  30. # Filter data for Stroke and Sham groups separately
  31. df_stroke = df[(df['Group'] == 'Stroke') &
  32. (df['mask_name'] == mask) &
  33. (df['Qtype'] == qtype) &
  34. (df['merged_timepoint'] == timepoint) &
  35. (df['dialation_amount'] == dilation)]
  36. df_sham = df[(df['Group'] == 'Sham') &
  37. (df['mask_name'] == mask) &
  38. (df['Qtype'] == qtype) &
  39. (df['merged_timepoint'] == timepoint) &
  40. (df['dialation_amount'] == dilation)]
  41. # Drop NaN values for the 'Value' column
  42. stroke_values = df_stroke['Value'].dropna()
  43. sham_values = df_sham['Value'].dropna()
  44. # Filter data after dropping NaN values to get subjects with non-null values
  45. df_stroke_filtered = df_stroke[df_stroke['Value'].notna()]
  46. df_sham_filtered = df_sham[df_sham['Value'].notna()]
  47. # Only proceed if there are more than 8 subjects in either group after dropping NaNs
  48. if len(df_stroke_filtered['subjectID'].unique()) > 8 and len(df_sham_filtered['subjectID'].unique()) > 8:
  49. # Check if we have enough values to perform statistical tests
  50. if len(stroke_values) > 0 and len(sham_values) > 0:
  51. # Perform Shapiro-Wilk normality test
  52. shapiro_stroke_p = shapiro(stroke_values)[1]
  53. shapiro_sham_p = shapiro(sham_values)[1]
  54. # Check if data is normally distributed
  55. if shapiro_stroke_p < 0.05 or shapiro_sham_p < 0.05:
  56. # Use Mann-Whitney U test if data is not normally distributed
  57. stat, p_value = mannwhitneyu(stroke_values, sham_values, alternative='two-sided')
  58. else:
  59. # Use Welch's t-test if data is normally distributed
  60. stat, p_value = ttest_ind(stroke_values, sham_values, equal_var=False)
  61. # Store the result
  62. result = {
  63. 'mask_name': mask,
  64. 'Qtype': qtype,
  65. 'merged_timepoint': timepoint,
  66. 'dialation_amount': dilation,
  67. 'Pvalue': p_value
  68. }
  69. return result
  70. # Parallel processing using ThreadPoolExecutor with 4 workers
  71. with ThreadPoolExecutor(max_workers=6) as executor:
  72. futures = {executor.submit(process_combination, mask, qtype, timepoint, dilation): (mask, qtype, timepoint, dilation) for mask, qtype, timepoint, dilation in combinations}
  73. # Iterate through completed tasks with progress bar
  74. with tqdm(total=len(futures), desc="Processing combinations in parallel") as pbar:
  75. for future in as_completed(futures):
  76. combination = futures[future]
  77. try:
  78. result = future.result()
  79. if result:
  80. results.append(result)
  81. except Exception as e:
  82. print(f"Error processing combination {combination}: {e}")
  83. finally:
  84. pbar.update(1)
  85. # Convert results to a DataFrame
  86. results_df = pd.DataFrame(results)
  87. # Define output path for the new CSV
  88. output_file_path = os.path.join(parent_dir, 'output', "Quantitative_outputs", 'Significance_stroke_vs_sham_difference_withoutWMmask_only_in_stroke_slices.csv')
  89. # Save results to CSV
  90. results_df.to_csv(output_file_path, index=False)
  91. print(f"Significance analysis results saved to {output_file_path}")