import nibabel as nib import numpy as np import matplotlib.pyplot as plt from scipy.ndimage import rotate # Define the file paths templateFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\input\average_template_50_from_website.nii.gz" densityFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\input\12_wks_coronal_100141780_50um_projection_density.nii.gz" maskFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\output\Nifti_Trakte_registered\CC_Mop_desnitymask_50um.nii.gz" # Load the NIfTI files templateNifti = nib.load(templateFilePath) densityNifti = nib.load(densityFilePath) maskNifti = nib.load(maskFilePath) # Ensure that both images have the same data type (e.g., cast density to the data type of template) densityNifti = nib.Nifti1Image(densityNifti.get_fdata().astype(templateNifti.get_fdata().dtype), densityNifti.affine) # Customize the figure DPI dpi = 400 # Get the number of slices along each dimension num_slices_dim0, num_slices_dim1, num_slices_dim2 = templateNifti.shape # Set the slice step for display (every tenth slice) slice_step = 10 # Choose the dimension for iteration (0, 1, or 2) iteration_dimension = 0 # Change this value to select the dimension # Calculate the number of rows and columns for the grid based on the selected dimension if iteration_dimension == 0: num_slices = num_slices_dim0 elif iteration_dimension == 1: num_slices = num_slices_dim1 else: num_slices = num_slices_dim2 grid_size = int(np.sqrt(num_slices // slice_step)) if grid_size * grid_size < num_slices // slice_step: grid_size += 1 cm = 1/2.54 # Adjust the figure size to be larger fig, axs = plt.subplots(grid_size, grid_size, figsize=(20, 20), dpi=dpi) # Hide axis numbers for ax_row in axs: for ax in ax_row: ax.axis('off') # Load the mask and perform necessary flips mask_data = np.squeeze(maskNifti.get_fdata()) # Swap the mask over the x and y axes mask_data = mask_data.swapaxes(0, 2) mask_data = np.flip(mask_data, axis=0) # Loop through and display every nth slice as subplots along the selected dimension for i, sliceIdx in enumerate(range(0, num_slices, slice_step)): row = i // grid_size col = i % grid_size # Extract the desired slice from the template and density based on the selected dimension if iteration_dimension == 0: templateSlice = np.squeeze(templateNifti.get_fdata()[sliceIdx, :, :]).astype(np.uint16) densitySlice = np.squeeze(densityNifti.get_fdata()[sliceIdx, :, :]).astype(np.double) maskSlice = np.squeeze(mask_data[sliceIdx, :, :]) elif iteration_dimension == 1: templateSlice = np.squeeze(templateNifti.get_fdata()[:, sliceIdx, :]).astype(np.uint16) densitySlice = np.squeeze(densityNifti.get_fdata()[:, sliceIdx, :]).astype(np.double) maskSlice = np.squeeze(mask_data[:, sliceIdx, :]) else: templateSlice = np.squeeze(templateNifti.get_fdata()[:, :, sliceIdx]).astype(np.uint16) densitySlice = np.squeeze(densityNifti.get_fdata()[:, :, sliceIdx]).astype(np.double) maskSlice = np.squeeze(mask_data[:, :, sliceIdx]) # Normalize the templateSlice to [0, 1] for visualization templateSlice = (templateSlice - np.min(templateSlice)) / (np.max(templateSlice) - np.min(templateSlice)) # Overlay the density and mask onto the template slice overlayedImage = templateSlice + densitySlice + maskSlice # Display the overlaid image in the corresponding subplot axs[row, col].imshow(templateSlice, cmap='gray') axs[row, col].imshow(densitySlice, cmap='hot', alpha=0.3) axs[row, col].imshow(maskSlice, cmap='cool', alpha=0.3) # Apply "hot" colormap to the mask # Apply "cool" colormap to the mask # Remove empty subplots if necessary for i in range(num_slices // slice_step, grid_size * grid_size): fig.delaxes(axs.flatten()[i]) # Adjust spacing between subplots plt.tight_layout() # Show the plot plt.show()