overlayPlot.py 3.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import nibabel as nib
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from scipy.ndimage import rotate
  5. # Define the file paths
  6. templateFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\input\average_template_50_from_website.nii.gz"
  7. densityFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\input\12_wks_coronal_100141780_50um_projection_density.nii.gz"
  8. maskFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\output\Nifti_Trakte_registered\CC_Mop_desnitymask_50um.nii.gz"
  9. # Load the NIfTI files
  10. templateNifti = nib.load(templateFilePath)
  11. densityNifti = nib.load(densityFilePath)
  12. maskNifti = nib.load(maskFilePath)
  13. # Ensure that both images have the same data type (e.g., cast density to the data type of template)
  14. densityNifti = nib.Nifti1Image(densityNifti.get_fdata().astype(templateNifti.get_fdata().dtype), densityNifti.affine)
  15. # Customize the figure DPI
  16. dpi = 400
  17. # Get the number of slices along each dimension
  18. num_slices_dim0, num_slices_dim1, num_slices_dim2 = templateNifti.shape
  19. # Set the slice step for display (every tenth slice)
  20. slice_step = 10
  21. # Choose the dimension for iteration (0, 1, or 2)
  22. iteration_dimension = 0 # Change this value to select the dimension
  23. # Calculate the number of rows and columns for the grid based on the selected dimension
  24. if iteration_dimension == 0:
  25. num_slices = num_slices_dim0
  26. elif iteration_dimension == 1:
  27. num_slices = num_slices_dim1
  28. else:
  29. num_slices = num_slices_dim2
  30. grid_size = int(np.sqrt(num_slices // slice_step))
  31. if grid_size * grid_size < num_slices // slice_step:
  32. grid_size += 1
  33. cm = 1/2.54
  34. # Adjust the figure size to be larger
  35. fig, axs = plt.subplots(grid_size, grid_size, figsize=(20, 20), dpi=dpi)
  36. # Hide axis numbers
  37. for ax_row in axs:
  38. for ax in ax_row:
  39. ax.axis('off')
  40. # Load the mask and perform necessary flips
  41. mask_data = np.squeeze(maskNifti.get_fdata())
  42. # Swap the mask over the x and y axes
  43. mask_data = mask_data.swapaxes(0, 2)
  44. mask_data = np.flip(mask_data, axis=0)
  45. # Loop through and display every nth slice as subplots along the selected dimension
  46. for i, sliceIdx in enumerate(range(0, num_slices, slice_step)):
  47. row = i // grid_size
  48. col = i % grid_size
  49. # Extract the desired slice from the template and density based on the selected dimension
  50. if iteration_dimension == 0:
  51. templateSlice = np.squeeze(templateNifti.get_fdata()[sliceIdx, :, :]).astype(np.uint16)
  52. densitySlice = np.squeeze(densityNifti.get_fdata()[sliceIdx, :, :]).astype(np.double)
  53. maskSlice = np.squeeze(mask_data[sliceIdx, :, :])
  54. elif iteration_dimension == 1:
  55. templateSlice = np.squeeze(templateNifti.get_fdata()[:, sliceIdx, :]).astype(np.uint16)
  56. densitySlice = np.squeeze(densityNifti.get_fdata()[:, sliceIdx, :]).astype(np.double)
  57. maskSlice = np.squeeze(mask_data[:, sliceIdx, :])
  58. else:
  59. templateSlice = np.squeeze(templateNifti.get_fdata()[:, :, sliceIdx]).astype(np.uint16)
  60. densitySlice = np.squeeze(densityNifti.get_fdata()[:, :, sliceIdx]).astype(np.double)
  61. maskSlice = np.squeeze(mask_data[:, :, sliceIdx])
  62. # Normalize the templateSlice to [0, 1] for visualization
  63. templateSlice = (templateSlice - np.min(templateSlice)) / (np.max(templateSlice) - np.min(templateSlice))
  64. # Overlay the density and mask onto the template slice
  65. overlayedImage = templateSlice + densitySlice + maskSlice
  66. # Display the overlaid image in the corresponding subplot
  67. axs[row, col].imshow(templateSlice, cmap='gray')
  68. axs[row, col].imshow(densitySlice, cmap='hot', alpha=0.3)
  69. axs[row, col].imshow(maskSlice, cmap='cool', alpha=0.3) # Apply "hot" colormap to the mask
  70. # Apply "cool" colormap to the mask
  71. # Remove empty subplots if necessary
  72. for i in range(num_slices // slice_step, grid_size * grid_size):
  73. fig.delaxes(axs.flatten()[i])
  74. # Adjust spacing between subplots
  75. plt.tight_layout()
  76. # Show the plot
  77. plt.show()