proof_of_modalities.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import nibabel as nib
  2. import numpy as np
  3. import os
  4. import glob
  5. import pandas as pd
  6. import matplotlib.pyplot as plt
  7. from scipy.ndimage import rotate
  8. import math
  9. # Assuming __file__ is defined in your environment; otherwise, replace it with your script's filename
  10. code_dir = os.path.dirname(os.path.abspath(__file__))
  11. parent_dir = os.path.dirname(code_dir)
  12. # Define the search paths for viral tracing and masks
  13. viral_tracing_SearchPath = os.path.join(parent_dir, "output", "Viral_tracing_flipped", "*fli*.nii*")
  14. masks_SearchPath = os.path.join(parent_dir, "output", "Tract_Mask_registered", "*.nii*")
  15. # Get the list of viral tracing files and masks files
  16. viral_tracing_files = glob.glob(viral_tracing_SearchPath)
  17. masks_files = glob.glob(masks_SearchPath)
  18. # Create a folder for saving plots if it doesn't exist
  19. save_folder = os.path.join(parent_dir, "output", "MaskComparisons")
  20. if not os.path.exists(save_folder):
  21. os.makedirs(save_folder)
  22. # Set the FigureFlag
  23. FigureFlag = True
  24. # Create an empty list to store the results as dictionaries
  25. results_list = []
  26. # Loop through each viral tracing file
  27. for vv in viral_tracing_files:
  28. # Exclude the file "average_template_50_from_website_flipped.nii.gz"
  29. if "average_template_50_from_website_flipped.nii.gz" in vv:
  30. continue
  31. # Load the viral tracing NIfTI file
  32. viral_tracing_img = nib.load(vv)
  33. viral_tracing_data = viral_tracing_img.get_fdata()
  34. # Loop through each masks file
  35. for ff in masks_files:
  36. # Load the masks NIfTI file
  37. masks_img = nib.load(ff)
  38. masks_data = masks_img.get_fdata()
  39. # Calculate dynamic thresholds based on maximum and minimum values in the mask, excluding the actual min and max
  40. min_threshold = np.min(viral_tracing_data)
  41. max_threshold = np.max(viral_tracing_data)
  42. thresholds = np.linspace(min_threshold, max_threshold, num=5)[0:-1] # Excludes actual min and max
  43. # Loop through each threshold
  44. for threshold in thresholds:
  45. # Apply threshold to viral tracing data
  46. viral_tracing_thresholded = (viral_tracing_data > threshold)
  47. # Calculate percentage of coverage
  48. overlap = np.logical_and(viral_tracing_thresholded, masks_data)
  49. num_covered_voxels = np.sum(overlap)
  50. total_voxels = np.sum(masks_data)
  51. percentage_covered = (num_covered_voxels / total_voxels) * 100
  52. if FigureFlag:
  53. # Rotate images by 90 degrees clockwise
  54. masks_data_rotated = np.rot90(masks_data, k=-1)
  55. viral_tracing_thresholded_rotated = np.rot90(viral_tracing_thresholded, k=-1)
  56. overlap_rotated = np.rot90(overlap, k=-1)
  57. # Plot middle 50% slices
  58. num_slices = masks_data_rotated.shape[2]
  59. start_index = int(num_slices * 0.4)
  60. end_index = int(num_slices * 0.9)
  61. num_cols = int(math.ceil(math.sqrt(end_index - start_index))) # Calculate number of columns as ceiling of square root
  62. num_rows = int(math.ceil((end_index - start_index) / num_cols)) # Calculate number of rows
  63. fig, axes = plt.subplots(num_rows, num_cols, figsize=(20, 20))
  64. for i in range(start_index, end_index):
  65. row = (i - start_index) // num_cols
  66. col = (i - start_index) % num_cols
  67. axes[row, col].imshow(masks_data_rotated[:, :, i], cmap='Reds') # Drawn mask in bright red
  68. axes[row, col].imshow(viral_tracing_thresholded_rotated[:, :, i], alpha=0.5, cmap='Greens') # Viral tracing in grass green
  69. axes[row, col].imshow(overlap_rotated[:, :, i], alpha=0.5, cmap='Blues') # Overlap in light blue
  70. #axes[row, col].set_title(f"Slice {i}")
  71. axes[row, col].axis('off') # Turn off axis for speed
  72. # Save the figure if FigureFlag is True
  73. plt.suptitle(f"Overlap of Masks for {os.path.basename(vv)} and {os.path.basename(ff)} (Threshold {threshold})")
  74. save_path = os.path.join(save_folder, f"{os.path.basename(vv).replace('.nii.gz', '')}_{os.path.basename(ff).replace('.nii', '')}_Threshold_{threshold}_middle_slices.png")
  75. plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=80) # Optimize for speed
  76. plt.close()
  77. # Add results to the list as a dictionary
  78. results_list.append({"Viral Tracing File": os.path.basename(vv).replace(".nii.gz", ""), "Mask File": os.path.basename(ff).replace(".nii", ""), "Threshold": threshold, "Percentage of Coverage": percentage_covered})
  79. # Convert the list of dictionaries into a DataFrame
  80. results_df = pd.DataFrame(results_list)
  81. # Save results to a CSV file
  82. SavePath = os.path.join(parent_dir, "output", "Overlap_metrics_of_VT_and_dwi_mass.csv")
  83. results_df.to_csv(SavePath, index=False)