GetQuantitativeValues.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import os
  2. import glob
  3. import pandas as pd
  4. import nibabel as nib
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. from tqdm import tqdm
  8. import argparse
  9. from scipy.ndimage import binary_dilation
  10. def create_output_dir(output_dir):
  11. """Create the output directory if it does not exist."""
  12. if not os.path.exists(output_dir):
  13. os.makedirs(output_dir)
  14. def save_mask_as_nifti(mask_data, output_file):
  15. """Save mask data as NIfTI file."""
  16. # Convert boolean mask to integer (0s and 1s)
  17. mask_data_int = mask_data.astype(np.uint8)
  18. # Create NIfTI image object
  19. mask_img = nib.Nifti1Image(mask_data_int, affine=None)
  20. # Save NIfTI image
  21. nib.save(mask_img, output_file)
  22. def process_files(input_path, output_path, create_figures, save_nifti_mask):
  23. """Process DWI files and generate figures."""
  24. # Define a dictionary to map Experiment values to the desired timepoints
  25. session_mapping = {
  26. 0: 0,
  27. 1: 3, 2: 3, 3: 3,
  28. 4: 3, 5: 3, 6: 7, 7: 7,
  29. 8: 7, 9: 7, 10: 7, 11: 14, 12: 14,
  30. 13: 14, 14: 14, 15: 14, 16: 14, 17: 14, 18: 14, 19: 21,
  31. 20: 21, 21: 21, 22: 21, 23: 21, 24: 21, 25: 21, 26: 28,
  32. 27: 28, 28: 28, 29: 28, 30: 28 , 42:42, 43:42, 56:56, 57:56
  33. }
  34. # Initialize a list to store extracted information
  35. csvdata = []
  36. # Iterate over files with progress bar
  37. file_paths = glob.glob(os.path.join(input_path, "**", "dwi", "DSI_studio", "*_flipped.nii.gz"), recursive=True)
  38. for file_path in tqdm(file_paths, desc="Processing files"):
  39. # Extract information from the file path
  40. subject_id = file_path.split(os.sep)[-5]
  41. time_point = file_path.split(os.sep)[-4]
  42. int_time_point = 0 if time_point == "ses-Baseline" else int(time_point.split("-P")[1])
  43. merged_time_point = session_mapping[int_time_point]
  44. q_type = os.path.basename(file_path).split("_flipped")[0]
  45. try:
  46. searchStroke = os.path.join(os.path.dirname(file_path),"*StrokeMask_scaled.nii")
  47. stroke_path = glob.glob(searchStroke)[0]
  48. except IndexError:
  49. stroke_path = False
  50. # Create the temp path to search for dwi_masks
  51. temp_path = os.path.join(os.path.dirname(file_path), "RegisteredTractMasks_adjusted")
  52. dwi_masks = glob.glob(os.path.join(temp_path, "*dwi_flipped.nii.gz"))
  53. # Load DWI data using nibabel
  54. dwi_img = nib.load(file_path)
  55. dwi_data = dwi_img.get_fdata()
  56. # Loop through masks and calculate average values
  57. max_pixels = -1
  58. selected_slice = None
  59. pix_dialation = [0,1,2,3,4]
  60. for mask_path in tqdm(dwi_masks, desc="Processing masks", leave=False):
  61. mask_name = os.path.basename(mask_path).replace("registered", "").replace("flipped", "").replace(".nii.gz", "").replace("__dwi_","")
  62. mask_img = nib.load(mask_path)
  63. mask_data_pre_dilation = mask_img.get_fdata()
  64. # Find the slice with the highest number of pixels
  65. for i in range(mask_data_pre_dilation.shape[2]):
  66. num_pixels = np.count_nonzero(mask_data_pre_dilation[..., i])
  67. if num_pixels > max_pixels:
  68. max_pixels = num_pixels
  69. selected_slice = i
  70. # Assuming you want to calculate the average value within the mask region
  71. for pp in pix_dialation:
  72. if pp != 0:
  73. mask_data = binary_dilation(mask_data_pre_dilation, iterations=pp)
  74. else:
  75. mask_data = mask_data_pre_dilation > 0
  76. if stroke_path:
  77. stroke_image = nib.load(stroke_path)
  78. stroke_data = stroke_image.get_fdata()
  79. # Subtracting stroke region from mask
  80. if merged_time_point == 3 or merged_time_point == 28:
  81. mask_data = mask_data & (stroke_data < 1)
  82. else:
  83. mask_data = mask_data
  84. masked_data = dwi_data[mask_data > 0]
  85. non_zero_masked_data = masked_data[masked_data != 0] # Exclude zero values
  86. average_value = np.nanmean(non_zero_masked_data) # Calculate mean only on non-zero elements
  87. strokeFlag = "Stroke"
  88. else:
  89. masked_data = dwi_data[mask_data > 0]
  90. non_zero_masked_data = masked_data[masked_data != 0] # Exclude zero values
  91. average_value = np.nanmean(non_zero_masked_data) # Calculate mean only on non-zero elements
  92. strokeFlag = "Sham"
  93. # Create and save the figure if the flag is True
  94. if create_figures:
  95. # Create a plot of the selected slice with overlay
  96. fig, ax = plt.subplots(1, 1, figsize=(10, 10), dpi=50)
  97. ax.imshow(dwi_data[..., selected_slice], cmap='gray')
  98. ax.imshow(mask_data[..., selected_slice], cmap='jet', alpha=0.5) # Overlay mask_data
  99. ax.set_title(f"{subject_id}_tp_{merged_time_point}_{q_type}_{average_value:.2f}_{mask_name}_{pp}")
  100. plt.axis('off')
  101. # Save the plot to the output path
  102. output_file = os.path.join(output_path,
  103. f"{subject_id}_tp_{merged_time_point}_{q_type}_{average_value:.2f}_{mask_name}_{pp}.png")
  104. plt.savefig(output_file, bbox_inches='tight', pad_inches=0)
  105. # Close the plot to release memory
  106. plt.close(fig)
  107. if save_nifti_mask and pp > 0 and merged_time_point in [3, 28] and q_type == "fa":
  108. # Save the mask as NIfTI file
  109. output_nifti_file = os.path.join(output_path,
  110. f"{subject_id}_tp_{merged_time_point}_{q_type}_{mask_name}_{pp}.nii.gz")
  111. save_mask_as_nifti(mask_data, output_nifti_file)
  112. # Append the extracted information to the list
  113. csvdata.append([file_path, subject_id, time_point, int_time_point, merged_time_point, q_type, mask_name, pp ,average_value,strokeFlag])
  114. # Create a DataFrame
  115. df = pd.DataFrame(csvdata,
  116. columns=["fullpath", "subjectID", "timePoint", "int_timepoint", "merged_timepoint", "Qtype", "mask_name",
  117. "dialation_amount","Value","Group"])
  118. # Return DataFrame
  119. return df
  120. def main():
  121. # Parse command-line arguments
  122. parser = argparse.ArgumentParser(description="Process DWI files and generate figures.")
  123. parser.add_argument("-i", "--input", type=str, required=True, help="Input directory containing DWI files. e.g: proc_data (flipped fa, md etc file must be available. registering of masks must happen before this step")
  124. parser.add_argument("-o", "--output", type=str, required=True, help="Output directory to save output results (csv and figures).")
  125. parser.add_argument("-f", "--figures", action="store_true", default=False, help="set if you want figures to be saved as pngs")
  126. parser.add_argument("-n", "--nifti-mask", action="store_true", default=False, help="set if you want masks to be saved as NIfTI files.")
  127. args = parser.parse_args()
  128. # Create the output directory if it does not exist
  129. create_output_dir(args.output)
  130. # Process files
  131. df = process_files(args.input, args.output, args.figures, args.nifti_mask)
  132. # Save the DataFrame as CSV
  133. csv_file = os.path.join(args.output, "Quantitative_results_from_dwi_processing.csv")
  134. df.to_csv(csv_file, index=False)
  135. print("CSV file created at:", csv_file)
  136. # Print the DataFrame
  137. print(df)
  138. if __name__ == "__main__":
  139. main()