GetQuantitativeValues_only_in_stroke_slices.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Mon Nov 11 12:49:37 2024
  4. @author: arefks
  5. """
  6. # -*- coding: utf-8 -*-
  7. """
  8. Created on Mon Nov 11 09:35:42 2024
  9. @author: arefks
  10. """
  11. import os
  12. import glob
  13. import pandas as pd
  14. import nibabel as nib
  15. import numpy as np
  16. from tqdm import tqdm
  17. from concurrent.futures import ProcessPoolExecutor
  18. from scipy.ndimage import binary_dilation
  19. def create_output_dir(output_dir):
  20. """Create the output directory if it does not exist."""
  21. if not os.path.exists(output_dir):
  22. os.makedirs(output_dir)
  23. def save_mask_as_nifti(mask_data, output_file):
  24. """Save mask data as NIfTI file."""
  25. mask_data_int = mask_data.astype(np.uint8)
  26. mask_img = nib.Nifti1Image(mask_data_int, affine=None)
  27. nib.save(mask_img, output_file)
  28. def process_single_file(file_info):
  29. """Process a single file path."""
  30. file_path, session_mapping, output_path, save_nifti_mask = file_info
  31. csvdata = [] # Local CSV data collector
  32. # Extract information from the file path
  33. subject_id = file_path.split(os.sep)[-5]
  34. time_point = file_path.split(os.sep)[-4]
  35. # Map the extracted time_point directly using session_mapping
  36. merged_time_point = session_mapping.get(time_point, None)
  37. if merged_time_point is None:
  38. print(f"Warning: No mapping found for {time_point}. Skipping file: {file_path}")
  39. return []
  40. # Extract the q_type from the filename
  41. q_type = os.path.basename(file_path).split("_flipped")[0]
  42. # Attempt to find the stroke mask path
  43. try:
  44. search_stroke = os.path.join(os.path.dirname(file_path), "*StrokeMask_scaled.nii")
  45. stroke_path = glob.glob(search_stroke)[0]
  46. except IndexError:
  47. stroke_path = None
  48. # Create the temp path to search for dwi_masks
  49. temp_path = os.path.join(os.path.dirname(file_path), "RegisteredTractMasks_adjusted")
  50. dwi_masks = glob.glob(os.path.join(temp_path, "*dwi_flipped.nii.gz"))
  51. white_matter_index = next((i for i, path in enumerate(dwi_masks) if "White_matter" in path), None)
  52. if white_matter_index is None:
  53. print(f"Warning: White matter mask not found in {temp_path}. Skipping file: {file_path}")
  54. return []
  55. white_matter_path = dwi_masks.pop(white_matter_index)
  56. # Load white matter mask
  57. white_matter_img = nib.load(white_matter_path)
  58. white_matter_data = white_matter_img.get_fdata()
  59. # Load DWI data using nibabel
  60. dwi_img = nib.load(file_path)
  61. dwi_data = dwi_img.get_fdata()
  62. # Loop through masks and calculate average values
  63. pix_dialation = [0, 1, 2, 3, 4]
  64. for mask_path in dwi_masks:
  65. mask_name = os.path.basename(mask_path).replace("registered", "").replace("flipped", "").replace(".nii.gz", "").replace("__dwi_", "").replace("_dwi_", "")
  66. mask_img = nib.load(mask_path)
  67. mask_data_pre_dilation = mask_img.get_fdata()
  68. # Determine if it's an AMBA mask
  69. AMBA_flag = "AMBA" in mask_name
  70. # Calculate the average value within the mask region for each dilation
  71. for pp in pix_dialation:
  72. if pp != 0:
  73. mask_data0 = binary_dilation(mask_data_pre_dilation, iterations=pp)
  74. else:
  75. mask_data0 = mask_data_pre_dilation > 0
  76. mask_data = mask_data0 #& (white_matter_data > 0)
  77. if stroke_path:
  78. stroke_image = nib.load(stroke_path)
  79. stroke_data = stroke_image.get_fdata()
  80. # Subtracting stroke region from mask except in baseline
  81. if merged_time_point != 0:
  82. mask_data = mask_data & (stroke_data < 1)
  83. # Create an empty array to hold the updated mask data
  84. new_mask_data = np.zeros_like(mask_data)
  85. # Iterate through each slice along the third axis and only retain slices where there is data in stroke_data
  86. for z in range(stroke_data.shape[2]):
  87. if np.any(stroke_data[:, :, z]):
  88. new_mask_data[:, :, z] = mask_data[:, :, z]
  89. # Replace mask_data with the new filtered mask
  90. mask_data = new_mask_data
  91. strokeFlag = "Stroke"
  92. else:
  93. strokeFlag = "Sham"
  94. # Calculate average DWI value within the mask
  95. masked_data = dwi_data[mask_data > 0]
  96. non_zero_masked_data = masked_data[masked_data != 0] # Exclude zero values
  97. average_value = np.nanmean(non_zero_masked_data) # Calculate mean only on non-zero elements
  98. # Calculate the number of voxels in the mask
  99. num_voxels = non_zero_masked_data.size
  100. # Save the mask as NIfTI file if the flag is set
  101. if save_nifti_mask and pp > 0 and merged_time_point in [3] and q_type == "fa":
  102. output_nifti_file = os.path.join(output_path,
  103. f"{subject_id}_tp_{merged_time_point}_{q_type}_{mask_name}_{pp}.nii.gz")
  104. save_mask_as_nifti(mask_data, output_nifti_file)
  105. # Append data to local CSV data
  106. csvdata.append([file_path, subject_id, time_point, merged_time_point, q_type, mask_name, pp, average_value, strokeFlag, AMBA_flag, num_voxels])
  107. return csvdata
  108. def process_files(input_path, output_path, save_nifti_mask):
  109. """Process DWI files and generate data using parallel processing."""
  110. session_mapping = {
  111. "ses-Baseline": 0, "ses-Baseline1": 0, "ses-pre": 0,
  112. "ses-P1": 3, "ses-P2": 3, "ses-P3": 3, "ses-P4": 3, "ses-P5": 3,
  113. "ses-P6": 7, "ses-P7": 7, "ses-P8": 7, "ses-P9": 7, "ses-P10": 7,
  114. "ses-P11": 14, "ses-P12": 14, "ses-P13": 14, "ses-P14": 14,
  115. "ses-P15": 14, "ses-P16": 14, "ses-P17": 14, "ses-P18": 14,
  116. "ses-P19": 21, "ses-D21": 21, "ses-P20": 21, "ses-P21": 21,
  117. "ses-P22": 21, "ses-P23": 21, "ses-P24": 21, "ses-P25": 21,
  118. "ses-P26": 28, "ses-P27": 28, "ses-P28": 28, "ses-P29": 28,
  119. "ses-P30": 28, "ses-P42": 42, "ses-P43": 42,
  120. "ses-P56": 56, "ses-P57": 56, "ses-P58": 56,
  121. "ses-P151": 28
  122. }
  123. # Get all relevant file paths
  124. file_paths = glob.glob(os.path.join(input_path, "**", "dwi", "DSI_studio", "*_flipped.nii.gz"), recursive=True)
  125. # Prepare arguments for parallel processing
  126. arguments = [(file_path, session_mapping, output_path, save_nifti_mask) for file_path in file_paths]
  127. # Initialize a list to store results
  128. csv_data = []
  129. # Use ProcessPoolExecutor to parallelize the file path processing
  130. with ProcessPoolExecutor(max_workers=6) as executor:
  131. # Process files in parallel
  132. results = list(tqdm(executor.map(process_single_file, arguments), total=len(file_paths), desc="Processing files"))
  133. # Aggregate the results
  134. for result in results:
  135. csv_data.extend(result)
  136. # Create a DataFrame
  137. df = pd.DataFrame(csv_data,
  138. columns=["fullpath", "subjectID", "timePoint", "merged_timepoint", "Qtype", "mask_name",
  139. "dialation_amount", "Value", "Group", "is_it_AMBA", "num_voxels"])
  140. return df
  141. def main():
  142. # Use argparse to get command-line inputs
  143. import argparse
  144. parser = argparse.ArgumentParser(description="Process DWI files and generate data.")
  145. parser.add_argument("-i", "--input", type=str, required=True, help="Input directory containing DWI files.")
  146. parser.add_argument("-o", "--output", type=str, required=True, help="Output directory to save results.")
  147. parser.add_argument("-n", "--nifti-mask", action="store_true", default=False, help="Set to save masks as NIfTI files.")
  148. args = parser.parse_args()
  149. # Create the output directory if it does not exist
  150. create_output_dir(args.output)
  151. # Process files
  152. df = process_files(args.input, args.output, args.nifti_mask)
  153. # Save the DataFrame as CSV
  154. csv_file = os.path.join(args.output, "Quantitative_results_from_dwi_processing_only_in_stroke_slices.csv")
  155. df.to_csv(csv_file, index=False)
  156. print("CSV file created at:", csv_file)
  157. if __name__ == "__main__":
  158. main()