|
@@ -1,197 +1,217 @@
|
|
|
-# -*- coding: utf-8 -*-
|
|
|
-"""
|
|
|
-Created on Mon Nov 11 09:35:42 2024
|
|
|
-
|
|
|
-@author: arefks
|
|
|
-"""
|
|
|
-
|
|
|
-import os
|
|
|
-import glob
|
|
|
-import pandas as pd
|
|
|
-import nibabel as nib
|
|
|
-import numpy as np
|
|
|
-from tqdm import tqdm
|
|
|
-from concurrent.futures import ProcessPoolExecutor
|
|
|
-from scipy.ndimage import binary_dilation
|
|
|
-
|
|
|
-def create_output_dir(output_dir):
|
|
|
- """Create the output directory if it does not exist."""
|
|
|
- if not os.path.exists(output_dir):
|
|
|
- os.makedirs(output_dir)
|
|
|
-
|
|
|
-def save_mask_as_nifti(mask_data, output_file):
|
|
|
- """Save mask data as NIfTI file."""
|
|
|
- mask_data_int = mask_data.astype(np.uint8)
|
|
|
- mask_img = nib.Nifti1Image(mask_data_int, affine=None)
|
|
|
- nib.save(mask_img, output_file)
|
|
|
-
|
|
|
-def process_single_file(file_info):
|
|
|
- """Process a single file path."""
|
|
|
- file_path, session_mapping, output_path, save_nifti_mask = file_info
|
|
|
-
|
|
|
- csvdata = [] # Local CSV data collector
|
|
|
-
|
|
|
- # Extract information from the file path
|
|
|
- subject_id = file_path.split(os.sep)[-5]
|
|
|
- time_point = file_path.split(os.sep)[-4]
|
|
|
-
|
|
|
- # Map the extracted time_point directly using session_mapping
|
|
|
- merged_time_point = session_mapping.get(time_point, None)
|
|
|
-
|
|
|
- if merged_time_point is None:
|
|
|
- print(f"Warning: No mapping found for {time_point}. Skipping file: {file_path}")
|
|
|
- return []
|
|
|
-
|
|
|
- # Extract the q_type from the filename
|
|
|
- q_type = os.path.basename(file_path).split("_flipped")[0]
|
|
|
-
|
|
|
- # Attempt to find the stroke mask path
|
|
|
- try:
|
|
|
- search_stroke = os.path.join(os.path.dirname(file_path), "*StrokeMask_scaled.nii")
|
|
|
- stroke_path = glob.glob(search_stroke)[0]
|
|
|
- except IndexError:
|
|
|
- stroke_path = None
|
|
|
-
|
|
|
- # Create the temp path to search for dwi_masks
|
|
|
- temp_path = os.path.join(os.path.dirname(file_path), "RegisteredTractMasks_adjusted")
|
|
|
- dwi_masks = glob.glob(os.path.join(temp_path, "CST*dwi_flipped.nii.gz")) + glob.glob(os.path.join(temp_path, "White_matter*dwi_flipped.nii.gz"))
|
|
|
- white_matter_index = next((i for i, path in enumerate(dwi_masks) if "White_matter" in path), None)
|
|
|
-
|
|
|
- if white_matter_index is None:
|
|
|
- print(f"Warning: White matter mask not found in {temp_path}. Skipping file: {file_path}")
|
|
|
- return []
|
|
|
-
|
|
|
- white_matter_path = dwi_masks.pop(white_matter_index)
|
|
|
-
|
|
|
- # Load white matter mask
|
|
|
- white_matter_img = nib.load(white_matter_path)
|
|
|
- white_matter_data = white_matter_img.get_fdata()
|
|
|
-
|
|
|
- # Load DWI data using nibabel
|
|
|
- dwi_img = nib.load(file_path)
|
|
|
- dwi_data = dwi_img.get_fdata()
|
|
|
-
|
|
|
- # Loop through masks and calculate average values
|
|
|
- pix_dialation = [0, 1, 2, 3, 4]
|
|
|
- for mask_path in dwi_masks:
|
|
|
- mask_name = os.path.basename(mask_path).replace("registered", "").replace("flipped", "").replace(".nii.gz", "").replace("__dwi_", "").replace("_dwi_", "")
|
|
|
- mask_img = nib.load(mask_path)
|
|
|
- mask_data_pre_dilation = mask_img.get_fdata()
|
|
|
-
|
|
|
- # Determine if it's an AMBA mask
|
|
|
- AMBA_flag = "AMBA" in mask_name
|
|
|
-
|
|
|
- # Calculate the average value within the mask region for each dilation
|
|
|
- for pp in pix_dialation:
|
|
|
- if pp != 0:
|
|
|
- mask_data0 = binary_dilation(mask_data_pre_dilation, iterations=pp)
|
|
|
- else:
|
|
|
- mask_data0 = mask_data_pre_dilation > 0
|
|
|
-
|
|
|
- mask_data = mask_data0 #& (white_matter_data > 0)
|
|
|
- if stroke_path:
|
|
|
- stroke_image = nib.load(stroke_path)
|
|
|
- stroke_data = stroke_image.get_fdata()
|
|
|
-
|
|
|
- # Subtracting stroke region from mask except in baseline
|
|
|
- if merged_time_point != 0:
|
|
|
- mask_data = mask_data & (stroke_data < 1)
|
|
|
-
|
|
|
- # Create an empty array to hold the updated mask data
|
|
|
- new_mask_data = np.zeros_like(mask_data)
|
|
|
-
|
|
|
- # Iterate through each slice along the third axis and only retain slices where there is data in stroke_data
|
|
|
- for z in range(stroke_data.shape[2]):
|
|
|
- if np.any(stroke_data[:, :, z]):
|
|
|
- new_mask_data[:, :, z] = mask_data[:, :, z]
|
|
|
-
|
|
|
- # Replace mask_data with the new filtered mask
|
|
|
- mask_data = new_mask_data
|
|
|
-
|
|
|
- strokeFlag = "Stroke"
|
|
|
- else:
|
|
|
- strokeFlag = "Sham"
|
|
|
-
|
|
|
-
|
|
|
- # Calculate average DWI value within the mask
|
|
|
- masked_data = dwi_data[mask_data > 0]
|
|
|
- non_zero_masked_data = masked_data[masked_data != 0] # Exclude zero values
|
|
|
- average_value = np.nanmean(non_zero_masked_data) # Calculate mean only on non-zero elements
|
|
|
-
|
|
|
- # Calculate the number of voxels in the mask
|
|
|
- num_voxels = non_zero_masked_data.size
|
|
|
-
|
|
|
- # Save the mask as NIfTI file if the flag is set
|
|
|
- if save_nifti_mask and pp > 0 and merged_time_point in [3] and q_type == "fa":
|
|
|
- output_nifti_file = os.path.join(output_path,
|
|
|
- f"{subject_id}_tp_{merged_time_point}_{q_type}_{mask_name}_{pp}.nii.gz")
|
|
|
- save_mask_as_nifti(mask_data, output_nifti_file)
|
|
|
-
|
|
|
- # Append data to local CSV data
|
|
|
- csvdata.append([file_path, subject_id, time_point, merged_time_point, q_type, mask_name, pp, average_value, strokeFlag, AMBA_flag, num_voxels])
|
|
|
-
|
|
|
- return csvdata
|
|
|
-
|
|
|
-def process_files(input_path, output_path, save_nifti_mask):
|
|
|
- """Process DWI files and generate data using parallel processing."""
|
|
|
- session_mapping = {
|
|
|
- "ses-Baseline": 0, "ses-Baseline1": 0, "ses-pre": 0,
|
|
|
- "ses-P1": 3, "ses-P2": 3, "ses-P3": 3, "ses-P4": 3, "ses-P5": 3,
|
|
|
- "ses-P6": 7, "ses-P7": 7, "ses-P8": 7, "ses-P9": 7, "ses-P10": 7,
|
|
|
- "ses-P11": 14, "ses-P12": 14, "ses-P13": 14, "ses-P14": 14,
|
|
|
- "ses-P15": 14, "ses-P16": 14, "ses-P17": 14, "ses-P18": 14,
|
|
|
- "ses-P19": 21, "ses-D21": 21, "ses-P20": 21, "ses-P21": 21,
|
|
|
- "ses-P22": 21, "ses-P23": 21, "ses-P24": 21, "ses-P25": 21,
|
|
|
- "ses-P26": 28, "ses-P27": 28, "ses-P28": 28, "ses-P29": 28,
|
|
|
- "ses-P30": 28, "ses-P42": 42, "ses-P43": 42,
|
|
|
- "ses-P56": 56, "ses-P57": 56, "ses-P58": 56,
|
|
|
- "ses-P151": 28
|
|
|
- }
|
|
|
-
|
|
|
- # Get all relevant file paths
|
|
|
- file_paths = glob.glob(os.path.join(input_path, "**", "dwi", "DSI_studio", "*_flipped.nii.gz"), recursive=True)
|
|
|
-
|
|
|
- # Prepare arguments for parallel processing
|
|
|
- arguments = [(file_path, session_mapping, output_path, save_nifti_mask) for file_path in file_paths]
|
|
|
-
|
|
|
- # Initialize a list to store results
|
|
|
- csv_data = []
|
|
|
-
|
|
|
- # Use ProcessPoolExecutor to parallelize the file path processing
|
|
|
- with ProcessPoolExecutor(max_workers=6) as executor:
|
|
|
- # Process files in parallel
|
|
|
- results = list(tqdm(executor.map(process_single_file, arguments), total=len(file_paths), desc="Processing files"))
|
|
|
-
|
|
|
- # Aggregate the results
|
|
|
- for result in results:
|
|
|
- csv_data.extend(result)
|
|
|
-
|
|
|
- # Create a DataFrame
|
|
|
- df = pd.DataFrame(csv_data,
|
|
|
- columns=["fullpath", "subjectID", "timePoint", "merged_timepoint", "Qtype", "mask_name",
|
|
|
- "dialation_amount", "Value", "Group", "is_it_AMBA", "num_voxels"])
|
|
|
-
|
|
|
- return df
|
|
|
-
|
|
|
-def main():
|
|
|
- # Use argparse to get command-line inputs
|
|
|
- import argparse
|
|
|
- parser = argparse.ArgumentParser(description="Process DWI files and generate data.")
|
|
|
- parser.add_argument("-i", "--input", type=str, required=True, help="Input directory containing DWI files.")
|
|
|
- parser.add_argument("-o", "--output", type=str, required=True, help="Output directory to save results.")
|
|
|
- parser.add_argument("-n", "--nifti-mask", action="store_true", default=False, help="Set to save masks as NIfTI files.")
|
|
|
- args = parser.parse_args()
|
|
|
-
|
|
|
- # Create the output directory if it does not exist
|
|
|
- create_output_dir(args.output)
|
|
|
-
|
|
|
- # Process files
|
|
|
- df = process_files(args.input, args.output, args.nifti_mask)
|
|
|
-
|
|
|
- # Save the DataFrame as CSV
|
|
|
- csv_file = os.path.join(args.output, "Quantitative_results_from_dwi_processing.csv")
|
|
|
- df.to_csv(csv_file, index=False)
|
|
|
- print("CSV file created at:", csv_file)
|
|
|
-
|
|
|
-if __name__ == "__main__":
|
|
|
- main()
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+"""
|
|
|
+Created on Mon Nov 11 09:35:42 2024
|
|
|
+
|
|
|
+@author: arefks
|
|
|
+"""
|
|
|
+
|
|
|
+import os
|
|
|
+import glob
|
|
|
+import pandas as pd
|
|
|
+import nibabel as nib
|
|
|
+import numpy as np
|
|
|
+from tqdm import tqdm
|
|
|
+from concurrent.futures import ProcessPoolExecutor
|
|
|
+from scipy.ndimage import binary_dilation
|
|
|
+
|
|
|
+def create_output_dir(output_dir):
|
|
|
+ """Create the output directory if it does not exist."""
|
|
|
+ if not os.path.exists(output_dir):
|
|
|
+ os.makedirs(output_dir)
|
|
|
+
|
|
|
+def save_mask_as_nifti(mask_data, output_file):
|
|
|
+ """Save mask data as NIfTI file."""
|
|
|
+ mask_data_int = mask_data.astype(np.uint8)
|
|
|
+ mask_img = nib.Nifti1Image(mask_data_int, affine=None)
|
|
|
+ nib.save(mask_img, output_file)
|
|
|
+
|
|
|
+def process_single_file(file_info):
|
|
|
+ """Process a single file path."""
|
|
|
+ file_path, session_mapping, output_path, save_nifti_mask = file_info
|
|
|
+
|
|
|
+ csvdata = [] # Local CSV data collector
|
|
|
+
|
|
|
+ # Extract information from the file path
|
|
|
+ subject_id = file_path.split(os.sep)[-5]
|
|
|
+ time_point = file_path.split(os.sep)[-4]
|
|
|
+
|
|
|
+ # Map the extracted time_point directly using session_mapping
|
|
|
+ merged_time_point = session_mapping.get(time_point, None)
|
|
|
+
|
|
|
+ if merged_time_point is None:
|
|
|
+ print(f"Warning: No mapping found for {time_point}. Skipping file: {file_path}")
|
|
|
+ return []
|
|
|
+
|
|
|
+ # Extract the q_type from the filename
|
|
|
+ q_type = os.path.basename(file_path).split("_flipped")[0]
|
|
|
+
|
|
|
+ # Attempt to find the stroke mask path
|
|
|
+ try:
|
|
|
+ search_stroke = os.path.join(os.path.dirname(file_path), "*StrokeMask_scaled.nii")
|
|
|
+ stroke_path = glob.glob(search_stroke)[0]
|
|
|
+ except IndexError:
|
|
|
+ stroke_path = None
|
|
|
+
|
|
|
+ # Create the temp path to search for dwi_masks
|
|
|
+ temp_path = os.path.join(os.path.dirname(file_path), "RegisteredTractMasks_adjusted")
|
|
|
+ dwi_masks = glob.glob(os.path.join(temp_path, "CST*dwi_flipped.nii.gz")) + glob.glob(os.path.join(temp_path, "White_matter*dwi_flipped.nii.gz"))
|
|
|
+ white_matter_index = next((i for i, path in enumerate(dwi_masks) if "White_matter" in path), None)
|
|
|
+
|
|
|
+ if white_matter_index is None:
|
|
|
+ print(f"Warning: White matter mask not found in {temp_path}. Skipping file: {file_path}")
|
|
|
+ return []
|
|
|
+
|
|
|
+ white_matter_path = dwi_masks.pop(white_matter_index)
|
|
|
+
|
|
|
+ # Load white matter mask
|
|
|
+ white_matter_img = nib.load(white_matter_path)
|
|
|
+ white_matter_data = white_matter_img.get_fdata()
|
|
|
+
|
|
|
+ # Load DWI data using nibabel
|
|
|
+ dwi_img = nib.load(file_path)
|
|
|
+ dwi_data = dwi_img.get_fdata()
|
|
|
+
|
|
|
+ # Loop through masks and calculate average values
|
|
|
+ pix_dialation = [0, 1, 2, 3, 4]
|
|
|
+ for mask_path in dwi_masks:
|
|
|
+ mask_name = os.path.basename(mask_path).replace("registered", "").replace("flipped", "").replace(".nii.gz", "").replace("__dwi_", "").replace("_dwi_", "")
|
|
|
+ mask_img = nib.load(mask_path)
|
|
|
+ mask_data_pre_dilation = mask_img.get_fdata()
|
|
|
+
|
|
|
+ # Determine if it's an AMBA mask
|
|
|
+ AMBA_flag = "AMBA" in mask_name
|
|
|
+
|
|
|
+ # Calculate the average value within the mask region for each dilation
|
|
|
+ for pp in pix_dialation:
|
|
|
+ if pp != 0:
|
|
|
+ mask_data0 = binary_dilation(mask_data_pre_dilation, iterations=pp)
|
|
|
+ else:
|
|
|
+ mask_data0 = mask_data_pre_dilation > 0
|
|
|
+
|
|
|
+ mask_data = mask_data0 # & (white_matter_data > 0)
|
|
|
+ if stroke_path:
|
|
|
+ stroke_image = nib.load(stroke_path)
|
|
|
+ stroke_data = stroke_image.get_fdata()
|
|
|
+
|
|
|
+ # Subtracting stroke region from mask except in baseline
|
|
|
+ if merged_time_point != 0:
|
|
|
+ mask_data = mask_data & (stroke_data < 1)
|
|
|
+
|
|
|
+ strokeFlag = "Stroke"
|
|
|
+ else:
|
|
|
+ strokeFlag = "Sham"
|
|
|
+
|
|
|
+ # Find the first and last non-zero slices in the mask
|
|
|
+ first_non_zero_slice = None
|
|
|
+ last_non_zero_slice = None
|
|
|
+ for z in range(mask_data.shape[2]):
|
|
|
+ if np.any(mask_data[:, :, z]):
|
|
|
+ if first_non_zero_slice is None:
|
|
|
+ first_non_zero_slice = z
|
|
|
+ last_non_zero_slice = z
|
|
|
+
|
|
|
+ if first_non_zero_slice is not None and last_non_zero_slice is not None:
|
|
|
+ # Create partitions A and B based on the z-axis starting from the first non-zero slice to the last non-zero slice
|
|
|
+ partition_A_mask = np.zeros_like(mask_data, dtype=bool)
|
|
|
+ partition_B_mask = np.zeros_like(mask_data, dtype=bool)
|
|
|
+
|
|
|
+
|
|
|
+ GAP = (last_non_zero_slice - first_non_zero_slice)
|
|
|
+ first_part = int(np.ceil(GAP*0.5))
|
|
|
+ # Define partition A as the first 4 slices after the first non-zero slice and partition B as the rest until the last non-zero slice
|
|
|
+ partition_boundary = first_non_zero_slice + first_part
|
|
|
+ partition_A_mask[:, :, first_non_zero_slice:partition_boundary] = mask_data[:, :, first_non_zero_slice:partition_boundary]
|
|
|
+ partition_B_mask[:, :, partition_boundary:last_non_zero_slice + 1] = mask_data[:, :, partition_boundary:last_non_zero_slice + 1]
|
|
|
+
|
|
|
+ # Calculate average DWI value within the mask for each partition (A and B)
|
|
|
+ for partition, partition_mask in zip(['A', 'B'], [partition_A_mask, partition_B_mask]):
|
|
|
+ # Calculate average DWI value within the partition mask
|
|
|
+ masked_data = dwi_data[partition_mask > 0]
|
|
|
+ non_zero_masked_data = masked_data[masked_data != 0] # Exclude zero values
|
|
|
+ average_value = np.nanmean(non_zero_masked_data) # Calculate mean only on non-zero elements
|
|
|
+
|
|
|
+ # Calculate the number of voxels in the mask
|
|
|
+ num_voxels = non_zero_masked_data.size
|
|
|
+
|
|
|
+ # Determine starting and ending slice for each partition
|
|
|
+ if partition == 'A':
|
|
|
+ partition_start_slice = first_non_zero_slice
|
|
|
+ partition_end_slice = min(partition_boundary - 1, last_non_zero_slice)
|
|
|
+ else:
|
|
|
+ partition_start_slice = partition_boundary
|
|
|
+ partition_end_slice = last_non_zero_slice
|
|
|
+
|
|
|
+ # Save the mask as NIfTI file if the flag is set
|
|
|
+ if save_nifti_mask and pp > 0 and merged_time_point in [3] and q_type == "fa":
|
|
|
+ output_nifti_file = os.path.join(output_path,
|
|
|
+ f"{subject_id}_tp_{merged_time_point}_{q_type}_{mask_name}_{pp}.nii.gz")
|
|
|
+ save_mask_as_nifti(mask_data, output_nifti_file)
|
|
|
+
|
|
|
+ # Append data to local CSV data
|
|
|
+ csvdata.append([file_path, subject_id, time_point, merged_time_point, q_type, mask_name, pp, average_value, strokeFlag, AMBA_flag, num_voxels, partition, partition_start_slice, partition_end_slice])
|
|
|
+
|
|
|
+ return csvdata
|
|
|
+
|
|
|
+def process_files(input_path, output_path, save_nifti_mask):
|
|
|
+ """Process DWI files and generate data using parallel processing."""
|
|
|
+ session_mapping = {
|
|
|
+ "ses-Baseline": 0, "ses-Baseline1": 0, "ses-pre": 0,
|
|
|
+ "ses-P1": 3, "ses-P2": 3, "ses-P3": 3, "ses-P4": 3, "ses-P5": 3,
|
|
|
+ "ses-P6": 7, "ses-P7": 7, "ses-P8": 7, "ses-P9": 7, "ses-P10": 7,
|
|
|
+ "ses-P11": 14, "ses-P12": 14, "ses-P13": 14, "ses-P14": 14,
|
|
|
+ "ses-P15": 14, "ses-P16": 14, "ses-P17": 14, "ses-P18": 14,
|
|
|
+ "ses-P19": 21, "ses-D21": 21, "ses-P20": 21, "ses-P21": 21,
|
|
|
+ "ses-P22": 21, "ses-P23": 21, "ses-P24": 21, "ses-P25": 21,
|
|
|
+ "ses-P26": 28, "ses-P27": 28, "ses-P28": 28, "ses-P29": 28,
|
|
|
+ "ses-P30": 28, "ses-P42": 42, "ses-P43": 42,
|
|
|
+ "ses-P56": 56, "ses-P57": 56, "ses-P58": 56,
|
|
|
+ "ses-P151": 28
|
|
|
+ }
|
|
|
+
|
|
|
+ # Get all relevant file paths
|
|
|
+ file_paths = glob.glob(os.path.join(input_path, "**", "dwi", "DSI_studio", "*_flipped.nii.gz"), recursive=True)
|
|
|
+
|
|
|
+ # Prepare arguments for parallel processing
|
|
|
+ arguments = [(file_path, session_mapping, output_path, save_nifti_mask) for file_path in file_paths]
|
|
|
+
|
|
|
+ # Initialize a list to store results
|
|
|
+ csv_data = []
|
|
|
+
|
|
|
+ # Use ProcessPoolExecutor to parallelize the file path processing
|
|
|
+ with ProcessPoolExecutor(max_workers=6) as executor:
|
|
|
+ # Process files in parallel
|
|
|
+ results = list(tqdm(executor.map(process_single_file, arguments), total=len(file_paths), desc="Processing files"))
|
|
|
+
|
|
|
+ # Aggregate the results
|
|
|
+ for result in results:
|
|
|
+ csv_data.extend(result)
|
|
|
+
|
|
|
+ # Create a DataFrame
|
|
|
+ df = pd.DataFrame(csv_data,
|
|
|
+ columns=["fullpath", "subjectID", "timePoint", "merged_timepoint", "Qtype", "mask_name",
|
|
|
+ "dialation_amount", "Value", "Group", "is_it_AMBA", "num_voxels", "partition", "start_slice", "end_slice"])
|
|
|
+
|
|
|
+ return df
|
|
|
+
|
|
|
+def main():
|
|
|
+ # Use argparse to get command-line inputs
|
|
|
+ import argparse
|
|
|
+ parser = argparse.ArgumentParser(description="Process DWI files and generate data.")
|
|
|
+ parser.add_argument("-i", "--input", type=str, required=True, help="Input directory containing DWI files.")
|
|
|
+ parser.add_argument("-o", "--output", type=str, required=True, help="Output directory to save results.")
|
|
|
+ parser.add_argument("-n", "--nifti-mask", action="store_true", default=False, help="Set to save masks as NIfTI files.")
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ # Create the output directory if it does not exist
|
|
|
+ create_output_dir(args.output)
|
|
|
+
|
|
|
+ # Process files
|
|
|
+ df = process_files(args.input, args.output, args.nifti_mask)
|
|
|
+
|
|
|
+ # Save the DataFrame as CSV
|
|
|
+ csv_file = os.path.join(args.output, "Quantitative_results_from_dwi_processing.csv")
|
|
|
+ df.to_csv(csv_file, index=False)
|
|
|
+ print("CSV file created at:", csv_file)
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|