overlayPlot3D.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import nibabel as nib
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from scipy.ndimage import rotate
  5. # Define the file paths
  6. templateFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\input\average_template_50_from_website.nii.gz"
  7. densityFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\input\12_wks_coronal_100141780_50um_projection_density.nii.gz"
  8. maskFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\output\Nifti_Trakte_registered\CC_Mop_dilated_registered.nii.gz"
  9. # Load the NIfTI files
  10. templateNifti = nib.load(templateFilePath)
  11. densityNifti = nib.load(densityFilePath)
  12. maskNifti = nib.load(maskFilePath)
  13. # Ensure that both images have the same data type (e.g., cast density to the data type of template)
  14. densityNifti = nib.Nifti1Image(densityNifti.get_fdata().astype(templateNifti.get_fdata().dtype), densityNifti.affine)
  15. # Customize the figure DPI
  16. dpi = 400
  17. # Get the number of slices along each dimension
  18. num_slices_dim0, num_slices_dim1, num_slices_dim2 = templateNifti.shape
  19. # Create a 3D array to hold the overlay
  20. overlay_volume = np.zeros((num_slices_dim0, num_slices_dim1, num_slices_dim2, 4), dtype=np.uint8)
  21. # Loop through and overlay all slices
  22. for sliceIdx in range(num_slices_dim2):
  23. # Extract the desired slice from the template, density, and mask
  24. templateSlice = np.squeeze(templateNifti.get_fdata()[:, :, sliceIdx]).astype(np.uint16)
  25. densitySlice = np.squeeze(densityNifti.get_fdata()[:, :, sliceIdx]).astype(np.double)
  26. mask_data = maskNifti.get_fdata()
  27. mask_data = mask_data.swapaxes(0, 2)
  28. mask_data = np.flip(mask_data, axis=0)
  29. maskSlice = np.squeeze(mask_data[:, :, sliceIdx])
  30. # Normalize the templateSlice to [0, 1] for visualization
  31. templateSlice = (templateSlice - np.min(templateSlice)) / (np.max(templateSlice) - np.min(templateSlice))
  32. # Overlay the density and mask onto the template slice
  33. overlayedImage = templateSlice + densitySlice + maskSlice
  34. # Normalize the overlay to [0, 1]
  35. overlayedImage = (overlayedImage - np.min(overlayedImage)) / (np.max(overlayedImage) - np.min(overlayedImage))
  36. # Convert the overlay image to RGBA format (4 channels for transparency)
  37. overlay_image_rgba = plt.cm.gray(overlayedImage)
  38. overlay_image_rgba[:, :, 3] = (overlayedImage * 255).astype(np.uint8) # Set alpha channel based on intensity
  39. # Update the 3D overlay volume
  40. overlay_volume[:, :, sliceIdx, :] = overlay_image_rgba
  41. # Create a figure for the 3D visualization
  42. fig = plt.figure(figsize=(10, 10), dpi=dpi)
  43. ax = fig.add_subplot(111, projection='3d')
  44. # Plot the 3D overlay volume
  45. ax.voxels(overlay_volume[:, :, :, 0] > 0, facecolors=overlay_volume[:, :, :, :3] / 255., edgecolor='k')
  46. # Set the aspect ratio to be equal for all dimensions
  47. ax.set_box_aspect([1, 1, 1])
  48. # Show the 3D plot
  49. plt.show()