import nibabel as nib import numpy as np import os import glob import pandas as pd import matplotlib.pyplot as plt from scipy.ndimage import rotate import math # Assuming __file__ is defined in your environment; otherwise, replace it with your script's filename code_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(code_dir) # Define the search paths for viral tracing and masks viral_tracing_SearchPath = os.path.join(parent_dir, "output", "Viral_tracing_flipped", "*fli*.nii*") masks_SearchPath = os.path.join(parent_dir, "output", "Tract_Mask_registered", "*.nii*") # Get the list of viral tracing files and masks files viral_tracing_files = glob.glob(viral_tracing_SearchPath) masks_files = glob.glob(masks_SearchPath) # Create a folder for saving plots if it doesn't exist save_folder = os.path.join(parent_dir, "output", "MaskComparisons") if not os.path.exists(save_folder): os.makedirs(save_folder) # Set the FigureFlag FigureFlag = True # Create an empty list to store the results as dictionaries results_list = [] # Loop through each viral tracing file for vv in viral_tracing_files: # Exclude the file "average_template_50_from_website_flipped.nii.gz" if "average_template_50_from_website_flipped.nii.gz" in vv: continue # Load the viral tracing NIfTI file viral_tracing_img = nib.load(vv) viral_tracing_data = viral_tracing_img.get_fdata() # Loop through each masks file for ff in masks_files: # Load the masks NIfTI file masks_img = nib.load(ff) masks_data = masks_img.get_fdata() # Calculate dynamic thresholds based on maximum and minimum values in the mask, excluding the actual min and max min_threshold = np.min(viral_tracing_data) max_threshold = np.max(viral_tracing_data) thresholds = np.linspace(min_threshold, max_threshold, num=5)[0:-1] # Excludes actual min and max # Loop through each threshold for threshold in thresholds: # Apply threshold to viral tracing data viral_tracing_thresholded = (viral_tracing_data > threshold) # Calculate percentage of coverage overlap = np.logical_and(viral_tracing_thresholded, masks_data) num_covered_voxels = np.sum(overlap) total_voxels = np.sum(masks_data) percentage_covered = (num_covered_voxels / total_voxels) * 100 if FigureFlag: # Rotate images by 90 degrees clockwise masks_data_rotated = np.rot90(masks_data, k=-1) viral_tracing_thresholded_rotated = np.rot90(viral_tracing_thresholded, k=-1) overlap_rotated = np.rot90(overlap, k=-1) # Plot middle 50% slices num_slices = masks_data_rotated.shape[2] start_index = int(num_slices * 0.4) end_index = int(num_slices * 0.9) num_cols = int(math.ceil(math.sqrt(end_index - start_index))) # Calculate number of columns as ceiling of square root num_rows = int(math.ceil((end_index - start_index) / num_cols)) # Calculate number of rows fig, axes = plt.subplots(num_rows, num_cols, figsize=(20, 20)) for i in range(start_index, end_index): row = (i - start_index) // num_cols col = (i - start_index) % num_cols axes[row, col].imshow(masks_data_rotated[:, :, i], cmap='Reds') # Drawn mask in bright red axes[row, col].imshow(viral_tracing_thresholded_rotated[:, :, i], alpha=0.5, cmap='Greens') # Viral tracing in grass green axes[row, col].imshow(overlap_rotated[:, :, i], alpha=0.5, cmap='Blues') # Overlap in light blue #axes[row, col].set_title(f"Slice {i}") axes[row, col].axis('off') # Turn off axis for speed # Save the figure if FigureFlag is True plt.suptitle(f"Overlap of Masks for {os.path.basename(vv)} and {os.path.basename(ff)} (Threshold {threshold})") save_path = os.path.join(save_folder, f"{os.path.basename(vv).replace('.nii.gz', '')}_{os.path.basename(ff).replace('.nii', '')}_Threshold_{threshold}_middle_slices.png") plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=80) # Optimize for speed plt.close() # Add results to the list as a dictionary results_list.append({"Viral Tracing File": os.path.basename(vv).replace(".nii.gz", ""), "Mask File": os.path.basename(ff).replace(".nii", ""), "Threshold": threshold, "Percentage of Coverage": percentage_covered}) # Convert the list of dictionaries into a DataFrame results_df = pd.DataFrame(results_list) # Save results to a CSV file SavePath = os.path.join(parent_dir, "output", "Overlap_metrics_of_VT_and_dwi_mass.csv") results_df.to_csv(SavePath, index=False)