GetQuantitativeValues_onlyCST.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Mon Nov 11 09:35:42 2024
  4. @author: arefks
  5. """
  6. import os
  7. import glob
  8. import pandas as pd
  9. import nibabel as nib
  10. import numpy as np
  11. from tqdm import tqdm
  12. from concurrent.futures import ProcessPoolExecutor
  13. from scipy.ndimage import binary_dilation
  14. def create_output_dir(output_dir):
  15. """Create the output directory if it does not exist."""
  16. if not os.path.exists(output_dir):
  17. os.makedirs(output_dir)
  18. def save_mask_as_nifti(mask_data, output_file):
  19. """Save mask data as NIfTI file."""
  20. mask_data_int = mask_data.astype(np.uint8)
  21. mask_img = nib.Nifti1Image(mask_data_int, affine=None)
  22. nib.save(mask_img, output_file)
  23. def process_single_file(file_info):
  24. """Process a single file path."""
  25. file_path, session_mapping, output_path, save_nifti_mask = file_info
  26. csvdata = [] # Local CSV data collector
  27. # Extract information from the file path
  28. subject_id = file_path.split(os.sep)[-5]
  29. time_point = file_path.split(os.sep)[-4]
  30. # Map the extracted time_point directly using session_mapping
  31. merged_time_point = session_mapping.get(time_point, None)
  32. if merged_time_point is None:
  33. print(f"Warning: No mapping found for {time_point}. Skipping file: {file_path}")
  34. return []
  35. # Extract the q_type from the filename
  36. q_type = os.path.basename(file_path).split("_flipped")[0]
  37. # Attempt to find the stroke mask path
  38. try:
  39. search_stroke = os.path.join(os.path.dirname(file_path), "*StrokeMask_scaled.nii")
  40. stroke_path = glob.glob(search_stroke)[0]
  41. except IndexError:
  42. stroke_path = None
  43. # Create the temp path to search for dwi_masks
  44. temp_path = os.path.join(os.path.dirname(file_path), "RegisteredTractMasks_adjusted")
  45. dwi_masks = glob.glob(os.path.join(temp_path, "CST*dwi_flipped.nii.gz")) + glob.glob(os.path.join(temp_path, "White_matter*dwi_flipped.nii.gz"))
  46. white_matter_index = next((i for i, path in enumerate(dwi_masks) if "White_matter" in path), None)
  47. if white_matter_index is None:
  48. print(f"Warning: White matter mask not found in {temp_path}. Skipping file: {file_path}")
  49. return []
  50. white_matter_path = dwi_masks.pop(white_matter_index)
  51. # Load white matter mask
  52. white_matter_img = nib.load(white_matter_path)
  53. white_matter_data = white_matter_img.get_fdata()
  54. # Load DWI data using nibabel
  55. dwi_img = nib.load(file_path)
  56. dwi_data = dwi_img.get_fdata()
  57. # Loop through masks and calculate average values
  58. pix_dialation = [0, 1, 2, 3, 4]
  59. for mask_path in dwi_masks:
  60. mask_name = os.path.basename(mask_path).replace("registered", "").replace("flipped", "").replace(".nii.gz", "").replace("__dwi_", "").replace("_dwi_", "")
  61. mask_img = nib.load(mask_path)
  62. mask_data_pre_dilation = mask_img.get_fdata()
  63. # Determine if it's an AMBA mask
  64. AMBA_flag = "AMBA" in mask_name
  65. # Calculate the average value within the mask region for each dilation
  66. for pp in pix_dialation:
  67. if pp != 0:
  68. mask_data0 = binary_dilation(mask_data_pre_dilation, iterations=pp)
  69. else:
  70. mask_data0 = mask_data_pre_dilation > 0
  71. mask_data = mask_data0 #& (white_matter_data > 0)
  72. if stroke_path:
  73. stroke_image = nib.load(stroke_path)
  74. stroke_data = stroke_image.get_fdata()
  75. # Subtracting stroke region from mask except in baseline
  76. if merged_time_point != 0:
  77. mask_data = mask_data & (stroke_data < 1)
  78. # Create an empty array to hold the updated mask data
  79. new_mask_data = np.zeros_like(mask_data)
  80. # Iterate through each slice along the third axis and only retain slices where there is data in stroke_data
  81. for z in range(stroke_data.shape[2]):
  82. if np.any(stroke_data[:, :, z]):
  83. new_mask_data[:, :, z] = mask_data[:, :, z]
  84. # Replace mask_data with the new filtered mask
  85. mask_data = new_mask_data
  86. strokeFlag = "Stroke"
  87. else:
  88. strokeFlag = "Sham"
  89. # Calculate average DWI value within the mask
  90. masked_data = dwi_data[mask_data > 0]
  91. non_zero_masked_data = masked_data[masked_data != 0] # Exclude zero values
  92. average_value = np.nanmean(non_zero_masked_data) # Calculate mean only on non-zero elements
  93. # Calculate the number of voxels in the mask
  94. num_voxels = non_zero_masked_data.size
  95. # Save the mask as NIfTI file if the flag is set
  96. if save_nifti_mask and pp > 0 and merged_time_point in [3] and q_type == "fa":
  97. output_nifti_file = os.path.join(output_path,
  98. f"{subject_id}_tp_{merged_time_point}_{q_type}_{mask_name}_{pp}.nii.gz")
  99. save_mask_as_nifti(mask_data, output_nifti_file)
  100. # Append data to local CSV data
  101. csvdata.append([file_path, subject_id, time_point, merged_time_point, q_type, mask_name, pp, average_value, strokeFlag, AMBA_flag, num_voxels])
  102. return csvdata
  103. def process_files(input_path, output_path, save_nifti_mask):
  104. """Process DWI files and generate data using parallel processing."""
  105. session_mapping = {
  106. "ses-Baseline": 0, "ses-Baseline1": 0, "ses-pre": 0,
  107. "ses-P1": 3, "ses-P2": 3, "ses-P3": 3, "ses-P4": 3, "ses-P5": 3,
  108. "ses-P6": 7, "ses-P7": 7, "ses-P8": 7, "ses-P9": 7, "ses-P10": 7,
  109. "ses-P11": 14, "ses-P12": 14, "ses-P13": 14, "ses-P14": 14,
  110. "ses-P15": 14, "ses-P16": 14, "ses-P17": 14, "ses-P18": 14,
  111. "ses-P19": 21, "ses-D21": 21, "ses-P20": 21, "ses-P21": 21,
  112. "ses-P22": 21, "ses-P23": 21, "ses-P24": 21, "ses-P25": 21,
  113. "ses-P26": 28, "ses-P27": 28, "ses-P28": 28, "ses-P29": 28,
  114. "ses-P30": 28, "ses-P42": 42, "ses-P43": 42,
  115. "ses-P56": 56, "ses-P57": 56, "ses-P58": 56,
  116. "ses-P151": 28
  117. }
  118. # Get all relevant file paths
  119. file_paths = glob.glob(os.path.join(input_path, "**", "dwi", "DSI_studio", "*_flipped.nii.gz"), recursive=True)
  120. # Prepare arguments for parallel processing
  121. arguments = [(file_path, session_mapping, output_path, save_nifti_mask) for file_path in file_paths]
  122. # Initialize a list to store results
  123. csv_data = []
  124. # Use ProcessPoolExecutor to parallelize the file path processing
  125. with ProcessPoolExecutor(max_workers=6) as executor:
  126. # Process files in parallel
  127. results = list(tqdm(executor.map(process_single_file, arguments), total=len(file_paths), desc="Processing files"))
  128. # Aggregate the results
  129. for result in results:
  130. csv_data.extend(result)
  131. # Create a DataFrame
  132. df = pd.DataFrame(csv_data,
  133. columns=["fullpath", "subjectID", "timePoint", "merged_timepoint", "Qtype", "mask_name",
  134. "dialation_amount", "Value", "Group", "is_it_AMBA", "num_voxels"])
  135. return df
  136. def main():
  137. # Use argparse to get command-line inputs
  138. import argparse
  139. parser = argparse.ArgumentParser(description="Process DWI files and generate data.")
  140. parser.add_argument("-i", "--input", type=str, required=True, help="Input directory containing DWI files.")
  141. parser.add_argument("-o", "--output", type=str, required=True, help="Output directory to save results.")
  142. parser.add_argument("-n", "--nifti-mask", action="store_true", default=False, help="Set to save masks as NIfTI files.")
  143. args = parser.parse_args()
  144. # Create the output directory if it does not exist
  145. create_output_dir(args.output)
  146. # Process files
  147. df = process_files(args.input, args.output, args.nifti_mask)
  148. # Save the DataFrame as CSV
  149. csv_file = os.path.join(args.output, "Quantitative_results_from_dwi_processing.csv")
  150. df.to_csv(csv_file, index=False)
  151. print("CSV file created at:", csv_file)
  152. if __name__ == "__main__":
  153. main()