import nibabel as nib import numpy as np import matplotlib.pyplot as plt from scipy.ndimage import rotate # Define the file paths templateFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\input\average_template_50_from_website.nii.gz" densityFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\input\12_wks_coronal_100141780_50um_projection_density.nii.gz" maskFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\output\Nifti_Trakte_registered\CC_Mop_dilated_registered.nii.gz" # Load the NIfTI files templateNifti = nib.load(templateFilePath) densityNifti = nib.load(densityFilePath) maskNifti = nib.load(maskFilePath) # Ensure that both images have the same data type (e.g., cast density to the data type of template) densityNifti = nib.Nifti1Image(densityNifti.get_fdata().astype(templateNifti.get_fdata().dtype), densityNifti.affine) # Customize the figure DPI dpi = 400 # Get the number of slices along each dimension num_slices_dim0, num_slices_dim1, num_slices_dim2 = templateNifti.shape # Create a 3D array to hold the overlay overlay_volume = np.zeros((num_slices_dim0, num_slices_dim1, num_slices_dim2, 4), dtype=np.uint8) # Loop through and overlay all slices for sliceIdx in range(num_slices_dim2): # Extract the desired slice from the template, density, and mask templateSlice = np.squeeze(templateNifti.get_fdata()[:, :, sliceIdx]).astype(np.uint16) densitySlice = np.squeeze(densityNifti.get_fdata()[:, :, sliceIdx]).astype(np.double) mask_data = maskNifti.get_fdata() mask_data = mask_data.swapaxes(0, 2) mask_data = np.flip(mask_data, axis=0) maskSlice = np.squeeze(mask_data[:, :, sliceIdx]) # Normalize the templateSlice to [0, 1] for visualization templateSlice = (templateSlice - np.min(templateSlice)) / (np.max(templateSlice) - np.min(templateSlice)) # Overlay the density and mask onto the template slice overlayedImage = templateSlice + densitySlice + maskSlice # Normalize the overlay to [0, 1] overlayedImage = (overlayedImage - np.min(overlayedImage)) / (np.max(overlayedImage) - np.min(overlayedImage)) # Convert the overlay image to RGBA format (4 channels for transparency) overlay_image_rgba = plt.cm.gray(overlayedImage) overlay_image_rgba[:, :, 3] = (overlayedImage * 255).astype(np.uint8) # Set alpha channel based on intensity # Update the 3D overlay volume overlay_volume[:, :, sliceIdx, :] = overlay_image_rgba # Create a figure for the 3D visualization fig = plt.figure(figsize=(10, 10), dpi=dpi) ax = fig.add_subplot(111, projection='3d') # Plot the 3D overlay volume ax.voxels(overlay_volume[:, :, :, 0] > 0, facecolors=overlay_volume[:, :, :, :3] / 255., edgecolor='k') # Set the aspect ratio to be equal for all dimensions ax.set_box_aspect([1, 1, 1]) # Show the 3D plot plt.show()