123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- import nibabel as nib
- import numpy as np
- import matplotlib.pyplot as plt
- from scipy.ndimage import rotate
- # Define the file paths
- templateFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\input\average_template_50_from_website.nii.gz"
- densityFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\input\12_wks_coronal_100141780_50um_projection_density.nii.gz"
- maskFilePath = r"\\10.209.5.114\Projects\Student_projects\14_Aref_Kalantari_2021\Luca\CC_Mop_desnitymap_50um.nii.gz"
- # Load the NIfTI files
- templateNifti = nib.load(templateFilePath)
- densityNifti = nib.load(densityFilePath)
- maskNifti = nib.load(maskFilePath)
- # Ensure that both images have the same data type (e.g., cast density to the data type of template)
- densityNifti = nib.Nifti1Image(densityNifti.get_fdata().astype(templateNifti.get_fdata().dtype), densityNifti.affine)
- # Customize the figure DPI
- dpi = 400
- # Get the number of slices along each dimension
- num_slices_dim0, num_slices_dim1, num_slices_dim2 = templateNifti.shape
- # Set the slice step for display (every tenth slice)
- slice_step = 10
- # Choose the dimension for iteration (0, 1, or 2)
- iteration_dimension = 0 # Change this value to select the dimension
- # Calculate the number of rows and columns for the grid based on the selected dimension
- if iteration_dimension == 0:
- num_slices = num_slices_dim0
- elif iteration_dimension == 1:
- num_slices = num_slices_dim1
- else:
- num_slices = num_slices_dim2
- grid_size = int(np.sqrt(num_slices // slice_step))
- if grid_size * grid_size < num_slices // slice_step:
- grid_size += 1
- cm = 1/2.54
- # Adjust the figure size to be larger
- fig, axs = plt.subplots(grid_size, grid_size, figsize=(20, 20), dpi=dpi)
- # Hide axis numbers
- for ax_row in axs:
- for ax in ax_row:
- ax.axis('off')
- # Load the mask and perform necessary flips
- mask_data = np.squeeze(maskNifti.get_fdata())
- # Swap the mask over the x and y axes
- mask_data = mask_data.swapaxes(0, 2)
- mask_data = np.flip(mask_data, axis=0)
- # Create a new NIfTI image with the adjusted mask data
- adjusted_mask_nifti = nib.Nifti1Image(mask_data, maskNifti.affine)
- # Save the new NIfTI image with adjusted header
- output_mask_filepath = r"\\10.209.5.114\Projects\Student_projects\14_Aref_Kalantari_2021\Luca\CC_Mop_desnitymap_50um_reoriented.nii.gz" # Specify the desired output file path
- nib.save(adjusted_mask_nifti, output_mask_filepath)
- # Loop through and display every nth slice as subplots along the selected dimension
- for i, sliceIdx in enumerate(range(0, num_slices, slice_step)):
- row = i // grid_size
- col = i % grid_size
- # Extract the desired slice from the template and density based on the selected dimension
- if iteration_dimension == 0:
- templateSlice = np.squeeze(templateNifti.get_fdata()[sliceIdx, :, :]).astype(np.uint16)
- densitySlice = np.squeeze(densityNifti.get_fdata()[sliceIdx, :, :]).astype(np.double)
- maskSlice = np.squeeze(mask_data[sliceIdx, :, :])
- elif iteration_dimension == 1:
- templateSlice = np.squeeze(templateNifti.get_fdata()[:, sliceIdx, :]).astype(np.uint16)
- densitySlice = np.squeeze(densityNifti.get_fdata()[:, sliceIdx, :]).astype(np.double)
- maskSlice = np.squeeze(mask_data[:, sliceIdx, :])
- else:
- templateSlice = np.squeeze(templateNifti.get_fdata()[:, :, sliceIdx]).astype(np.uint16)
- densitySlice = np.squeeze(densityNifti.get_fdata()[:, :, sliceIdx]).astype(np.double)
- maskSlice = np.squeeze(mask_data[:, :, sliceIdx])
- # Normalize the templateSlice to [0, 1] for visualization
- templateSlice = (templateSlice - np.min(templateSlice)) / (np.max(templateSlice) - np.min(templateSlice))
- # Overlay the density and mask onto the template slice
- overlayedImage = templateSlice + densitySlice + maskSlice
- # Display the overlaid image in the corresponding subplot
- axs[row, col].imshow(templateSlice, cmap='gray')
- axs[row, col].imshow(densitySlice, cmap='hot', alpha=0.4)
- axs[row, col].imshow(maskSlice, cmap='hot', alpha=0.2) # Apply "hot" colormap to the mask
- # Apply "cool" colormap to the mask
- # Remove empty subplots if necessary
- for i in range(num_slices // slice_step, grid_size * grid_size):
- fig.delaxes(axs.flatten()[i])
- # Adjust spacing between subplots
- plt.tight_layout()
- # Show the plot
- plt.show()
|