123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194 |
- import os
- import glob
- import pandas as pd
- import nibabel as nib
- import numpy as np
- import matplotlib.pyplot as plt
- from tqdm import tqdm
- import argparse
- from scipy.ndimage import binary_dilation
- def create_output_dir(output_dir):
- """Create the output directory if it does not exist."""
- if not os.path.exists(output_dir):
- os.makedirs(output_dir)
- def save_mask_as_nifti(mask_data, output_file):
- """Save mask data as NIfTI file."""
- # Convert boolean mask to integer (0s and 1s)
- mask_data_int = mask_data.astype(np.uint8)
-
- # Create NIfTI image object
- mask_img = nib.Nifti1Image(mask_data_int, affine=None)
-
- # Save NIfTI image
- nib.save(mask_img, output_file)
- def process_files(input_path, output_path, create_figures, save_nifti_mask):
- """Process DWI files and generate figures."""
- # Define a comprehensive mapping for all expected timepoints
- session_mapping = {
- "ses-Baseline": 0, "ses-Baseline1": 0, "ses-pre": 0,
- "ses-P1": 3, "ses-P2": 3, "ses-P3": 3, "ses-P": 3,
- "ses-P4": 3, "ses-P5": 3, "ses-P6": 7, "ses-P7": 7,
- "ses-P8": 7, "ses-P9": 7, "ses-P10": 7,
- "ses-P11": 14, "ses-P12": 14, "ses-P13": 14, "ses-P14": 14,
- "ses-P15": 14, "ses-P16": 14, "ses-P17": 14, "ses-P18": 14,
- "ses-P19": 21, "ses-D21": 21, "ses-P20": 21, "ses-P21": 21,
- "ses-P22": 21, "ses-P23": 21, "ses-P24": 21, "ses-P25": 21,
- "ses-P26": 28, "ses-P27b": 28, "ses-P151": 28, "ses-P152": 28,
- "ses-P27": 28, "ses-P28": 28, "ses-P29": 28, "ses-P30": 28,
- "ses-P42": 42, "ses-P43": 42,
- "ses-P56": 56, "ses-P222": 56, "ses-P57": 56, "ses-P58": 56
- }
- # Initialize a list to store extracted information
- csvdata = []
- # Get all relevant file paths with a progress bar
- file_paths = glob.glob(os.path.join(input_path, "**", "dwi", "DSI_studio", "*_flipped.nii.gz"), recursive=True)
- for file_path in tqdm(file_paths, desc="Processing files"):
- # Extract information from the file path
- subject_id = file_path.split(os.sep)[-5]
- time_point = file_path.split(os.sep)[-4]
- # Map the extracted time_point directly using session_mapping
- merged_time_point = session_mapping.get(time_point, None)
- # If mapping is not found, skip this file
- if merged_time_point is None:
- print(f"Warning: No mapping found for {time_point}. Skipping file: {file_path}")
- continue
- # Extract the q_type from the filename
- q_type = os.path.basename(file_path).split("_flipped")[0]
- # Attempt to find the stroke mask path
- try:
- search_stroke = os.path.join(os.path.dirname(file_path), "*StrokeMask_scaled.nii")
- stroke_path = glob.glob(search_stroke)[0]
- except IndexError:
- stroke_path = None
-
- # Create the temp path to search for dwi_masks
- temp_path = os.path.join(os.path.dirname(file_path), "RegisteredTractMasks_adjusted")
- dwi_masks = glob.glob(os.path.join(temp_path, "*dwi_flipped.nii.gz"))
- white_matter_index = next((i for i, path in enumerate(dwi_masks) if "White_matter" in path), None)
- white_matter_path = dwi_masks.pop(white_matter_index)
- # Load white matter mask
- white_matter_img = nib.load(white_matter_path)
- white_matter_data = white_matter_img.get_fdata()
-
- # Load DWI data using nibabel
- dwi_img = nib.load(file_path)
- dwi_data = dwi_img.get_fdata()
- # Loop through masks and calculate average values
- max_pixels = -1
- selected_slice = None
- pix_dialation = [0,1,2,3,4]
- for mask_path in tqdm(dwi_masks, desc="Processing masks", leave=False):
- mask_name = os.path.basename(mask_path).replace("registered", "").replace("flipped", "").replace(".nii.gz", "").replace("__dwi_","").replace("_dwi_","")
- mask_img = nib.load(mask_path)
- mask_data_pre_dilation = mask_img.get_fdata()
- # Append the extracted information to the list
- if "AMBA" in mask_name:
- AMBA_flag = True
- else:
- AMBA_flag = False
- # Find the slice with the highest number of pixels
- for i in range(mask_data_pre_dilation.shape[2]):
- num_pixels = np.count_nonzero(mask_data_pre_dilation[..., i])
- if num_pixels > max_pixels:
- max_pixels = num_pixels
- selected_slice = i
- # Assuming you want to calculate the average value within the mask region
- for pp in pix_dialation:
- if pp != 0:
- mask_data0 = binary_dilation(mask_data_pre_dilation, iterations=pp)
- else:
- mask_data0 = mask_data_pre_dilation > 0
-
- mask_data = mask_data0 & (white_matter_data > 0)
- if stroke_path:
- stroke_image = nib.load(stroke_path)
- stroke_data = stroke_image.get_fdata()
- # Subtracting stroke region from mask except in baseline
- if merged_time_point != 0:
- mask_data = mask_data & (stroke_data < 1)
- else:
- mask_data = mask_data
-
- masked_data = dwi_data[mask_data > 0]
- non_zero_masked_data = masked_data[masked_data != 0] # Exclude zero values
- average_value = np.nanmean(non_zero_masked_data) # Calculate mean only on non-zero elements
- strokeFlag = "Stroke"
- else:
- masked_data = dwi_data[mask_data > 0]
- non_zero_masked_data = masked_data[masked_data != 0] # Exclude zero values
- average_value = np.nanmean(non_zero_masked_data) # Calculate mean only on non-zero elements
- strokeFlag = "Sham"
- # Create and save the figure if the flag is True
- if create_figures:
- # Create a plot of the selected slice with overlay
- fig, ax = plt.subplots(1, 1, figsize=(10, 10), dpi=50)
- ax.imshow(dwi_data[..., selected_slice], cmap='gray')
- ax.imshow(mask_data[..., selected_slice], cmap='jet', alpha=0.5) # Overlay mask_data
- ax.set_title(f"{subject_id}_tp_{merged_time_point}_{q_type}_{average_value:.2f}_{mask_name}_{pp}")
- plt.axis('off')
-
- # Save the plot to the output path
- output_file = os.path.join(output_path,
- f"{subject_id}_tp_{merged_time_point}_{q_type}_{average_value:.2f}_{mask_name}_{pp}.png")
- plt.savefig(output_file, bbox_inches='tight', pad_inches=0)
-
- # Close the plot to release memory
- plt.close(fig)
- if save_nifti_mask and pp > 0 and merged_time_point in [3] and q_type == "fa":
- # Save the mask as NIfTI file
- output_nifti_file = os.path.join(output_path,
- f"{subject_id}_tp_{merged_time_point}_{q_type}_{mask_name}_{pp}.nii.gz")
- save_mask_as_nifti(mask_data, output_nifti_file)
-
-
- csvdata.append([file_path, subject_id, time_point, merged_time_point, q_type, mask_name, pp ,average_value,strokeFlag,AMBA_flag])
- # Create a DataFrame
- df = pd.DataFrame(csvdata,
- columns=["fullpath", "subjectID", "timePoint", "merged_timepoint", "Qtype", "mask_name",
- "dialation_amount","Value","Group","is_it_AMBA"])
- # Return DataFrame
- return df
- def main():
- # Parse command-line arguments
- parser = argparse.ArgumentParser(description="Process DWI files and generate figures.")
- parser.add_argument("-i", "--input", type=str, required=True, help="Input directory containing DWI files. e.g: proc_data (flipped fa, md etc file must be available. registering of masks must happen before this step")
- parser.add_argument("-o", "--output", type=str, required=True, help="Output directory to save output results (csv and figures).")
- parser.add_argument("-f", "--figures", action="store_true", default=False, help="set if you want figures to be saved as pngs")
- parser.add_argument("-n", "--nifti-mask", action="store_true", default=False, help="set if you want masks to be saved as NIfTI files.")
- args = parser.parse_args()
- # Create the output directory if it does not exist
- create_output_dir(args.output)
- # Process files
- df = process_files(args.input, args.output, args.figures, args.nifti_mask)
- # Save the DataFrame as CSV
- csv_file = os.path.join(args.output, "Quantitative_results_from_dwi_processing.csv")
- df.to_csv(csv_file, index=False)
- print("CSV file created at:", csv_file)
- # Print the DataFrame
- print(df)
- if __name__ == "__main__":
- main()
|