GetQuantitativeValues.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. import os
  2. import glob
  3. import pandas as pd
  4. import nibabel as nib
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. from tqdm import tqdm
  8. import argparse
  9. from scipy.ndimage import binary_dilation
  10. def create_output_dir(output_dir):
  11. """Create the output directory if it does not exist."""
  12. if not os.path.exists(output_dir):
  13. os.makedirs(output_dir)
  14. def process_files(input_path, output_path, create_figures):
  15. """Process DWI files and generate figures."""
  16. # Define a dictionary to map Experiment values to the desired timepoints
  17. session_mapping = {
  18. 0: 0,
  19. 1: 3, 2: 3, 3: 3,
  20. 4: 3, 5: 3, 6: 7, 7: 7,
  21. 8: 7, 9: 7, 10: 7, 11: 14, 12: 14,
  22. 13: 14, 14: 14, 15: 14, 16: 14, 17: 14, 18: 14, 19: 21,
  23. 20: 21, 21: 21, 22: 21, 23: 21, 24: 21, 25: 21, 26: 28,
  24. 27: 28, 28: 28, 29: 28, 30: 28 , 42:42, 43:42, 56:56, 57:56
  25. }
  26. # Initialize a list to store extracted information
  27. csvdata = []
  28. # Iterate over files with progress bar
  29. file_paths = glob.glob(os.path.join(input_path, "**", "dwi", "DSI_studio", "*_flipped.nii.gz"), recursive=True)
  30. for file_path in tqdm(file_paths, desc="Processing files"):
  31. # Extract information from the file path
  32. subject_id = file_path.split(os.sep)[-5]
  33. time_point = file_path.split(os.sep)[-4]
  34. print(file_path)
  35. int_time_point = 0 if time_point == "ses-Baseline" else int(time_point.split("-P")[1])
  36. merged_time_point = session_mapping[int_time_point]
  37. q_type = os.path.basename(file_path).split("_flipped")[0]
  38. try:
  39. searchStroke = os.path.join(os.path.dirname(file_path),"*StrokeMask_scaled.nii")
  40. #print(searchStroke)
  41. stroke_path = glob.glob(searchStroke)[0]
  42. print(stroke_path)
  43. except IndexError:
  44. stroke_path = False
  45. # Create the temp path to search for dwi_masks
  46. temp_path = os.path.join(os.path.dirname(file_path), "RegisteredTractMasks_adjusted")
  47. dwi_masks = glob.glob(os.path.join(temp_path, "*dwi_flipped.nii.gz"))
  48. # Load DWI data using nibabel
  49. dwi_img = nib.load(file_path)
  50. dwi_data = dwi_img.get_fdata()
  51. # Loop through masks and calculate average values
  52. max_pixels = -1
  53. selected_slice = None
  54. pix_dialation = [0,1,2]
  55. for mask_path in tqdm(dwi_masks, desc="Processing masks", leave=False):
  56. mask_name = os.path.basename(mask_path).replace("registered", "").replace("flipped", "").replace(".nii.gz", "").replace("__dwi_","")
  57. mask_img = nib.load(mask_path)
  58. mask_data_pre_dilation = mask_img.get_fdata()
  59. # Find the slice with the highest number of pixels
  60. for i in range(mask_data_pre_dilation.shape[2]):
  61. num_pixels = np.count_nonzero(mask_data_pre_dilation[..., i])
  62. if num_pixels > max_pixels:
  63. max_pixels = num_pixels
  64. selected_slice = i
  65. # Assuming you want to calculate the average value within the mask region
  66. for pp in pix_dialation:
  67. if pp != 0:
  68. mask_data = binary_dilation(mask_data_pre_dilation, iterations=pp)
  69. else:
  70. mask_data = mask_data_pre_dilation > 0
  71. if stroke_path:
  72. stroke_image = nib.load(stroke_path)
  73. stroke_data = stroke_image.get_fdata()
  74. # Subtracting stroke region from mask
  75. mask_data = mask_data & (stroke_data < 1)
  76. masked_data = dwi_data[mask_data > 0]
  77. average_value = np.nanmean(masked_data) # Handle case when masked_data is empty
  78. strokeFlag = "Stroke"
  79. else:
  80. masked_data = dwi_data[mask_data > 0]
  81. average_value = np.nanmean(masked_data) # Handle case when masked_data is empty
  82. strokeFlag = "Sham"
  83. # Create and save the figure if the flag is True
  84. if create_figures:
  85. # Create a plot of the selected slice with overlay
  86. fig, ax = plt.subplots(1, 1, figsize=(10, 10), dpi=50)
  87. ax.imshow(dwi_data[..., selected_slice], cmap='gray')
  88. ax.imshow(mask_data[..., selected_slice], cmap='jet', alpha=0.5) # Overlay mask_data
  89. ax.set_title(f"{subject_id}_tp_{merged_time_point}_{q_type}_{average_value:.2f}_{mask_name}_{pp}")
  90. plt.axis('off')
  91. # Save the plot to the output path
  92. output_file = os.path.join(output_path,
  93. f"{subject_id}_tp_{merged_time_point}_{q_type}_{average_value:.2f}_{mask_name}_{pp}.png")
  94. plt.savefig(output_file, bbox_inches='tight', pad_inches=0)
  95. # Close the plot to release memory
  96. plt.close(fig)
  97. # Append the extracted information to the list
  98. csvdata.append([file_path, subject_id, time_point, int_time_point, merged_time_point, q_type, mask_name, pp ,average_value,strokeFlag])
  99. # Create a DataFrame
  100. df = pd.DataFrame(csvdata,
  101. columns=["fullpath", "subjectID", "timePoint", "int_timepoint", "merged_timepoint", "Qtype", "mask_name",
  102. "dialation_amount","Value","Group"])
  103. # Return DataFrame
  104. return df
  105. def main():
  106. # Parse command-line arguments
  107. parser = argparse.ArgumentParser(description="Process DWI files and generate figures.")
  108. parser.add_argument("-i", "--input", type=str, required=True, help="Input directory containing DWI files. e.g: proc_data (flipped fa, md etc file must be available. registering of masks must happen before this step")
  109. parser.add_argument("-o", "--output", type=str, required=True, help="Output directory to save output results (csv and figures).")
  110. parser.add_argument("-f", "--figures", action="store_true", default=False, help="set if you want figures to be saved as pngs")
  111. args = parser.parse_args()
  112. # Create the output directory if it does not exist
  113. create_output_dir(args.output)
  114. # Process files
  115. df = process_files(args.input, args.output,args.figures)
  116. # Save the DataFrame as CSV
  117. csv_file = os.path.join(args.output, "Quantitative_results_from_dwi_processing.csv")
  118. df.to_csv(csv_file, index=False)
  119. print("CSV file created at:", csv_file)
  120. # Print the DataFrame
  121. print(df)
  122. if __name__ == "__main__":
  123. main()