GetQuantitativeValues_not_parallel.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import os
  2. import glob
  3. import pandas as pd
  4. import nibabel as nib
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. from tqdm import tqdm
  8. import argparse
  9. from scipy.ndimage import binary_dilation
  10. def create_output_dir(output_dir):
  11. """Create the output directory if it does not exist."""
  12. if not os.path.exists(output_dir):
  13. os.makedirs(output_dir)
  14. def save_mask_as_nifti(mask_data, output_file):
  15. """Save mask data as NIfTI file."""
  16. # Convert boolean mask to integer (0s and 1s)
  17. mask_data_int = mask_data.astype(np.uint8)
  18. # Create NIfTI image object
  19. mask_img = nib.Nifti1Image(mask_data_int, affine=None)
  20. # Save NIfTI image
  21. nib.save(mask_img, output_file)
  22. def process_files(input_path, output_path, create_figures, save_nifti_mask):
  23. """Process DWI files and generate figures."""
  24. # Define a comprehensive mapping for all expected timepoints
  25. session_mapping = {
  26. "ses-Baseline": 0, "ses-Baseline1": 0, "ses-pre": 0,
  27. "ses-P1": 3, "ses-P2": 3, "ses-P3": 3, "ses-P": 3,
  28. "ses-P4": 3, "ses-P5": 3, "ses-P6": 7, "ses-P7": 7,
  29. "ses-P8": 7, "ses-P9": 7, "ses-P10": 7,
  30. "ses-P11": 14, "ses-P12": 14, "ses-P13": 14, "ses-P14": 14,
  31. "ses-P15": 14, "ses-P16": 14, "ses-P17": 14, "ses-P18": 14,
  32. "ses-P19": 21, "ses-D21": 21, "ses-P20": 21, "ses-P21": 21,
  33. "ses-P22": 21, "ses-P23": 21, "ses-P24": 21, "ses-P25": 21,
  34. "ses-P26": 28, "ses-P27b": 28, "ses-P151": 28, "ses-P152": 28,
  35. "ses-P27": 28, "ses-P28": 28, "ses-P29": 28, "ses-P30": 28,
  36. "ses-P42": 42, "ses-P43": 42,
  37. "ses-P56": 56, "ses-P222": 56, "ses-P57": 56, "ses-P58": 56
  38. }
  39. # Initialize a list to store extracted information
  40. csvdata = []
  41. # Get all relevant file paths with a progress bar
  42. file_paths = glob.glob(os.path.join(input_path, "**", "dwi", "DSI_studio", "*_flipped.nii.gz"), recursive=True)
  43. for file_path in tqdm(file_paths, desc="Processing files"):
  44. # Extract information from the file path
  45. subject_id = file_path.split(os.sep)[-5]
  46. time_point = file_path.split(os.sep)[-4]
  47. # Map the extracted time_point directly using session_mapping
  48. merged_time_point = session_mapping.get(time_point, None)
  49. # If mapping is not found, skip this file
  50. if merged_time_point is None:
  51. print(f"Warning: No mapping found for {time_point}. Skipping file: {file_path}")
  52. continue
  53. # Extract the q_type from the filename
  54. q_type = os.path.basename(file_path).split("_flipped")[0]
  55. # Attempt to find the stroke mask path
  56. try:
  57. search_stroke = os.path.join(os.path.dirname(file_path), "*StrokeMask_scaled.nii")
  58. stroke_path = glob.glob(search_stroke)[0]
  59. except IndexError:
  60. stroke_path = None
  61. # Create the temp path to search for dwi_masks
  62. temp_path = os.path.join(os.path.dirname(file_path), "RegisteredTractMasks_adjusted")
  63. dwi_masks = glob.glob(os.path.join(temp_path, "*dwi_flipped.nii.gz"))
  64. white_matter_index = next((i for i, path in enumerate(dwi_masks) if "White_matter" in path), None)
  65. white_matter_path = dwi_masks.pop(white_matter_index)
  66. # Load white matter mask
  67. white_matter_img = nib.load(white_matter_path)
  68. white_matter_data = white_matter_img.get_fdata()
  69. # Load DWI data using nibabel
  70. dwi_img = nib.load(file_path)
  71. dwi_data = dwi_img.get_fdata()
  72. # Loop through masks and calculate average values
  73. max_pixels = -1
  74. selected_slice = None
  75. pix_dialation = [0,1,2,3,4]
  76. for mask_path in tqdm(dwi_masks, desc="Processing masks", leave=False):
  77. mask_name = os.path.basename(mask_path).replace("registered", "").replace("flipped", "").replace(".nii.gz", "").replace("__dwi_","").replace("_dwi_","")
  78. mask_img = nib.load(mask_path)
  79. mask_data_pre_dilation = mask_img.get_fdata()
  80. # Append the extracted information to the list
  81. if "AMBA" in mask_name:
  82. AMBA_flag = True
  83. else:
  84. AMBA_flag = False
  85. # Find the slice with the highest number of pixels
  86. for i in range(mask_data_pre_dilation.shape[2]):
  87. num_pixels = np.count_nonzero(mask_data_pre_dilation[..., i])
  88. if num_pixels > max_pixels:
  89. max_pixels = num_pixels
  90. selected_slice = i
  91. # Assuming you want to calculate the average value within the mask region
  92. for pp in pix_dialation:
  93. if pp != 0:
  94. mask_data0 = binary_dilation(mask_data_pre_dilation, iterations=pp)
  95. else:
  96. mask_data0 = mask_data_pre_dilation > 0
  97. mask_data = mask_data0 & (white_matter_data > 0)
  98. if stroke_path:
  99. stroke_image = nib.load(stroke_path)
  100. stroke_data = stroke_image.get_fdata()
  101. # Subtracting stroke region from mask except in baseline
  102. if merged_time_point != 0:
  103. mask_data = mask_data & (stroke_data < 1)
  104. else:
  105. mask_data = mask_data
  106. masked_data = dwi_data[mask_data > 0]
  107. non_zero_masked_data = masked_data[masked_data != 0] # Exclude zero values
  108. average_value = np.nanmean(non_zero_masked_data) # Calculate mean only on non-zero elements
  109. strokeFlag = "Stroke"
  110. else:
  111. masked_data = dwi_data[mask_data > 0]
  112. non_zero_masked_data = masked_data[masked_data != 0] # Exclude zero values
  113. average_value = np.nanmean(non_zero_masked_data) # Calculate mean only on non-zero elements
  114. strokeFlag = "Sham"
  115. # Create and save the figure if the flag is True
  116. if create_figures:
  117. # Create a plot of the selected slice with overlay
  118. fig, ax = plt.subplots(1, 1, figsize=(10, 10), dpi=50)
  119. ax.imshow(dwi_data[..., selected_slice], cmap='gray')
  120. ax.imshow(mask_data[..., selected_slice], cmap='jet', alpha=0.5) # Overlay mask_data
  121. ax.set_title(f"{subject_id}_tp_{merged_time_point}_{q_type}_{average_value:.2f}_{mask_name}_{pp}")
  122. plt.axis('off')
  123. # Save the plot to the output path
  124. output_file = os.path.join(output_path,
  125. f"{subject_id}_tp_{merged_time_point}_{q_type}_{average_value:.2f}_{mask_name}_{pp}.png")
  126. plt.savefig(output_file, bbox_inches='tight', pad_inches=0)
  127. # Close the plot to release memory
  128. plt.close(fig)
  129. if save_nifti_mask and pp > 0 and merged_time_point in [3] and q_type == "fa":
  130. # Save the mask as NIfTI file
  131. output_nifti_file = os.path.join(output_path,
  132. f"{subject_id}_tp_{merged_time_point}_{q_type}_{mask_name}_{pp}.nii.gz")
  133. save_mask_as_nifti(mask_data, output_nifti_file)
  134. csvdata.append([file_path, subject_id, time_point, merged_time_point, q_type, mask_name, pp ,average_value,strokeFlag,AMBA_flag])
  135. # Create a DataFrame
  136. df = pd.DataFrame(csvdata,
  137. columns=["fullpath", "subjectID", "timePoint", "merged_timepoint", "Qtype", "mask_name",
  138. "dialation_amount","Value","Group","is_it_AMBA"])
  139. # Return DataFrame
  140. return df
  141. def main():
  142. # Parse command-line arguments
  143. parser = argparse.ArgumentParser(description="Process DWI files and generate figures.")
  144. parser.add_argument("-i", "--input", type=str, required=True, help="Input directory containing DWI files. e.g: proc_data (flipped fa, md etc file must be available. registering of masks must happen before this step")
  145. parser.add_argument("-o", "--output", type=str, required=True, help="Output directory to save output results (csv and figures).")
  146. parser.add_argument("-f", "--figures", action="store_true", default=False, help="set if you want figures to be saved as pngs")
  147. parser.add_argument("-n", "--nifti-mask", action="store_true", default=False, help="set if you want masks to be saved as NIfTI files.")
  148. args = parser.parse_args()
  149. # Create the output directory if it does not exist
  150. create_output_dir(args.output)
  151. # Process files
  152. df = process_files(args.input, args.output, args.figures, args.nifti_mask)
  153. # Save the DataFrame as CSV
  154. csv_file = os.path.join(args.output, "Quantitative_results_from_dwi_processing.csv")
  155. df.to_csv(csv_file, index=False)
  156. print("CSV file created at:", csv_file)
  157. # Print the DataFrame
  158. print(df)
  159. if __name__ == "__main__":
  160. main()