GetQuantitativeValues.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import os
  2. import glob
  3. import pandas as pd
  4. import nibabel as nib
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. from tqdm import tqdm
  8. import argparse
  9. from scipy.ndimage import binary_dilation
  10. def create_output_dir(output_dir):
  11. """Create the output directory if it does not exist."""
  12. if not os.path.exists(output_dir):
  13. os.makedirs(output_dir)
  14. def save_mask_as_nifti(mask_data, output_file):
  15. """Save mask data as NIfTI file."""
  16. # Convert boolean mask to integer (0s and 1s)
  17. mask_data_int = mask_data.astype(np.uint8)
  18. # Create NIfTI image object
  19. mask_img = nib.Nifti1Image(mask_data_int, affine=None)
  20. # Save NIfTI image
  21. nib.save(mask_img, output_file)
  22. def process_files(input_path, output_path, create_figures, save_nifti_mask):
  23. """Process DWI files and generate figures."""
  24. # Define a dictionary to map Experiment values to the desired timepoints
  25. session_mapping = {
  26. 0: 0,
  27. 1: 3, 2: 3, 3: 3,
  28. 4: 3, 5: 3, 6: 7, 7: 7,
  29. 8: 7, 9: 7, 10: 7, 11: 14, 12: 14,
  30. 13: 14, 14: 14, 15: 14, 16: 14, 17: 14, 18: 14, 19: 21,
  31. 20: 21, 21: 21, 22: 21, 23: 21, 24: 21, 25: 21, 26: 28,
  32. 27: 28, 28: 28, 29: 28, 30: 28 , 42:42, 43:42, 56:56, 57:56
  33. }
  34. # Initialize a list to store extracted information
  35. csvdata = []
  36. # Iterate over files with progress bar
  37. file_paths = glob.glob(os.path.join(input_path, "**", "dwi", "DSI_studio", "*_flipped.nii.gz"), recursive=True)
  38. for file_path in tqdm(file_paths, desc="Processing files"):
  39. # Extract information from the file path
  40. subject_id = file_path.split(os.sep)[-5]
  41. time_point = file_path.split(os.sep)[-4]
  42. int_time_point = 0 if time_point == "ses-Baseline" else int(time_point.split("-P")[1])
  43. merged_time_point = session_mapping[int_time_point]
  44. q_type = os.path.basename(file_path).split("_flipped")[0]
  45. try:
  46. searchStroke = os.path.join(os.path.dirname(file_path),"*StrokeMask_scaled.nii")
  47. stroke_path = glob.glob(searchStroke)[0]
  48. except IndexError:
  49. stroke_path = False
  50. # Create the temp path to search for dwi_masks
  51. temp_path = os.path.join(os.path.dirname(file_path), "RegisteredTractMasks_adjusted")
  52. dwi_masks = glob.glob(os.path.join(temp_path, "*dwi_flipped.nii.gz"))
  53. white_matter_index = next((i for i, path in enumerate(dwi_masks) if "White_matter" in path), None)
  54. white_matter_path = dwi_masks.pop(white_matter_index)
  55. # Load white matter mask
  56. white_matter_img = nib.load(white_matter_path)
  57. white_matter_data = white_matter_img.get_fdata()
  58. # Load DWI data using nibabel
  59. dwi_img = nib.load(file_path)
  60. dwi_data = dwi_img.get_fdata()
  61. # Loop through masks and calculate average values
  62. max_pixels = -1
  63. selected_slice = None
  64. pix_dialation = [0,1,2,3,4]
  65. for mask_path in tqdm(dwi_masks, desc="Processing masks", leave=False):
  66. mask_name = os.path.basename(mask_path).replace("registered", "").replace("flipped", "").replace(".nii.gz", "").replace("__dwi_","")
  67. mask_img = nib.load(mask_path)
  68. mask_data_pre_dilation = mask_img.get_fdata()
  69. # Find the slice with the highest number of pixels
  70. for i in range(mask_data_pre_dilation.shape[2]):
  71. num_pixels = np.count_nonzero(mask_data_pre_dilation[..., i])
  72. if num_pixels > max_pixels:
  73. max_pixels = num_pixels
  74. selected_slice = i
  75. # Assuming you want to calculate the average value within the mask region
  76. for pp in pix_dialation:
  77. if pp != 0:
  78. mask_data0 = binary_dilation(mask_data_pre_dilation, iterations=pp)
  79. else:
  80. mask_data0 = mask_data_pre_dilation > 0
  81. mask_data = mask_data0 & (white_matter_data > 0)
  82. if stroke_path:
  83. stroke_image = nib.load(stroke_path)
  84. stroke_data = stroke_image.get_fdata()
  85. # Subtracting stroke region from mask
  86. if merged_time_point == 3 or merged_time_point == 28:
  87. mask_data = mask_data & (stroke_data < 1)
  88. else:
  89. mask_data = mask_data
  90. masked_data = dwi_data[mask_data > 0]
  91. non_zero_masked_data = masked_data[masked_data != 0] # Exclude zero values
  92. average_value = np.nanmean(non_zero_masked_data) # Calculate mean only on non-zero elements
  93. strokeFlag = "Stroke"
  94. else:
  95. masked_data = dwi_data[mask_data > 0]
  96. non_zero_masked_data = masked_data[masked_data != 0] # Exclude zero values
  97. average_value = np.nanmean(non_zero_masked_data) # Calculate mean only on non-zero elements
  98. strokeFlag = "Sham"
  99. # Create and save the figure if the flag is True
  100. if create_figures:
  101. # Create a plot of the selected slice with overlay
  102. fig, ax = plt.subplots(1, 1, figsize=(10, 10), dpi=50)
  103. ax.imshow(dwi_data[..., selected_slice], cmap='gray')
  104. ax.imshow(mask_data[..., selected_slice], cmap='jet', alpha=0.5) # Overlay mask_data
  105. ax.set_title(f"{subject_id}_tp_{merged_time_point}_{q_type}_{average_value:.2f}_{mask_name}_{pp}")
  106. plt.axis('off')
  107. # Save the plot to the output path
  108. output_file = os.path.join(output_path,
  109. f"{subject_id}_tp_{merged_time_point}_{q_type}_{average_value:.2f}_{mask_name}_{pp}.png")
  110. plt.savefig(output_file, bbox_inches='tight', pad_inches=0)
  111. # Close the plot to release memory
  112. plt.close(fig)
  113. if save_nifti_mask and pp > 0 and merged_time_point in [3, 28] and q_type == "fa":
  114. # Save the mask as NIfTI file
  115. output_nifti_file = os.path.join(output_path,
  116. f"{subject_id}_tp_{merged_time_point}_{q_type}_{mask_name}_{pp}.nii.gz")
  117. save_mask_as_nifti(mask_data, output_nifti_file)
  118. # Append the extracted information to the list
  119. csvdata.append([file_path, subject_id, time_point, int_time_point, merged_time_point, q_type, mask_name, pp ,average_value,strokeFlag])
  120. # Create a DataFrame
  121. df = pd.DataFrame(csvdata,
  122. columns=["fullpath", "subjectID", "timePoint", "int_timepoint", "merged_timepoint", "Qtype", "mask_name",
  123. "dialation_amount","Value","Group"])
  124. # Return DataFrame
  125. return df
  126. def main():
  127. # Parse command-line arguments
  128. parser = argparse.ArgumentParser(description="Process DWI files and generate figures.")
  129. parser.add_argument("-i", "--input", type=str, required=True, help="Input directory containing DWI files. e.g: proc_data (flipped fa, md etc file must be available. registering of masks must happen before this step")
  130. parser.add_argument("-o", "--output", type=str, required=True, help="Output directory to save output results (csv and figures).")
  131. parser.add_argument("-f", "--figures", action="store_true", default=False, help="set if you want figures to be saved as pngs")
  132. parser.add_argument("-n", "--nifti-mask", action="store_true", default=False, help="set if you want masks to be saved as NIfTI files.")
  133. args = parser.parse_args()
  134. # Create the output directory if it does not exist
  135. create_output_dir(args.output)
  136. # Process files
  137. df = process_files(args.input, args.output, args.figures, args.nifti_mask)
  138. # Save the DataFrame as CSV
  139. csv_file = os.path.join(args.output, "Quantitative_results_from_dwi_processing.csv")
  140. df.to_csv(csv_file, index=False)
  141. print("CSV file created at:", csv_file)
  142. # Print the DataFrame
  143. print(df)
  144. if __name__ == "__main__":
  145. main()