GetQuantitativeValues_only_in_CST.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Mon Nov 11 09:35:42 2024
  4. @author: arefks
  5. """
  6. import os
  7. import glob
  8. import pandas as pd
  9. import nibabel as nib
  10. import numpy as np
  11. from tqdm import tqdm
  12. from concurrent.futures import ProcessPoolExecutor
  13. from scipy.ndimage import binary_dilation
  14. def create_output_dir(output_dir):
  15. """Create the output directory if it does not exist."""
  16. if not os.path.exists(output_dir):
  17. os.makedirs(output_dir)
  18. def save_mask_as_nifti(mask_data, output_file):
  19. """Save mask data as NIfTI file."""
  20. mask_data_int = mask_data.astype(np.uint8)
  21. mask_img = nib.Nifti1Image(mask_data_int, affine=None)
  22. nib.save(mask_img, output_file)
  23. def process_single_file(file_info):
  24. """Process a single file path."""
  25. file_path, session_mapping, output_path, save_nifti_mask = file_info
  26. csvdata = [] # Local CSV data collector
  27. # Extract information from the file path
  28. subject_id = file_path.split(os.sep)[-5]
  29. time_point = file_path.split(os.sep)[-4]
  30. # Map the extracted time_point directly using session_mapping
  31. merged_time_point = session_mapping.get(time_point, None)
  32. if merged_time_point is None:
  33. print(f"Warning: No mapping found for {time_point}. Skipping file: {file_path}")
  34. return []
  35. # Extract the q_type from the filename
  36. q_type = os.path.basename(file_path).split("_flipped")[0]
  37. # Attempt to find the stroke mask path
  38. try:
  39. search_stroke = os.path.join(os.path.dirname(file_path), "*StrokeMask_scaled.nii")
  40. stroke_path = glob.glob(search_stroke)[0]
  41. except IndexError:
  42. stroke_path = None
  43. # Create the temp path to search for dwi_masks
  44. temp_path = os.path.join(os.path.dirname(file_path), "RegisteredTractMasks_adjusted")
  45. dwi_masks = glob.glob(os.path.join(temp_path, "CST*dwi_flipped.nii.gz")) + glob.glob(os.path.join(temp_path, "White_matter*dwi_flipped.nii.gz"))
  46. white_matter_index = next((i for i, path in enumerate(dwi_masks) if "White_matter" in path), None)
  47. if white_matter_index is None:
  48. print(f"Warning: White matter mask not found in {temp_path}. Skipping file: {file_path}")
  49. return []
  50. white_matter_path = dwi_masks.pop(white_matter_index)
  51. # Load white matter mask
  52. white_matter_img = nib.load(white_matter_path)
  53. white_matter_data = white_matter_img.get_fdata()
  54. # Load DWI data using nibabel
  55. dwi_img = nib.load(file_path)
  56. dwi_data = dwi_img.get_fdata()
  57. # Loop through masks and calculate average values
  58. pix_dialation = [0, 1, 2, 3, 4]
  59. for mask_path in dwi_masks:
  60. mask_name = os.path.basename(mask_path).replace("registered", "").replace("flipped", "").replace(".nii.gz", "").replace("__dwi_", "").replace("_dwi_", "")
  61. mask_img = nib.load(mask_path)
  62. mask_data_pre_dilation = mask_img.get_fdata()
  63. # Determine if it's an AMBA mask
  64. AMBA_flag = "AMBA" in mask_name
  65. # Calculate the average value within the mask region for each dilation
  66. for pp in pix_dialation:
  67. if pp != 0:
  68. mask_data0 = binary_dilation(mask_data_pre_dilation, iterations=pp)
  69. else:
  70. mask_data0 = mask_data_pre_dilation > 0
  71. mask_data = mask_data0 # & (white_matter_data > 0)
  72. if stroke_path:
  73. stroke_image = nib.load(stroke_path)
  74. stroke_data = stroke_image.get_fdata()
  75. # Subtracting stroke region from mask except in baseline
  76. if merged_time_point != 0:
  77. mask_data = mask_data & (stroke_data < 1)
  78. strokeFlag = "Stroke"
  79. else:
  80. strokeFlag = "Sham"
  81. # Find the first and last non-zero slices in the mask
  82. first_non_zero_slice = None
  83. last_non_zero_slice = None
  84. for z in range(mask_data.shape[2]):
  85. if np.any(mask_data[:, :, z]):
  86. if first_non_zero_slice is None:
  87. first_non_zero_slice = z
  88. last_non_zero_slice = z
  89. if first_non_zero_slice is not None and last_non_zero_slice is not None:
  90. # Create partitions A and B based on the z-axis starting from the first non-zero slice to the last non-zero slice
  91. partition_A_mask = np.zeros_like(mask_data, dtype=bool)
  92. partition_B_mask = np.zeros_like(mask_data, dtype=bool)
  93. GAP = (last_non_zero_slice - first_non_zero_slice)
  94. first_part = int(np.ceil(GAP*0.5))
  95. # Define partition A as the first 4 slices after the first non-zero slice and partition B as the rest until the last non-zero slice
  96. partition_boundary = first_non_zero_slice + first_part
  97. partition_A_mask[:, :, first_non_zero_slice:partition_boundary] = mask_data[:, :, first_non_zero_slice:partition_boundary]
  98. partition_B_mask[:, :, partition_boundary:last_non_zero_slice + 1] = mask_data[:, :, partition_boundary:last_non_zero_slice + 1]
  99. # Calculate average DWI value within the mask for each partition (A and B)
  100. for partition, partition_mask in zip(['A', 'B'], [partition_A_mask, partition_B_mask]):
  101. # Calculate average DWI value within the partition mask
  102. masked_data = dwi_data[partition_mask > 0]
  103. non_zero_masked_data = masked_data[masked_data != 0] # Exclude zero values
  104. average_value = np.nanmean(non_zero_masked_data) # Calculate mean only on non-zero elements
  105. # Calculate the number of voxels in the mask
  106. num_voxels = non_zero_masked_data.size
  107. # Determine starting and ending slice for each partition
  108. if partition == 'A':
  109. partition_start_slice = first_non_zero_slice
  110. partition_end_slice = min(partition_boundary - 1, last_non_zero_slice)
  111. else:
  112. partition_start_slice = partition_boundary
  113. partition_end_slice = last_non_zero_slice
  114. # Save the mask as NIfTI file if the flag is set
  115. if save_nifti_mask and pp > 0 and merged_time_point in [3] and q_type == "fa":
  116. output_nifti_file = os.path.join(output_path,
  117. f"{subject_id}_tp_{merged_time_point}_{q_type}_{mask_name}_{pp}.nii.gz")
  118. save_mask_as_nifti(mask_data, output_nifti_file)
  119. # Append data to local CSV data
  120. csvdata.append([file_path, subject_id, time_point, merged_time_point, q_type, mask_name, pp, average_value, strokeFlag, AMBA_flag, num_voxels, partition, partition_start_slice, partition_end_slice])
  121. return csvdata
  122. def process_files(input_path, output_path, save_nifti_mask):
  123. """Process DWI files and generate data using parallel processing."""
  124. session_mapping = {
  125. "ses-Baseline": 0, "ses-Baseline1": 0, "ses-pre": 0,
  126. "ses-P1": 3, "ses-P2": 3, "ses-P3": 3, "ses-P4": 3, "ses-P5": 3,
  127. "ses-P6": 7, "ses-P7": 7, "ses-P8": 7, "ses-P9": 7, "ses-P10": 7,
  128. "ses-P11": 14, "ses-P12": 14, "ses-P13": 14, "ses-P14": 14,
  129. "ses-P15": 14, "ses-P16": 14, "ses-P17": 14, "ses-P18": 14,
  130. "ses-P19": 21, "ses-D21": 21, "ses-P20": 21, "ses-P21": 21,
  131. "ses-P22": 21, "ses-P23": 21, "ses-P24": 21, "ses-P25": 21,
  132. "ses-P26": 28, "ses-P27": 28, "ses-P28": 28, "ses-P29": 28,
  133. "ses-P30": 28, "ses-P42": 42, "ses-P43": 42,
  134. "ses-P56": 56, "ses-P57": 56, "ses-P58": 56,
  135. "ses-P151": 28
  136. }
  137. # Get all relevant file paths
  138. file_paths = glob.glob(os.path.join(input_path, "**", "dwi", "DSI_studio", "*_flipped.nii.gz"), recursive=True)
  139. # Prepare arguments for parallel processing
  140. arguments = [(file_path, session_mapping, output_path, save_nifti_mask) for file_path in file_paths]
  141. # Initialize a list to store results
  142. csv_data = []
  143. # Use ProcessPoolExecutor to parallelize the file path processing
  144. with ProcessPoolExecutor(max_workers=6) as executor:
  145. # Process files in parallel
  146. results = list(tqdm(executor.map(process_single_file, arguments), total=len(file_paths), desc="Processing files"))
  147. # Aggregate the results
  148. for result in results:
  149. csv_data.extend(result)
  150. # Create a DataFrame
  151. df = pd.DataFrame(csv_data,
  152. columns=["fullpath", "subjectID", "timePoint", "merged_timepoint", "Qtype", "mask_name",
  153. "dialation_amount", "Value", "Group", "is_it_AMBA", "num_voxels", "partition", "start_slice", "end_slice"])
  154. return df
  155. def main():
  156. # Use argparse to get command-line inputs
  157. import argparse
  158. parser = argparse.ArgumentParser(description="Process DWI files and generate data.")
  159. parser.add_argument("-i", "--input", type=str, required=True, help="Input directory containing DWI files.")
  160. parser.add_argument("-o", "--output", type=str, required=True, help="Output directory to save results.")
  161. parser.add_argument("-n", "--nifti-mask", action="store_true", default=False, help="Set to save masks as NIfTI files.")
  162. args = parser.parse_args()
  163. # Create the output directory if it does not exist
  164. create_output_dir(args.output)
  165. # Process files
  166. df = process_files(args.input, args.output, args.nifti_mask)
  167. # Save the DataFrame as CSV
  168. csv_file = os.path.join(args.output, "Quantitative_results_from_dwi_processing.csv")
  169. df.to_csv(csv_file, index=False)
  170. print("CSV file created at:", csv_file)
  171. if __name__ == "__main__":
  172. main()