import argparse from nipype.interfaces import ants import nibabel as nib import numpy as np import os from skimage.filters import threshold_otsu import matplotlib.pyplot as plt from helper_functions import make_nii, make_dirs, normal_list_1d, mutual_information from pre_processing import pre_process from skimage import morphology from scipy.ndimage import gaussian_filter import sys def slice_template(template, template_seg, out_dir, pos, padding=5): # Load template data and header. template = nib.load(template) template_data = template.get_data() template_header = template.header # Load template segmentation data. template_seg = nib.load(template_seg) template_seg_data = template_seg.get_data() # Set pixel dimensions. slice_pixdim = [1., template_header['pixdim'][1], template_header['pixdim'][3], 1., 1., 1., 1., 1.] images = [] for i in range(pos-padding, pos+padding+1): output = out_dir + 'S' + str(i).zfill(2) + '.nii' seg_output = out_dir + 'S' + str(i).zfill(2) + '_seg.nii' if not os.path.exists(output): nib.save(make_nii(np.rot90(template_data[:, i, :], 2), pixdim=slice_pixdim), output) if not os.path.exists(seg_output): nib.save(make_nii(np.rot90(template_seg_data[:, i, :], 2), pixdim=slice_pixdim), seg_output) images.append(output) # Returns a list of image paths. return images def reg_slice(fixed, moving, output): registration = ants.Registration( dimension=2, transforms=['Affine'], transform_parameters=[(0.15, ), ], initial_moving_transform_com=1, metric=['MI'], interpolation='BSpline', fixed_image=[fixed], moving_image=[moving], metric_weight=[1], radius_or_number_of_bins=[32], number_of_iterations=[[1000, 1000, 1000, 1000, 1000]], smoothing_sigmas=[[4, 4, 4, 2, 2]], shrink_factors=[[64, 32, 16, 8, 4]], convergence_threshold=[1.e-7], sampling_strategy=['Regular'], use_histogram_matching=True, winsorize_lower_quantile=0.05, winsorize_upper_quantile=0.95, output_warped_image=output, output_transform_prefix=output.split('.')[0] + '-Transformation', num_threads=12 ) registration.run() def full_reg_slice(fixed, moving, output, seg): output_pre = output.split('.')[0] registration = ants.Registration( dimension=2, transforms=['Rigid', 'Affine', 'SyN'], transform_parameters=[(0.15,), (0.15, ), (0.3, 0.3, 0.5)], metric=['MeanSquares']*3, initial_moving_transform_com=1, interpolation='BSpline', fixed_image=[fixed], moving_image=[moving], metric_weight=[1]*3, radius_or_number_of_bins=[32]*3, number_of_iterations=[[1000, 1000, 1000, 1000], [1000, 1000, 1000, 1000, 1000], [1000, 1000, 1000, 1000, 1000, 1000]], smoothing_sigmas=[[4, 4, 2, 2], [4, 4, 4, 2, 2], [4, 4, 4, 2, 2, 2]], shrink_factors=[[32, 16, 8, 4], [64, 32, 16, 8, 4], [64, 32, 16, 8, 4, 2]], convergence_threshold=[1.e-7], sampling_strategy=['Regular']*3, use_histogram_matching=True, winsorize_lower_quantile=0.05, winsorize_upper_quantile=0.95, output_warped_image=output, output_transform_prefix=output_pre + '-Transformation-', num_threads=12 ) registration.run() # Define transformation to be applied on segmentation. Dependent on the system running ANTs. if os.path.exists(output_pre + '-Transformation-0GenericAffine.mat'): transform_paths = [output_pre + '-Transformation-1Warp.nii.gz', output_pre + '-Transformation-0GenericAffine.mat'] elif os.path.exists(output_pre + '-Transformation-0DerivedInitialMovingTranslation.mat'): transform_paths = [output_pre + '-Transformation-3Warp.nii.gz', output_pre + '-Transformation-2Affine.mat', output_pre + '-Transformation-1Rigid.mat', output_pre + '-Transformation-0DerivedInitialMovingTranslation.mat'] else: sys.exit('The ANTs-output transformation names/postfix was not recognised.') # Apply recognised transformations to slice segmentation. ants.ApplyTransforms( dimension=2, input_image=seg, reference_image=output, transforms=transform_paths, interpolation='NearestNeighbor', output_image=output.split('.')[0] + '_seg.nii', num_threads=12 ).run() def main(input_slice_path, template_path, template_seg_path, output_path, bregma, preprocess_input, inverse_output, padding=6, return_result_dict=False): # Directory variables. work_dir = os.path.dirname(output_path) + '/' temp_dir = work_dir + 'temp/' template_slice_dir = temp_dir + 'template_slices/' make_dirs([temp_dir, template_slice_dir]) # Inverse bool output_inverse = inverse_output != None print(output_inverse) # Slice template coor = int(9.7431 * bregma + 38.7044) template_slices = slice_template(template_path, template_seg_path, template_slice_dir, coor, padding=padding) # Pre-process input image. if preprocess_input in ['true', 'True']: pre_processed_output = input_slice_path.split('.')[0] + '.nii' pre_process(input_slice_path, pre_processed_output, print_log=True, shrink=1) slice_path = pre_processed_output else: slice_path = input_slice_path # Rigid registration registration_dir = temp_dir + 'rigid_registrations/' registered_slices = [] make_dirs(registration_dir) for template_slice in template_slices: name = template_slice.split('/')[-1].split('.')[0] registered_slice = registration_dir + name + '.nii' reg_slice(slice_path, template_slice, registered_slice) registered_slices.append(registered_slice) # Similarity extrema mutual_informations = [] input_nii = nib.load(slice_path) input_data = input_nii.get_data() input_data_smooth = gaussian_filter(input_data, 2) for image in registered_slices: nii = nib.load(image) data = nii.get_data() mutual_informations.append(mutual_information(input_data_smooth, data, 32)-0.5) # Weighted similarity. _, weights = normal_list_1d(10, len(mutual_informations), max_is_one=True) weighted_mi = [mi*w for mi, w in zip(mutual_informations, weights)] max_y = max(weighted_mi) max_index = weighted_mi.index(max_y) maximum_slice = template_slices[max_index] maximum_slice_seg = maximum_slice.split('.')[0] + '_seg.nii' # Save pdf of similarities, with minimum marked. f = plt.figure() range_min = coor - padding range_max = coor + padding + 1 max_x = range_min + max_index plt.plot(range(range_min, range_max), weighted_mi) plt.plot(range(range_min, range_max), mutual_informations) plt.plot(max_x, max_y, 'rx') plt.plot(range_min + mutual_informations.index(max(mutual_informations)), max(mutual_informations), 'rx') plt.title('Mutual information of template slice registered to input slice.') f.savefig(temp_dir + 'mutual_information.pdf') # Run full (rigid, affine and nonlinear) registration on selected template slice. # Transform segmentation accordingly. full_reg_slice(slice_path, maximum_slice, output_path, maximum_slice_seg) # Load segmentation, threshold, and close potential holes. slice_data = nib.load(slice_path).get_data() slice_seg_path = output_path.split('.nii')[0] + '_seg.nii' slice_seg_nii = nib.load(slice_seg_path) slice_seg_header = slice_seg_nii.header slice_seg_data = slice_seg_nii.get_data() # Calculate threshold. threshold = threshold_otsu(slice_data)/6 foreground = np.where(slice_data > threshold, 1, 0) foreground = morphology.binary_closing(foreground, selem=morphology.square(4)) slice_seg_data[foreground == 0] = 0 # Save threshold segmentation. nib.save(nib.Nifti1Image(slice_seg_data, None, slice_seg_header), output_path.split('.nii')[0] + '_seg_thres.nii') if return_result_dict: return {'image': slice_path, 'segmentation': slice_seg_path, 'threshold_segmentation': output_path.split('.nii')[0] + '_seg_thres.nii', 'index': max_x, 'bregma': round((max_x - 38.7044)/9.7431, 2)} if __name__ == '__main__': parser = argparse.ArgumentParser(description='Register a 2 dimensional coronal brain slice to a template.') parser.add_argument("-s", "--slice", required=True, metavar='\b', help="Existing path to the slice.") parser.add_argument("-t", "--template", required=True, metavar='\b', help="Existing path to the template.") parser.add_argument("-e", "--segmentation", required=True, metavar='\b', help="Existing path to the template segmentation.") parser.add_argument("-o", "--output", required=True, metavar='\b', help="Output.") parser.add_argument("-b", "--bregma", required=True, metavar='\b', help="Approximate bregma coordinate of slice.") parser.add_argument("-p", "--preprocess", required=True, metavar='\b', help="""Pre-process input image. Includes finding biggest connected component, normalizing and biasfieldcorrection.""") parser.add_argument("-i", "--inverse", required=True, metavar='\b', help="""Path to input slice inversely transformed to template space. Will not transform inversely if not set.""") args = parser.parse_args() main(args.slice, args.template, args.segmentation, args.output, args.bregma, args.preprocess, args.inverse)