123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- import nibabel as nib
- import numpy as np
- import os
- import glob
- import pandas as pd
- import matplotlib.pyplot as plt
- from scipy.ndimage import rotate
- import math
- # Assuming __file__ is defined in your environment; otherwise, replace it with your script's filename
- code_dir = os.path.dirname(os.path.abspath(__file__))
- parent_dir = os.path.dirname(code_dir)
- # Define the search paths for viral tracing and masks
- viral_tracing_SearchPath = os.path.join(parent_dir, "output", "Viral_tracing_flipped", "*fli*.nii*")
- masks_SearchPath = os.path.join(parent_dir, "output", "Tract_Mask_registered", "*.nii*")
- # Get the list of viral tracing files and masks files
- viral_tracing_files = glob.glob(viral_tracing_SearchPath)
- masks_files = glob.glob(masks_SearchPath)
- # Create a folder for saving plots if it doesn't exist
- save_folder = os.path.join(parent_dir, "output", "MaskComparisons")
- if not os.path.exists(save_folder):
- os.makedirs(save_folder)
- # Set the FigureFlag
- FigureFlag = True
- # Create an empty list to store the results as dictionaries
- results_list = []
- # Loop through each viral tracing file
- for vv in viral_tracing_files:
- # Exclude the file "average_template_50_from_website_flipped.nii.gz"
- if "average_template_50_from_website_flipped.nii.gz" in vv:
- continue
- # Load the viral tracing NIfTI file
- viral_tracing_img = nib.load(vv)
- viral_tracing_data = viral_tracing_img.get_fdata()
- # Loop through each masks file
- for ff in masks_files:
- # Load the masks NIfTI file
- masks_img = nib.load(ff)
- masks_data = masks_img.get_fdata()
- # Calculate dynamic thresholds based on maximum and minimum values in the mask, excluding the actual min and max
- min_threshold = np.min(viral_tracing_data)
- max_threshold = np.max(viral_tracing_data)
- thresholds = np.linspace(min_threshold, max_threshold, num=5)[0:-1] # Excludes actual min and max
- # Loop through each threshold
- for threshold in thresholds:
- # Apply threshold to viral tracing data
- viral_tracing_thresholded = (viral_tracing_data > threshold)
- # Calculate percentage of coverage
- overlap = np.logical_and(viral_tracing_thresholded, masks_data)
- num_covered_voxels = np.sum(overlap)
- total_voxels = np.sum(masks_data)
- percentage_covered = (num_covered_voxels / total_voxels) * 100
- if FigureFlag:
- # Rotate images by 90 degrees clockwise
- masks_data_rotated = np.rot90(masks_data, k=-1)
- viral_tracing_thresholded_rotated = np.rot90(viral_tracing_thresholded, k=-1)
- overlap_rotated = np.rot90(overlap, k=-1)
-
- # Plot middle 50% slices
- num_slices = masks_data_rotated.shape[2]
- start_index = int(num_slices * 0.4)
- end_index = int(num_slices * 0.9)
- num_cols = int(math.ceil(math.sqrt(end_index - start_index))) # Calculate number of columns as ceiling of square root
- num_rows = int(math.ceil((end_index - start_index) / num_cols)) # Calculate number of rows
- fig, axes = plt.subplots(num_rows, num_cols, figsize=(20, 20))
- for i in range(start_index, end_index):
- row = (i - start_index) // num_cols
- col = (i - start_index) % num_cols
- axes[row, col].imshow(masks_data_rotated[:, :, i], cmap='Reds') # Drawn mask in bright red
- axes[row, col].imshow(viral_tracing_thresholded_rotated[:, :, i], alpha=0.5, cmap='Greens') # Viral tracing in grass green
- axes[row, col].imshow(overlap_rotated[:, :, i], alpha=0.5, cmap='Blues') # Overlap in light blue
- #axes[row, col].set_title(f"Slice {i}")
- axes[row, col].axis('off') # Turn off axis for speed
-
- # Save the figure if FigureFlag is True
-
- plt.suptitle(f"Overlap of Masks for {os.path.basename(vv)} and {os.path.basename(ff)} (Threshold {threshold})")
- save_path = os.path.join(save_folder, f"{os.path.basename(vv).replace('.nii.gz', '')}_{os.path.basename(ff).replace('.nii', '')}_Threshold_{threshold}_middle_slices.png")
- plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=80) # Optimize for speed
- plt.close()
- # Add results to the list as a dictionary
- results_list.append({"Viral Tracing File": os.path.basename(vv).replace(".nii.gz", ""), "Mask File": os.path.basename(ff).replace(".nii", ""), "Threshold": threshold, "Percentage of Coverage": percentage_covered})
- # Convert the list of dictionaries into a DataFrame
- results_df = pd.DataFrame(results_list)
- # Save results to a CSV file
- SavePath = os.path.join(parent_dir, "output", "Overlap_metrics_of_VT_and_dwi_mass.csv")
- results_df.to_csv(SavePath, index=False)
|