123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- import os
- import glob
- import pandas as pd
- import nibabel as nib
- import numpy as np
- from tqdm import tqdm
- from concurrent.futures import ProcessPoolExecutor
- from scipy.ndimage import binary_dilation
- def create_output_dir(output_dir):
- """Create the output directory if it does not exist."""
- if not os.path.exists(output_dir):
- os.makedirs(output_dir)
- def save_mask_as_nifti(mask_data, output_file):
- """Save mask data as NIfTI file."""
- mask_data_int = mask_data.astype(np.uint8)
- mask_img = nib.Nifti1Image(mask_data_int, affine=None)
- nib.save(mask_img, output_file)
- def process_single_file(file_info):
- """Process a single file path."""
- file_path, session_mapping, output_path, save_nifti_mask = file_info
-
- csvdata = [] # Local CSV data collector
-
- # Extract information from the file path
- subject_id = file_path.split(os.sep)[-5]
- time_point = file_path.split(os.sep)[-4]
- # Map the extracted time_point directly using session_mapping
- merged_time_point = session_mapping.get(time_point, None)
- if merged_time_point is None:
- print(f"Warning: No mapping found for {time_point}. Skipping file: {file_path}")
- return []
- # Extract the q_type from the filename
- q_type = os.path.basename(file_path).split("_flipped")[0]
- # Attempt to find the stroke mask path
- try:
- search_stroke = os.path.join(os.path.dirname(file_path), "*StrokeMask_scaled.nii")
- stroke_path = glob.glob(search_stroke)[0]
- except IndexError:
- stroke_path = None
- # Create the temp path to search for dwi_masks
- temp_path = os.path.join(os.path.dirname(file_path), "RegisteredTractMasks_adjusted")
- dwi_masks = glob.glob(os.path.join(temp_path, "*dwi_flipped.nii.gz"))
- white_matter_index = next((i for i, path in enumerate(dwi_masks) if "White_matter" in path), None)
- white_matter_path = dwi_masks.pop(white_matter_index)
-
- # Load white matter mask
- white_matter_img = nib.load(white_matter_path)
- white_matter_data = white_matter_img.get_fdata()
- # Load DWI data using nibabel
- dwi_img = nib.load(file_path)
- dwi_data = dwi_img.get_fdata()
- # Loop through masks and calculate average values
- pix_dialation = [0, 1, 2, 3, 4]
- for mask_path in dwi_masks:
- mask_name = os.path.basename(mask_path).replace("registered", "").replace("flipped", "").replace(".nii.gz", "").replace("__dwi_", "").replace("_dwi_", "")
- mask_img = nib.load(mask_path)
- mask_data_pre_dilation = mask_img.get_fdata()
- # Determine if it's an AMBA mask
- AMBA_flag = "AMBA" in mask_name
- # Calculate the average value within the mask region for each dilation
- for pp in pix_dialation:
- if pp != 0:
- mask_data0 = binary_dilation(mask_data_pre_dilation, iterations=pp)
- else:
- mask_data0 = mask_data_pre_dilation > 0
- mask_data = mask_data0 & (white_matter_data > 0)
- if stroke_path:
- stroke_image = nib.load(stroke_path)
- stroke_data = stroke_image.get_fdata()
- # Subtracting stroke region from mask except in baseline
- if merged_time_point != 0:
- mask_data = mask_data & (stroke_data < 1)
- strokeFlag = "Stroke"
- else:
- strokeFlag = "Sham"
- # Calculate average DWI value within the mask
- masked_data = dwi_data[mask_data > 0]
- non_zero_masked_data = masked_data[masked_data != 0] # Exclude zero values
- average_value = np.nanmean(non_zero_masked_data) # Calculate mean only on non-zero elements
- # Save the mask as NIfTI file if the flag is set
- if save_nifti_mask and pp > 0 and merged_time_point in [3] and q_type == "fa":
- output_nifti_file = os.path.join(output_path,
- f"{subject_id}_tp_{merged_time_point}_{q_type}_{mask_name}_{pp}.nii.gz")
- save_mask_as_nifti(mask_data, output_nifti_file)
- # Append data to local CSV data
- csvdata.append([file_path, subject_id, time_point, merged_time_point, q_type, mask_name, pp, average_value, strokeFlag, AMBA_flag])
- return csvdata
- def process_files(input_path, output_path, save_nifti_mask):
- """Process DWI files and generate data using parallel processing."""
- session_mapping = {
- "ses-Baseline": 0, "ses-Baseline1": 0, "ses-pre": 0,
- "ses-P1": 3, "ses-P2": 3, "ses-P3": 3, "ses-P4": 3, "ses-P5": 3,
- "ses-P6": 7, "ses-P7": 7, "ses-P8": 7, "ses-P9": 7, "ses-P10": 7,
- "ses-P11": 14, "ses-P12": 14, "ses-P13": 14, "ses-P14": 14,
- "ses-P15": 14, "ses-P16": 14, "ses-P17": 14, "ses-P18": 14,
- "ses-P19": 21, "ses-D21": 21, "ses-P20": 21, "ses-P21": 21,
- "ses-P22": 21, "ses-P23": 21, "ses-P24": 21, "ses-P25": 21,
- "ses-P26": 28, "ses-P27": 28, "ses-P28": 28, "ses-P29": 28,
- "ses-P30": 28, "ses-P42": 42, "ses-P43": 42,
- "ses-P56": 56, "ses-P57": 56, "ses-P58": 56,
- "ses-P151": 28
- }
- # Get all relevant file paths
- file_paths = glob.glob(os.path.join(input_path, "**", "dwi", "DSI_studio", "*_flipped.nii.gz"), recursive=True)
- # Prepare arguments for parallel processing
- arguments = [(file_path, session_mapping, output_path, save_nifti_mask) for file_path in file_paths]
- # Initialize a list to store results
- csv_data = []
- # Use ProcessPoolExecutor to parallelize the file path processing
- with ProcessPoolExecutor(max_workers=6) as executor:
- # Process files in parallel
- results = list(tqdm(executor.map(process_single_file, arguments), total=len(file_paths), desc="Processing files"))
- # Aggregate the results
- for result in results:
- csv_data.extend(result)
- # Create a DataFrame
- df = pd.DataFrame(csv_data,
- columns=["fullpath", "subjectID", "timePoint", "merged_timepoint", "Qtype", "mask_name",
- "dialation_amount", "Value", "Group", "is_it_AMBA"])
- return df
- def main():
- # Use argparse to get command-line inputs
- import argparse
- parser = argparse.ArgumentParser(description="Process DWI files and generate data.")
- parser.add_argument("-i", "--input", type=str, required=True, help="Input directory containing DWI files.")
- parser.add_argument("-o", "--output", type=str, required=True, help="Output directory to save results.")
- parser.add_argument("-n", "--nifti-mask", action="store_true", default=False, help="Set to save masks as NIfTI files.")
- args = parser.parse_args()
- # Create the output directory if it does not exist
- create_output_dir(args.output)
- # Process files
- df = process_files(args.input, args.output, args.nifti_mask)
- # Save the DataFrame as CSV
- csv_file = os.path.join(args.output, "Quantitative_results_from_dwi_processing.csv")
- df.to_csv(csv_file, index=False)
- print("CSV file created at:", csv_file)
- if __name__ == "__main__":
- main()
|