import os import glob import pandas as pd import nibabel as nib import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm import argparse from scipy.ndimage import binary_dilation def create_output_dir(output_dir): """Create the output directory if it does not exist.""" if not os.path.exists(output_dir): os.makedirs(output_dir) def save_mask_as_nifti(mask_data, output_file): """Save mask data as NIfTI file.""" # Convert boolean mask to integer (0s and 1s) mask_data_int = mask_data.astype(np.uint8) # Create NIfTI image object mask_img = nib.Nifti1Image(mask_data_int, affine=None) # Save NIfTI image nib.save(mask_img, output_file) def process_files(input_path, output_path, create_figures, save_nifti_mask): """Process DWI files and generate figures.""" # Define a dictionary to map Experiment values to the desired timepoints session_mapping = { 0: 0, 1: 3, 2: 3, 3: 3, 4: 3, 5: 3, 6: 7, 7: 7, 8: 7, 9: 7, 10: 7, 11: 14, 12: 14, 13: 14, 14: 14, 15: 14, 16: 14, 17: 14, 18: 14, 19: 21, 20: 21, 21: 21, 22: 21, 23: 21, 24: 21, 25: 21, 26: 28, 27: 28, 28: 28, 29: 28, 30: 28 , 42:42, 43:42, 56:56, 57:56 } # Initialize a list to store extracted information csvdata = [] # Iterate over files with progress bar file_paths = glob.glob(os.path.join(input_path, "**", "dwi", "DSI_studio", "*_flipped.nii.gz"), recursive=True) for file_path in tqdm(file_paths, desc="Processing files"): # Extract information from the file path subject_id = file_path.split(os.sep)[-5] time_point = file_path.split(os.sep)[-4] int_time_point = 0 if time_point == "ses-Baseline" else int(time_point.split("-P")[1]) merged_time_point = session_mapping[int_time_point] q_type = os.path.basename(file_path).split("_flipped")[0] try: searchStroke = os.path.join(os.path.dirname(file_path),"*StrokeMask_scaled.nii") stroke_path = glob.glob(searchStroke)[0] except IndexError: stroke_path = False # Create the temp path to search for dwi_masks temp_path = os.path.join(os.path.dirname(file_path), "RegisteredTractMasks_adjusted") dwi_masks = glob.glob(os.path.join(temp_path, "*dwi_flipped.nii.gz")) white_matter_index = next((i for i, path in enumerate(dwi_masks) if "White_matter" in path), None) white_matter_path = dwi_masks.pop(white_matter_index) # Load white matter mask white_matter_img = nib.load(white_matter_path) white_matter_data = white_matter_img.get_fdata() # Load DWI data using nibabel dwi_img = nib.load(file_path) dwi_data = dwi_img.get_fdata() # Loop through masks and calculate average values max_pixels = -1 selected_slice = None pix_dialation = [0,1,2,3,4] for mask_path in tqdm(dwi_masks, desc="Processing masks", leave=False): mask_name = os.path.basename(mask_path).replace("registered", "").replace("flipped", "").replace(".nii.gz", "").replace("__dwi_","") mask_img = nib.load(mask_path) mask_data_pre_dilation = mask_img.get_fdata() # Find the slice with the highest number of pixels for i in range(mask_data_pre_dilation.shape[2]): num_pixels = np.count_nonzero(mask_data_pre_dilation[..., i]) if num_pixels > max_pixels: max_pixels = num_pixels selected_slice = i # Assuming you want to calculate the average value within the mask region for pp in pix_dialation: if pp != 0: mask_data0 = binary_dilation(mask_data_pre_dilation, iterations=pp) else: mask_data0 = mask_data_pre_dilation > 0 mask_data = mask_data0 & (white_matter_data > 0) if stroke_path: stroke_image = nib.load(stroke_path) stroke_data = stroke_image.get_fdata() # Subtracting stroke region from mask if merged_time_point == 3 or merged_time_point == 28: mask_data = mask_data & (stroke_data < 1) else: mask_data = mask_data masked_data = dwi_data[mask_data > 0] non_zero_masked_data = masked_data[masked_data != 0] # Exclude zero values average_value = np.nanmean(non_zero_masked_data) # Calculate mean only on non-zero elements strokeFlag = "Stroke" else: masked_data = dwi_data[mask_data > 0] non_zero_masked_data = masked_data[masked_data != 0] # Exclude zero values average_value = np.nanmean(non_zero_masked_data) # Calculate mean only on non-zero elements strokeFlag = "Sham" # Create and save the figure if the flag is True if create_figures: # Create a plot of the selected slice with overlay fig, ax = plt.subplots(1, 1, figsize=(10, 10), dpi=50) ax.imshow(dwi_data[..., selected_slice], cmap='gray') ax.imshow(mask_data[..., selected_slice], cmap='jet', alpha=0.5) # Overlay mask_data ax.set_title(f"{subject_id}_tp_{merged_time_point}_{q_type}_{average_value:.2f}_{mask_name}_{pp}") plt.axis('off') # Save the plot to the output path output_file = os.path.join(output_path, f"{subject_id}_tp_{merged_time_point}_{q_type}_{average_value:.2f}_{mask_name}_{pp}.png") plt.savefig(output_file, bbox_inches='tight', pad_inches=0) # Close the plot to release memory plt.close(fig) if save_nifti_mask and pp > 0 and merged_time_point in [3, 28] and q_type == "fa": # Save the mask as NIfTI file output_nifti_file = os.path.join(output_path, f"{subject_id}_tp_{merged_time_point}_{q_type}_{mask_name}_{pp}.nii.gz") save_mask_as_nifti(mask_data, output_nifti_file) # Append the extracted information to the list csvdata.append([file_path, subject_id, time_point, int_time_point, merged_time_point, q_type, mask_name, pp ,average_value,strokeFlag]) # Create a DataFrame df = pd.DataFrame(csvdata, columns=["fullpath", "subjectID", "timePoint", "int_timepoint", "merged_timepoint", "Qtype", "mask_name", "dialation_amount","Value","Group"]) # Return DataFrame return df def main(): # Parse command-line arguments parser = argparse.ArgumentParser(description="Process DWI files and generate figures.") parser.add_argument("-i", "--input", type=str, required=True, help="Input directory containing DWI files. e.g: proc_data (flipped fa, md etc file must be available. registering of masks must happen before this step") parser.add_argument("-o", "--output", type=str, required=True, help="Output directory to save output results (csv and figures).") parser.add_argument("-f", "--figures", action="store_true", default=False, help="set if you want figures to be saved as pngs") parser.add_argument("-n", "--nifti-mask", action="store_true", default=False, help="set if you want masks to be saved as NIfTI files.") args = parser.parse_args() # Create the output directory if it does not exist create_output_dir(args.output) # Process files df = process_files(args.input, args.output, args.figures, args.nifti_mask) # Save the DataFrame as CSV csv_file = os.path.join(args.output, "Quantitative_results_from_dwi_processing.csv") df.to_csv(csv_file, index=False) print("CSV file created at:", csv_file) # Print the DataFrame print(df) if __name__ == "__main__": main()