overlayPlot3D.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import nibabel as nib
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from scipy.ndimage import rotate
  5. # Define the file paths
  6. templateFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\input\average_template_50_from_website.nii.gz"
  7. densityFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\input\12_wks_coronal_100141780_50um_projection_density.nii.gz"
  8. maskFilePath = r"\\10.209.5.114\Projects\Student_projects\14_Aref_Kalantari_2021\Luca\CC_Mop_desnitymap_50um.nii.gz"
  9. # Load the NIfTI files
  10. templateNifti = nib.load(templateFilePath)
  11. densityNifti = nib.load(densityFilePath)
  12. maskNifti = nib.load(maskFilePath)
  13. # Ensure that both images have the same data type (e.g., cast density to the data type of template)
  14. densityNifti = nib.Nifti1Image(densityNifti.get_fdata().astype(templateNifti.get_fdata().dtype), densityNifti.affine)
  15. # Customize the figure DPI
  16. dpi = 400
  17. # Get the number of slices along each dimension
  18. num_slices_dim0, num_slices_dim1, num_slices_dim2 = templateNifti.shape
  19. # Set the slice step for display (every tenth slice)
  20. slice_step = 10
  21. # Choose the dimension for iteration (0, 1, or 2)
  22. iteration_dimension = 0 # Change this value to select the dimension
  23. # Calculate the number of rows and columns for the grid based on the selected dimension
  24. if iteration_dimension == 0:
  25. num_slices = num_slices_dim0
  26. elif iteration_dimension == 1:
  27. num_slices = num_slices_dim1
  28. else:
  29. num_slices = num_slices_dim2
  30. grid_size = int(np.sqrt(num_slices // slice_step))
  31. if grid_size * grid_size < num_slices // slice_step:
  32. grid_size += 1
  33. cm = 1/2.54
  34. # Adjust the figure size to be larger
  35. fig, axs = plt.subplots(grid_size, grid_size, figsize=(20, 20), dpi=dpi)
  36. # Hide axis numbers
  37. for ax_row in axs:
  38. for ax in ax_row:
  39. ax.axis('off')
  40. # Load the mask and perform necessary flips
  41. mask_data = np.squeeze(maskNifti.get_fdata())
  42. # Swap the mask over the x and y axes
  43. mask_data = mask_data.swapaxes(0, 2)
  44. mask_data = np.flip(mask_data, axis=0)
  45. # Create a new NIfTI image with the adjusted mask data
  46. adjusted_mask_nifti = nib.Nifti1Image(mask_data, maskNifti.affine)
  47. # Save the new NIfTI image with adjusted header
  48. output_mask_filepath = r"\\10.209.5.114\Projects\Student_projects\14_Aref_Kalantari_2021\Luca\CC_Mop_desnitymap_50um_reoriented.nii.gz" # Specify the desired output file path
  49. nib.save(adjusted_mask_nifti, output_mask_filepath)
  50. # Loop through and display every nth slice as subplots along the selected dimension
  51. for i, sliceIdx in enumerate(range(0, num_slices, slice_step)):
  52. row = i // grid_size
  53. col = i % grid_size
  54. # Extract the desired slice from the template and density based on the selected dimension
  55. if iteration_dimension == 0:
  56. templateSlice = np.squeeze(templateNifti.get_fdata()[sliceIdx, :, :]).astype(np.uint16)
  57. densitySlice = np.squeeze(densityNifti.get_fdata()[sliceIdx, :, :]).astype(np.double)
  58. maskSlice = np.squeeze(mask_data[sliceIdx, :, :])
  59. elif iteration_dimension == 1:
  60. templateSlice = np.squeeze(templateNifti.get_fdata()[:, sliceIdx, :]).astype(np.uint16)
  61. densitySlice = np.squeeze(densityNifti.get_fdata()[:, sliceIdx, :]).astype(np.double)
  62. maskSlice = np.squeeze(mask_data[:, sliceIdx, :])
  63. else:
  64. templateSlice = np.squeeze(templateNifti.get_fdata()[:, :, sliceIdx]).astype(np.uint16)
  65. densitySlice = np.squeeze(densityNifti.get_fdata()[:, :, sliceIdx]).astype(np.double)
  66. maskSlice = np.squeeze(mask_data[:, :, sliceIdx])
  67. # Normalize the templateSlice to [0, 1] for visualization
  68. templateSlice = (templateSlice - np.min(templateSlice)) / (np.max(templateSlice) - np.min(templateSlice))
  69. # Overlay the density and mask onto the template slice
  70. overlayedImage = templateSlice + densitySlice + maskSlice
  71. # Display the overlaid image in the corresponding subplot
  72. axs[row, col].imshow(templateSlice, cmap='gray')
  73. axs[row, col].imshow(densitySlice, cmap='hot', alpha=0.4)
  74. axs[row, col].imshow(maskSlice, cmap='hot', alpha=0.2) # Apply "hot" colormap to the mask
  75. # Apply "cool" colormap to the mask
  76. # Remove empty subplots if necessary
  77. for i in range(num_slices // slice_step, grid_size * grid_size):
  78. fig.delaxes(axs.flatten()[i])
  79. # Adjust spacing between subplots
  80. plt.tight_layout()
  81. # Show the plot
  82. plt.show()