GetQuantitativeValues.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import os
  2. import glob
  3. import pandas as pd
  4. import nibabel as nib
  5. import numpy as np
  6. from tqdm import tqdm
  7. from concurrent.futures import ProcessPoolExecutor
  8. from scipy.ndimage import binary_dilation
  9. def create_output_dir(output_dir):
  10. """Create the output directory if it does not exist."""
  11. if not os.path.exists(output_dir):
  12. os.makedirs(output_dir)
  13. def save_mask_as_nifti(mask_data, output_file):
  14. """Save mask data as NIfTI file."""
  15. mask_data_int = mask_data.astype(np.uint8)
  16. mask_img = nib.Nifti1Image(mask_data_int, affine=None)
  17. nib.save(mask_img, output_file)
  18. def process_single_file(file_info):
  19. """Process a single file path."""
  20. file_path, session_mapping, output_path, save_nifti_mask = file_info
  21. csvdata = [] # Local CSV data collector
  22. # Extract information from the file path
  23. subject_id = file_path.split(os.sep)[-5]
  24. time_point = file_path.split(os.sep)[-4]
  25. # Map the extracted time_point directly using session_mapping
  26. merged_time_point = session_mapping.get(time_point, None)
  27. if merged_time_point is None:
  28. print(f"Warning: No mapping found for {time_point}. Skipping file: {file_path}")
  29. return []
  30. # Extract the q_type from the filename
  31. q_type = os.path.basename(file_path).split("_flipped")[0]
  32. # Attempt to find the stroke mask path
  33. try:
  34. search_stroke = os.path.join(os.path.dirname(file_path), "*StrokeMask_scaled.nii")
  35. stroke_path = glob.glob(search_stroke)[0]
  36. except IndexError:
  37. stroke_path = None
  38. # Create the temp path to search for dwi_masks
  39. temp_path = os.path.join(os.path.dirname(file_path), "RegisteredTractMasks_adjusted")
  40. dwi_masks = glob.glob(os.path.join(temp_path, "*dwi_flipped.nii.gz"))
  41. white_matter_index = next((i for i, path in enumerate(dwi_masks) if "White_matter" in path), None)
  42. white_matter_path = dwi_masks.pop(white_matter_index)
  43. # Load white matter mask
  44. white_matter_img = nib.load(white_matter_path)
  45. white_matter_data = white_matter_img.get_fdata()
  46. # Load DWI data using nibabel
  47. dwi_img = nib.load(file_path)
  48. dwi_data = dwi_img.get_fdata()
  49. # Loop through masks and calculate average values
  50. pix_dialation = [0, 1, 2, 3, 4]
  51. for mask_path in dwi_masks:
  52. mask_name = os.path.basename(mask_path).replace("registered", "").replace("flipped", "").replace(".nii.gz", "").replace("__dwi_", "").replace("_dwi_", "")
  53. mask_img = nib.load(mask_path)
  54. mask_data_pre_dilation = mask_img.get_fdata()
  55. # Determine if it's an AMBA mask
  56. AMBA_flag = "AMBA" in mask_name
  57. # Calculate the average value within the mask region for each dilation
  58. for pp in pix_dialation:
  59. if pp != 0:
  60. mask_data0 = binary_dilation(mask_data_pre_dilation, iterations=pp)
  61. else:
  62. mask_data0 = mask_data_pre_dilation > 0
  63. mask_data = mask_data0 #& (white_matter_data > 0)
  64. if stroke_path:
  65. stroke_image = nib.load(stroke_path)
  66. stroke_data = stroke_image.get_fdata()
  67. # Subtracting stroke region from mask except in baseline
  68. if merged_time_point != 0:
  69. mask_data = mask_data & (stroke_data < 1)
  70. strokeFlag = "Stroke"
  71. else:
  72. strokeFlag = "Sham"
  73. # Calculate average DWI value within the mask
  74. masked_data = dwi_data[mask_data > 0]
  75. non_zero_masked_data = masked_data[masked_data != 0] # Exclude zero values
  76. average_value = np.nanmean(non_zero_masked_data) # Calculate mean only on non-zero elements
  77. # Save the mask as NIfTI file if the flag is set
  78. if save_nifti_mask and pp > 0 and merged_time_point in [3] and q_type == "fa":
  79. output_nifti_file = os.path.join(output_path,
  80. f"{subject_id}_tp_{merged_time_point}_{q_type}_{mask_name}_{pp}.nii.gz")
  81. save_mask_as_nifti(mask_data, output_nifti_file)
  82. # Append data to local CSV data
  83. csvdata.append([file_path, subject_id, time_point, merged_time_point, q_type, mask_name, pp, average_value, strokeFlag, AMBA_flag])
  84. return csvdata
  85. def process_files(input_path, output_path, save_nifti_mask):
  86. """Process DWI files and generate data using parallel processing."""
  87. session_mapping = {
  88. "ses-Baseline": 0, "ses-Baseline1": 0, "ses-pre": 0,
  89. "ses-P1": 3, "ses-P2": 3, "ses-P3": 3, "ses-P4": 3, "ses-P5": 3,
  90. "ses-P6": 7, "ses-P7": 7, "ses-P8": 7, "ses-P9": 7, "ses-P10": 7,
  91. "ses-P11": 14, "ses-P12": 14, "ses-P13": 14, "ses-P14": 14,
  92. "ses-P15": 14, "ses-P16": 14, "ses-P17": 14, "ses-P18": 14,
  93. "ses-P19": 21, "ses-D21": 21, "ses-P20": 21, "ses-P21": 21,
  94. "ses-P22": 21, "ses-P23": 21, "ses-P24": 21, "ses-P25": 21,
  95. "ses-P26": 28, "ses-P27": 28, "ses-P28": 28, "ses-P29": 28,
  96. "ses-P30": 28, "ses-P42": 42, "ses-P43": 42,
  97. "ses-P56": 56, "ses-P57": 56, "ses-P58": 56,
  98. "ses-P151": 28
  99. }
  100. # Get all relevant file paths
  101. file_paths = glob.glob(os.path.join(input_path, "**", "dwi", "DSI_studio", "*_flipped.nii.gz"), recursive=True)
  102. # Prepare arguments for parallel processing
  103. arguments = [(file_path, session_mapping, output_path, save_nifti_mask) for file_path in file_paths]
  104. # Initialize a list to store results
  105. csv_data = []
  106. # Use ProcessPoolExecutor to parallelize the file path processing
  107. with ProcessPoolExecutor(max_workers=6) as executor:
  108. # Process files in parallel
  109. results = list(tqdm(executor.map(process_single_file, arguments), total=len(file_paths), desc="Processing files"))
  110. # Aggregate the results
  111. for result in results:
  112. csv_data.extend(result)
  113. # Create a DataFrame
  114. df = pd.DataFrame(csv_data,
  115. columns=["fullpath", "subjectID", "timePoint", "merged_timepoint", "Qtype", "mask_name",
  116. "dialation_amount", "Value", "Group", "is_it_AMBA"])
  117. return df
  118. def main():
  119. # Use argparse to get command-line inputs
  120. import argparse
  121. parser = argparse.ArgumentParser(description="Process DWI files and generate data.")
  122. parser.add_argument("-i", "--input", type=str, required=True, help="Input directory containing DWI files.")
  123. parser.add_argument("-o", "--output", type=str, required=True, help="Output directory to save results.")
  124. parser.add_argument("-n", "--nifti-mask", action="store_true", default=False, help="Set to save masks as NIfTI files.")
  125. args = parser.parse_args()
  126. # Create the output directory if it does not exist
  127. create_output_dir(args.output)
  128. # Process files
  129. df = process_files(args.input, args.output, args.nifti_mask)
  130. # Save the DataFrame as CSV
  131. csv_file = os.path.join(args.output, "Quantitative_results_from_dwi_processing.csv")
  132. df.to_csv(csv_file, index=False)
  133. print("CSV file created at:", csv_file)
  134. if __name__ == "__main__":
  135. main()