Переглянути джерело

Merge remote-tracking branch 'origin/master'

lucaruthe 9 місяців тому
батько
коміт
c74fc4ebef
2 змінених файлів з 35 додано та 77 видалено
  1. 1 1
      code/overlayPlot.py
  2. 34 76
      code/overlayPlot3D.py

+ 1 - 1
code/overlayPlot.py

@@ -26,7 +26,7 @@ num_slices_dim0, num_slices_dim1, num_slices_dim2 = templateNifti.shape
 slice_step = 10
 
 # Choose the dimension for iteration (0, 1, or 2)
-iteration_dimension = 1  # Change this value to select the dimension
+iteration_dimension = 0  # Change this value to select the dimension
 
 # Calculate the number of rows and columns for the grid based on the selected dimension
 if iteration_dimension == 0:

+ 34 - 76
code/overlayPlot3D.py

@@ -7,13 +7,11 @@ from scipy.ndimage import rotate
 templateFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\input\average_template_50_from_website.nii.gz"
 densityFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\input\12_wks_coronal_100141780_50um_projection_density.nii.gz"
 maskFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\output\Nifti_Trakte_registered\CC_Mop_dilated_registered.nii.gz"
-tractFilePath = r"C:\Users\aswen\Desktop\Code\2024_Ruthe_SND\input\fiber_registered.nii.gz"
 
 # Load the NIfTI files
 templateNifti = nib.load(templateFilePath)
 densityNifti = nib.load(densityFilePath)
 maskNifti = nib.load(maskFilePath)
-tractNifti = nib.load(tractFilePath)
 
 # Ensure that both images have the same data type (e.g., cast density to the data type of template)
 densityNifti = nib.Nifti1Image(densityNifti.get_fdata().astype(templateNifti.get_fdata().dtype), densityNifti.affine)
@@ -24,86 +22,46 @@ dpi = 400
 # Get the number of slices along each dimension
 num_slices_dim0, num_slices_dim1, num_slices_dim2 = templateNifti.shape
 
