slice_to_template.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import argparse
  2. from nipype.interfaces import ants
  3. import nibabel as nib
  4. import numpy as np
  5. import os
  6. from skimage.filters import threshold_otsu
  7. import matplotlib.pyplot as plt
  8. from helper_functions import make_nii, make_dirs, normal_list_1d, mutual_information
  9. from pre_processing import pre_process
  10. from skimage import morphology
  11. from scipy.ndimage import gaussian_filter
  12. import sys
  13. def slice_template(template, template_seg, out_dir, pos, padding=5):
  14. # Load template data and header.
  15. template = nib.load(template)
  16. template_data = template.get_data()
  17. template_header = template.header
  18. # Load template segmentation data.
  19. template_seg = nib.load(template_seg)
  20. template_seg_data = template_seg.get_data()
  21. # Set pixel dimensions.
  22. slice_pixdim = [1., template_header['pixdim'][1], template_header['pixdim'][3], 1., 1., 1., 1., 1.]
  23. images = []
  24. for i in range(pos-padding, pos+padding+1):
  25. output = out_dir + 'S' + str(i).zfill(2) + '.nii'
  26. seg_output = out_dir + 'S' + str(i).zfill(2) + '_seg.nii'
  27. if not os.path.exists(output):
  28. nib.save(make_nii(np.rot90(template_data[:, i, :], 2), pixdim=slice_pixdim), output)
  29. if not os.path.exists(seg_output):
  30. nib.save(make_nii(np.rot90(template_seg_data[:, i, :], 2), pixdim=slice_pixdim), seg_output)
  31. images.append(output)
  32. # Returns a list of image paths.
  33. return images
  34. def reg_slice(fixed, moving, output):
  35. registration = ants.Registration(
  36. dimension=2,
  37. transforms=['Affine'],
  38. transform_parameters=[(0.15, ), ],
  39. initial_moving_transform_com=1,
  40. metric=['MI'],
  41. interpolation='BSpline',
  42. fixed_image=[fixed],
  43. moving_image=[moving],
  44. metric_weight=[1],
  45. radius_or_number_of_bins=[32],
  46. number_of_iterations=[[1000, 1000, 1000, 1000, 1000]],
  47. smoothing_sigmas=[[4, 4, 4, 2, 2]],
  48. shrink_factors=[[64, 32, 16, 8, 4]],
  49. convergence_threshold=[1.e-7],
  50. sampling_strategy=['Regular'],
  51. use_histogram_matching=True,
  52. winsorize_lower_quantile=0.05,
  53. winsorize_upper_quantile=0.95,
  54. output_warped_image=output,
  55. output_transform_prefix=output.split('.')[0] + '-Transformation',
  56. num_threads=12
  57. )
  58. registration.run()
  59. def full_reg_slice(fixed, moving, output, seg):
  60. output_pre = output.split('.')[0]
  61. registration = ants.Registration(
  62. dimension=2,
  63. transforms=['Rigid', 'Affine', 'SyN'],
  64. transform_parameters=[(0.15,), (0.15, ), (0.3, 0.3, 0.5)],
  65. metric=['MeanSquares']*3,
  66. initial_moving_transform_com=1,
  67. interpolation='BSpline',
  68. fixed_image=[fixed],
  69. moving_image=[moving],
  70. metric_weight=[1]*3,
  71. radius_or_number_of_bins=[32]*3,
  72. number_of_iterations=[[1000, 1000, 1000, 1000], [1000, 1000, 1000, 1000, 1000], [1000, 1000, 1000, 1000, 1000, 1000]],
  73. smoothing_sigmas=[[4, 4, 2, 2], [4, 4, 4, 2, 2], [4, 4, 4, 2, 2, 2]],
  74. shrink_factors=[[32, 16, 8, 4], [64, 32, 16, 8, 4], [64, 32, 16, 8, 4, 2]],
  75. convergence_threshold=[1.e-7],
  76. sampling_strategy=['Regular']*3,
  77. use_histogram_matching=True,
  78. winsorize_lower_quantile=0.05,
  79. winsorize_upper_quantile=0.95,
  80. output_warped_image=output,
  81. output_transform_prefix=output_pre + '-Transformation-',
  82. num_threads=12
  83. )
  84. registration.run()
  85. # Define transformation to be applied on segmentation. Dependent on the system running ANTs.
  86. if os.path.exists(output_pre + '-Transformation-0GenericAffine.mat'):
  87. transform_paths = [output_pre + '-Transformation-1Warp.nii.gz',
  88. output_pre + '-Transformation-0GenericAffine.mat']
  89. elif os.path.exists(output_pre + '-Transformation-0DerivedInitialMovingTranslation.mat'):
  90. transform_paths = [output_pre + '-Transformation-3Warp.nii.gz',
  91. output_pre + '-Transformation-2Affine.mat',
  92. output_pre + '-Transformation-1Rigid.mat',
  93. output_pre + '-Transformation-0DerivedInitialMovingTranslation.mat']
  94. else:
  95. sys.exit('The ANTs-output transformation names/postfix was not recognised.')
  96. # Apply recognised transformations to slice segmentation.
  97. ants.ApplyTransforms(
  98. dimension=2,
  99. input_image=seg,
  100. reference_image=output,
  101. transforms=transform_paths,
  102. interpolation='NearestNeighbor',
  103. output_image=output.split('.')[0] + '_seg.nii',
  104. num_threads=12
  105. ).run()
  106. def main(input_slice_path, template_path, template_seg_path, output_path,
  107. bregma, preprocess_input, inverse_output, padding=6, return_result_dict=False):
  108. # Directory variables.
  109. work_dir = os.path.dirname(output_path) + '/'
  110. temp_dir = work_dir + 'temp/'
  111. template_slice_dir = temp_dir + 'template_slices/'
  112. make_dirs([temp_dir, template_slice_dir])
  113. # Inverse bool
  114. output_inverse = inverse_output != None
  115. print(output_inverse)
  116. # Slice template
  117. coor = int(9.7431 * bregma + 38.7044)
  118. template_slices = slice_template(template_path, template_seg_path, template_slice_dir, coor, padding=padding)
  119. # Pre-process input image.
  120. if preprocess_input in ['true', 'True']:
  121. pre_processed_output = input_slice_path.split('.')[0] + '.nii'
  122. pre_process(input_slice_path, pre_processed_output, print_log=True, shrink=1)
  123. slice_path = pre_processed_output
  124. else:
  125. slice_path = input_slice_path
  126. # Rigid registration
  127. registration_dir = temp_dir + 'rigid_registrations/'
  128. registered_slices = []
  129. make_dirs(registration_dir)
  130. for template_slice in template_slices:
  131. name = template_slice.split('/')[-1].split('.')[0]
  132. registered_slice = registration_dir + name + '.nii'
  133. reg_slice(slice_path, template_slice, registered_slice)
  134. registered_slices.append(registered_slice)
  135. # Similarity extrema
  136. mutual_informations = []
  137. input_nii = nib.load(slice_path)
  138. input_data = input_nii.get_data()
  139. input_data_smooth = gaussian_filter(input_data, 2)
  140. for image in registered_slices:
  141. nii = nib.load(image)
  142. data = nii.get_data()
  143. mutual_informations.append(mutual_information(input_data_smooth, data, 32)-0.5)
  144. # Weighted similarity.
  145. _, weights = normal_list_1d(10, len(mutual_informations), max_is_one=True)
  146. weighted_mi = [mi*w for mi, w in zip(mutual_informations, weights)]
  147. max_y = max(weighted_mi)
  148. max_index = weighted_mi.index(max_y)
  149. maximum_slice = template_slices[max_index]
  150. maximum_slice_seg = maximum_slice.split('.')[0] + '_seg.nii'
  151. # Save pdf of similarities, with minimum marked.
  152. f = plt.figure()
  153. range_min = coor - padding
  154. range_max = coor + padding + 1
  155. max_x = range_min + max_index
  156. plt.plot(range(range_min, range_max), weighted_mi)
  157. plt.plot(range(range_min, range_max), mutual_informations)
  158. plt.plot(max_x, max_y, 'rx')
  159. plt.plot(range_min + mutual_informations.index(max(mutual_informations)), max(mutual_informations), 'rx')
  160. plt.title('Mutual information of template slice registered to input slice.')
  161. f.savefig(temp_dir + 'mutual_information.pdf')
  162. # Run full (rigid, affine and nonlinear) registration on selected template slice.
  163. # Transform segmentation accordingly.
  164. full_reg_slice(slice_path, maximum_slice, output_path, maximum_slice_seg)
  165. # Load segmentation, threshold, and close potential holes.
  166. slice_data = nib.load(slice_path).get_data()
  167. slice_seg_path = output_path.split('.nii')[0] + '_seg.nii'
  168. slice_seg_nii = nib.load(slice_seg_path)
  169. slice_seg_header = slice_seg_nii.header
  170. slice_seg_data = slice_seg_nii.get_data()
  171. # Calculate threshold.
  172. threshold = threshold_otsu(slice_data)/6
  173. foreground = np.where(slice_data > threshold, 1, 0)
  174. foreground = morphology.binary_closing(foreground, selem=morphology.square(4))
  175. slice_seg_data[foreground == 0] = 0
  176. # Save threshold segmentation.
  177. nib.save(nib.Nifti1Image(slice_seg_data, None, slice_seg_header), output_path.split('.nii')[0] + '_seg_thres.nii')
  178. if return_result_dict:
  179. return {'image': slice_path,
  180. 'segmentation': slice_seg_path,
  181. 'threshold_segmentation': output_path.split('.nii')[0] + '_seg_thres.nii',
  182. 'index': max_x,
  183. 'bregma': round((max_x - 38.7044)/9.7431, 2)}
  184. if __name__ == '__main__':
  185. parser = argparse.ArgumentParser(description='Register a 2 dimensional coronal brain slice to a template.')
  186. parser.add_argument("-s", "--slice", required=True, metavar='\b', help="Existing path to the slice.")
  187. parser.add_argument("-t", "--template", required=True, metavar='\b', help="Existing path to the template.")
  188. parser.add_argument("-e", "--segmentation", required=True, metavar='\b', help="Existing path to the template segmentation.")
  189. parser.add_argument("-o", "--output", required=True, metavar='\b', help="Output.")
  190. parser.add_argument("-b", "--bregma", required=True, metavar='\b', help="Approximate bregma coordinate of slice.")
  191. parser.add_argument("-p", "--preprocess", required=True, metavar='\b', help="""Pre-process input image. Includes
  192. finding biggest connected component, normalizing and biasfieldcorrection.""")
  193. parser.add_argument("-i", "--inverse", required=True, metavar='\b', help="""Path to input slice inversely
  194. transformed to template space. Will not transform inversely if not set.""")
  195. args = parser.parse_args()
  196. main(args.slice, args.template, args.segmentation, args.output, args.bregma, args.preprocess, args.inverse)