123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149 |
- import os
- import glob
- import pandas as pd
- import nibabel as nib
- import numpy as np
- import matplotlib.pyplot as plt
- from tqdm import tqdm
- import argparse
- from scipy.ndimage import binary_dilation
- def create_output_dir(output_dir):
- """Create the output directory if it does not exist."""
- if not os.path.exists(output_dir):
- os.makedirs(output_dir)
- def process_files(input_path, output_path, create_figures):
- """Process DWI files and generate figures."""
- # Define a dictionary to map Experiment values to the desired timepoints
- session_mapping = {
- 0: 0,
- 1: 3, 2: 3, 3: 3,
- 4: 3, 5: 3, 6: 7, 7: 7,
- 8: 7, 9: 7, 10: 7, 11: 14, 12: 14,
- 13: 14, 14: 14, 15: 14, 16: 14, 17: 14, 18: 14, 19: 21,
- 20: 21, 21: 21, 22: 21, 23: 21, 24: 21, 25: 21, 26: 28,
- 27: 28, 28: 28, 29: 28, 30: 28 , 42:42, 43:42, 56:56, 57:56
- }
- # Initialize a list to store extracted information
- csvdata = []
- # Iterate over files with progress bar
- file_paths = glob.glob(os.path.join(input_path, "**", "dwi", "DSI_studio", "*_flipped.nii.gz"), recursive=True)
- for file_path in tqdm(file_paths, desc="Processing files"):
- # Extract information from the file path
- subject_id = file_path.split(os.sep)[-5]
- time_point = file_path.split(os.sep)[-4]
- print(file_path)
- int_time_point = 0 if time_point == "ses-Baseline" else int(time_point.split("-P")[1])
- merged_time_point = session_mapping[int_time_point]
- q_type = os.path.basename(file_path).split("_flipped")[0]
- try:
- searchStroke = os.path.join(os.path.dirname(file_path),"*StrokeMask_scaled.nii")
- #print(searchStroke)
- stroke_path = glob.glob(searchStroke)[0]
- print(stroke_path)
- except IndexError:
- stroke_path = False
-
- # Create the temp path to search for dwi_masks
- temp_path = os.path.join(os.path.dirname(file_path), "RegisteredTractMasks_adjusted")
- dwi_masks = glob.glob(os.path.join(temp_path, "*dwi_flipped.nii.gz"))
- # Load DWI data using nibabel
- dwi_img = nib.load(file_path)
- dwi_data = dwi_img.get_fdata()
- # Loop through masks and calculate average values
- max_pixels = -1
- selected_slice = None
- pix_dialation = [0,1,2]
- for mask_path in tqdm(dwi_masks, desc="Processing masks", leave=False):
- mask_name = os.path.basename(mask_path).replace("registered", "").replace("flipped", "").replace(".nii.gz", "").replace("__dwi_","")
- mask_img = nib.load(mask_path)
- mask_data_pre_dilation = mask_img.get_fdata()
- # Find the slice with the highest number of pixels
- for i in range(mask_data_pre_dilation.shape[2]):
- num_pixels = np.count_nonzero(mask_data_pre_dilation[..., i])
- if num_pixels > max_pixels:
- max_pixels = num_pixels
- selected_slice = i
- # Assuming you want to calculate the average value within the mask region
- for pp in pix_dialation:
- if pp != 0:
- mask_data = binary_dilation(mask_data_pre_dilation, iterations=pp)
- else:
- mask_data = mask_data_pre_dilation > 0
-
- if stroke_path:
- stroke_image = nib.load(stroke_path)
- stroke_data = stroke_image.get_fdata()
- # Subtracting stroke region from mask
- mask_data = mask_data & (stroke_data < 1)
- masked_data = dwi_data[mask_data > 0]
- average_value = np.nanmean(masked_data) # Handle case when masked_data is empty
- strokeFlag = "Stroke"
- else:
- masked_data = dwi_data[mask_data > 0]
- average_value = np.nanmean(masked_data) # Handle case when masked_data is empty
- strokeFlag = "Sham"
- # Create and save the figure if the flag is True
- if create_figures:
- # Create a plot of the selected slice with overlay
- fig, ax = plt.subplots(1, 1, figsize=(10, 10), dpi=50)
- ax.imshow(dwi_data[..., selected_slice], cmap='gray')
- ax.imshow(mask_data[..., selected_slice], cmap='jet', alpha=0.5) # Overlay mask_data
- ax.set_title(f"{subject_id}_tp_{merged_time_point}_{q_type}_{average_value:.2f}_{mask_name}_{pp}")
- plt.axis('off')
-
- # Save the plot to the output path
- output_file = os.path.join(output_path,
- f"{subject_id}_tp_{merged_time_point}_{q_type}_{average_value:.2f}_{mask_name}_{pp}.png")
- plt.savefig(output_file, bbox_inches='tight', pad_inches=0)
-
- # Close the plot to release memory
- plt.close(fig)
-
- # Append the extracted information to the list
- csvdata.append([file_path, subject_id, time_point, int_time_point, merged_time_point, q_type, mask_name, pp ,average_value,strokeFlag])
- # Create a DataFrame
- df = pd.DataFrame(csvdata,
- columns=["fullpath", "subjectID", "timePoint", "int_timepoint", "merged_timepoint", "Qtype", "mask_name",
- "dialation_amount","Value","Group"])
- # Return DataFrame
- return df
- def main():
- # Parse command-line arguments
- parser = argparse.ArgumentParser(description="Process DWI files and generate figures.")
- parser.add_argument("-i", "--input", type=str, required=True, help="Input directory containing DWI files. e.g: proc_data (flipped fa, md etc file must be available. registering of masks must happen before this step")
- parser.add_argument("-o", "--output", type=str, required=True, help="Output directory to save output results (csv and figures).")
- parser.add_argument("-f", "--figures", action="store_true", default=False, help="set if you want figures to be saved as pngs")
- args = parser.parse_args()
- # Create the output directory if it does not exist
- create_output_dir(args.output)
- # Process files
- df = process_files(args.input, args.output,args.figures)
- # Save the DataFrame as CSV
- csv_file = os.path.join(args.output, "Quantitative_results_from_dwi_processing.csv")
- df.to_csv(csv_file, index=False)
- print("CSV file created at:", csv_file)
- # Print the DataFrame
- print(df)
- if __name__ == "__main__":
- main()
|