-# Set the slice step for display (every tenth slice)
-slice_step = 10
-
-# Choose the dimension for iteration (0, 1, or 2)
-iteration_dimension = 0  # Change this value to select the dimension
-
-# Calculate the number of rows and columns for the grid based on the selected dimension
-if iteration_dimension == 0:
-    num_slices = num_slices_dim0
-elif iteration_dimension == 1:
-    num_slices = num_slices_dim1
-else:
-    num_slices = num_slices_dim2
-
-grid_size = int(np.sqrt(num_slices // slice_step))
-if grid_size * grid_size < num_slices // slice_step:
-    grid_size += 1
-
-cm = 1/2.54
-# Adjust the figure size to be larger
-fig, axs = plt.subplots(grid_size, grid_size, figsize=(20, 20), dpi=dpi)
-
-# Hide axis numbers
-for ax_row in axs:
-    for ax in ax_row:
-        ax.axis('off')
-
-# Load the mask and perform necessary flips
-mask_data = np.squeeze(maskNifti.get_fdata())
-# Swap the mask over the x and y axes
-mask_data = mask_data.swapaxes(0, 2)
-mask_data = np.flip(mask_data, axis=0)
-
-# Load the tract data and perform necessary flips
-tract_data0 = np.squeeze(tractNifti.get_fdata())
-tract_data = tract_data0[:,:,:,1]
-# Swap the tract over the x and y axes
-tract_data = tract_data.swapaxes(0, 2)
-tract_data = np.flip(tract_data, axis=0)
-
-# Loop through and display every nth slice as subplots along the selected dimension
-for i, sliceIdx in enumerate(range(0, num_slices, slice_step)):
-    row = i // grid_size
-    col = i % grid_size
-
-    # Extract the desired slice from the template, density, mask, and tract based on the selected dimension
-    if iteration_dimension == 0:
-        templateSlice = np.squeeze(templateNifti.get_fdata()[sliceIdx, :, :]).astype(np.uint16)
-        densitySlice = np.squeeze(densityNifti.get_fdata()[sliceIdx, :, :]).astype(np.double)
-        maskSlice = np.squeeze(mask_data[sliceIdx, :, :])
-        tractSlice = np.squeeze(tract_data[sliceIdx, :, :])
-    elif iteration_dimension == 1:
-        templateSlice = np.squeeze(templateNifti.get_fdata()[:, sliceIdx, :]).astype(np.uint16)
-        densitySlice = np.squeeze(densityNifti.get_fdata()[:, sliceIdx, :]).astype(np.double)
-        maskSlice = np.squeeze(mask_data[:, sliceIdx, :])
-        tractSlice = np.squeeze(tract_data[:, sliceIdx, :])
-    else:
-        templateSlice = np.squeeze(templateNifti.get_fdata()[:, :, sliceIdx]).astype(np.uint16)
-        densitySlice = np.squeeze(densityNifti.get_fdata()[:, :, sliceIdx]).astype(np.double)
-        maskSlice = np.squeeze(mask_data[:, :, sliceIdx])
-        tractSlice = np.squeeze(tract_data[:, :, sliceIdx])
+# Create a 3D array to hold the overlay
+overlay_volume = np.zeros((num_slices_dim0, num_slices_dim1, num_slices_dim2, 4), dtype=np.uint8)
+
+# Loop through and overlay all slices
+for sliceIdx in range(num_slices_dim2):
+    # Extract the desired slice from the template, density, and mask
+    templateSlice = np.squeeze(templateNifti.get_fdata()[:, :, sliceIdx]).astype(np.uint16)
+    densitySlice = np.squeeze(densityNifti.get_fdata()[:, :, sliceIdx]).astype(np.double)
+    
+    mask_data = maskNifti.get_fdata()
+    mask_data = mask_data.swapaxes(0, 2)
+    mask_data = np.flip(mask_data, axis=0)
+
+    maskSlice = np.squeeze(mask_data[:, :, sliceIdx])
 
     # Normalize the templateSlice to [0, 1] for visualization
     templateSlice = (templateSlice - np.min(templateSlice)) / (np.max(templateSlice) - np.min(templateSlice))
 
-    # Overlay the density, mask, and tract onto the template slice
-    overlayedImage = templateSlice + densitySlice + maskSlice + tractSlice
+    # Overlay the density and mask onto the template slice
+    overlayedImage = templateSlice + densitySlice + maskSlice
+
+    # Normalize the overlay to [0, 1]
+    overlayedImage = (overlayedImage - np.min(overlayedImage)) / (np.max(overlayedImage) - np.min(overlayedImage))
+
+    # Convert the overlay image to RGBA format (4 channels for transparency)
+    overlay_image_rgba = plt.cm.gray(overlayedImage)
+    overlay_image_rgba[:, :, 3] = (overlayedImage * 255).astype(np.uint8)  # Set alpha channel based on intensity
+
+    # Update the 3D overlay volume
+    overlay_volume[:, :, sliceIdx, :] = overlay_image_rgba
 
-    # Display the overlaid image in the corresponding subplot
-    #axs[row, col].imshow(templateSlice, cmap='gray')
-    #axs[row, col].imshow(densitySlice, cmap='hot', alpha=0.3)
-    #axs[row, col].imshow(maskSlice, cmap='YlGn', alpha=0.20, vmin=0.9, vmax=1)  # Apply "hot" colormap to the mask
-    axs[row, col].imshow(tractSlice, cmap='gray', alpha=1)  # Apply "Blues" colormap to the tract
+# Create a figure for the 3D visualization
+fig = plt.figure(figsize=(10, 10), dpi=dpi)
+ax = fig.add_subplot(111, projection='3d')
 
-# Remove empty subplots if necessary
-for i in range(num_slices // slice_step, grid_size * grid_size):
-    fig.delaxes(axs.flatten()[i])
+# Plot the 3D overlay volume
+ax.voxels(overlay_volume[:, :, :, 0] > 0, facecolors=overlay_volume[:, :, :, :3] / 255., edgecolor='k')
 
-# Adjust spacing between subplots
-plt.tight_layout()
+# Set the aspect ratio to be equal for all dimensions
+ax.set_box_aspect([1, 1, 1])
 
-# Show the plot
+# Show the 3D plot
 plt.show()