|
@@ -7,13 +7,11 @@ from scipy.ndimage import rotate
|
|
|
templateFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\input\average_template_50_from_website.nii.gz"
|
|
|
densityFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\input\12_wks_coronal_100141780_50um_projection_density.nii.gz"
|
|
|
maskFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\output\Nifti_Trakte_registered\CC_Mop_dilated_registered.nii.gz"
|
|
|
-tractFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\input\fiber_registered.nii.gz"
|
|
|
|
|
|
# Load the NIfTI files
|
|
|
templateNifti = nib.load(templateFilePath)
|
|
|
densityNifti = nib.load(densityFilePath)
|
|
|
maskNifti = nib.load(maskFilePath)
|
|
|
-tractNifti = nib.load(tractFilePath)
|
|
|
|
|
|
# Ensure that both images have the same data type (e.g., cast density to the data type of template)
|
|
|
densityNifti = nib.Nifti1Image(densityNifti.get_fdata().astype(templateNifti.get_fdata().dtype), densityNifti.affine)
|
|
@@ -24,86 +22,46 @@ dpi = 400
|
|
|
# Get the number of slices along each dimension
|
|
|
num_slices_dim0, num_slices_dim1, num_slices_dim2 = templateNifti.shape
|
|
|
|
|
|
-# Set the slice step for display (every tenth slice)
|
|
|
-slice_step = 10
|
|
|
-
|
|
|
-# Choose the dimension for iteration (0, 1, or 2)
|
|
|
-iteration_dimension = 0 # Change this value to select the dimension
|
|
|
-
|
|
|
-# Calculate the number of rows and columns for the grid based on the selected dimension
|
|
|
-if iteration_dimension == 0:
|
|
|
- num_slices = num_slices_dim0
|
|
|
-elif iteration_dimension == 1:
|
|
|
- num_slices = num_slices_dim1
|
|
|
-else:
|
|
|
- num_slices = num_slices_dim2
|
|
|
-
|
|
|
-grid_size = int(np.sqrt(num_slices // slice_step))
|
|
|
-if grid_size * grid_size < num_slices // slice_step:
|
|
|
- grid_size += 1
|
|
|
-
|
|
|
-cm = 1/2.54
|
|
|
-# Adjust the figure size to be larger
|
|
|
-fig, axs = plt.subplots(grid_size, grid_size, figsize=(20, 20), dpi=dpi)
|
|
|
-
|
|
|
-# Hide axis numbers
|
|
|
-for ax_row in axs:
|
|
|
- for ax in ax_row:
|
|
|
- ax.axis('off')
|
|
|
-
|
|
|
-# Load the mask and perform necessary flips
|
|
|
-mask_data = np.squeeze(maskNifti.get_fdata())
|
|
|
-# Swap the mask over the x and y axes
|
|
|
-mask_data = mask_data.swapaxes(0, 2)
|
|
|
-mask_data = np.flip(mask_data, axis=0)
|
|
|
-
|
|
|
-# Load the tract data and perform necessary flips
|
|
|
-tract_data0 = np.squeeze(tractNifti.get_fdata())
|
|
|
-tract_data = tract_data0[:,:,:,1]
|
|
|
-# Swap the tract over the x and y axes
|
|
|
-tract_data = tract_data.swapaxes(0, 2)
|
|
|
-tract_data = np.flip(tract_data, axis=0)
|
|
|
-
|
|
|
-# Loop through and display every nth slice as subplots along the selected dimension
|
|
|
-for i, sliceIdx in enumerate(range(0, num_slices, slice_step)):
|
|
|
- row = i // grid_size
|
|
|
- col = i % grid_size
|
|
|
-
|
|
|
- # Extract the desired slice from the template, density, mask, and tract based on the selected dimension
|
|
|
- if iteration_dimension == 0:
|
|
|
- templateSlice = np.squeeze(templateNifti.get_fdata()[sliceIdx, :, :]).astype(np.uint16)
|
|
|
- densitySlice = np.squeeze(densityNifti.get_fdata()[sliceIdx, :, :]).astype(np.double)
|
|
|
- maskSlice = np.squeeze(mask_data[sliceIdx, :, :])
|
|
|
- tractSlice = np.squeeze(tract_data[sliceIdx, :, :])
|
|
|
- elif iteration_dimension == 1:
|
|
|
- templateSlice = np.squeeze(templateNifti.get_fdata()[:, sliceIdx, :]).astype(np.uint16)
|
|
|
- densitySlice = np.squeeze(densityNifti.get_fdata()[:, sliceIdx, :]).astype(np.double)
|
|
|
- maskSlice = np.squeeze(mask_data[:, sliceIdx, :])
|
|
|
- tractSlice = np.squeeze(tract_data[:, sliceIdx, :])
|
|
|
- else:
|
|
|
- templateSlice = np.squeeze(templateNifti.get_fdata()[:, :, sliceIdx]).astype(np.uint16)
|
|
|
- densitySlice = np.squeeze(densityNifti.get_fdata()[:, :, sliceIdx]).astype(np.double)
|
|
|
- maskSlice = np.squeeze(mask_data[:, :, sliceIdx])
|
|
|
- tractSlice = np.squeeze(tract_data[:, :, sliceIdx])
|
|
|
+# Create a 3D array to hold the overlay
|
|
|
+overlay_volume = np.zeros((num_slices_dim0, num_slices_dim1, num_slices_dim2, 4), dtype=np.uint8)
|
|
|
+
|
|
|
+# Loop through and overlay all slices
|
|
|
+for sliceIdx in range(num_slices_dim2):
|
|
|
+ # Extract the desired slice from the template, density, and mask
|
|
|
+ templateSlice = np.squeeze(templateNifti.get_fdata()[:, :, sliceIdx]).astype(np.uint16)
|
|
|
+ densitySlice = np.squeeze(densityNifti.get_fdata()[:, :, sliceIdx]).astype(np.double)
|
|
|
+
|
|
|
+ mask_data = maskNifti.get_fdata()
|
|
|
+ mask_data = mask_data.swapaxes(0, 2)
|
|
|
+ mask_data = np.flip(mask_data, axis=0)
|
|
|
+
|
|
|
+ maskSlice = np.squeeze(mask_data[:, :, sliceIdx])
|
|
|
|
|
|
# Normalize the templateSlice to [0, 1] for visualization
|
|
|
templateSlice = (templateSlice - np.min(templateSlice)) / (np.max(templateSlice) - np.min(templateSlice))
|
|
|
|
|
|
- # Overlay the density, mask, and tract onto the template slice
|
|
|
- overlayedImage = templateSlice + densitySlice + maskSlice + tractSlice
|
|
|
+ # Overlay the density and mask onto the template slice
|
|
|
+ overlayedImage = templateSlice + densitySlice + maskSlice
|
|
|
+
|
|
|
+ # Normalize the overlay to [0, 1]
|
|
|
+ overlayedImage = (overlayedImage - np.min(overlayedImage)) / (np.max(overlayedImage) - np.min(overlayedImage))
|
|
|
+
|
|
|
+ # Convert the overlay image to RGBA format (4 channels for transparency)
|
|
|
+ overlay_image_rgba = plt.cm.gray(overlayedImage)
|
|
|
+ overlay_image_rgba[:, :, 3] = (overlayedImage * 255).astype(np.uint8) # Set alpha channel based on intensity
|
|
|
+
|
|
|
+ # Update the 3D overlay volume
|
|
|
+ overlay_volume[:, :, sliceIdx, :] = overlay_image_rgba
|
|
|
|
|
|
- # Display the overlaid image in the corresponding subplot
|
|
|
- #axs[row, col].imshow(templateSlice, cmap='gray')
|
|
|
- #axs[row, col].imshow(densitySlice, cmap='hot', alpha=0.3)
|
|
|
- #axs[row, col].imshow(maskSlice, cmap='YlGn', alpha=0.20, vmin=0.9, vmax=1) # Apply "hot" colormap to the mask
|
|
|
- axs[row, col].imshow(tractSlice, cmap='gray', alpha=1) # Apply "Blues" colormap to the tract
|
|
|
+# Create a figure for the 3D visualization
|
|
|
+fig = plt.figure(figsize=(10, 10), dpi=dpi)
|
|
|
+ax = fig.add_subplot(111, projection='3d')
|
|
|
|
|
|
-# Remove empty subplots if necessary
|
|
|
-for i in range(num_slices // slice_step, grid_size * grid_size):
|
|
|
- fig.delaxes(axs.flatten()[i])
|
|
|
+# Plot the 3D overlay volume
|
|
|
+ax.voxels(overlay_volume[:, :, :, 0] > 0, facecolors=overlay_volume[:, :, :, :3] / 255., edgecolor='k')
|
|
|
|
|
|
-# Adjust spacing between subplots
|
|
|
-plt.tight_layout()
|
|
|
+# Set the aspect ratio to be equal for all dimensions
|
|
|
+ax.set_box_aspect([1, 1, 1])
|
|
|
|
|
|
-# Show the plot
|
|
|
+# Show the 3D plot
|
|
|
plt.show()
|