auto_seg.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # System imports
  2. from pathlib import Path
  3. from pprint import pprint
  4. import argparse
  5. import sys
  6. # 3rd Party imports
  7. import numpy as np
  8. import scipy.stats
  9. import nibabel as nib
  10. from nipype.interfaces import ants
  11. # Local source tree imports
  12. import registration
  13. import tools
  14. import preprocess as pre
  15. TEMP_DIR = Path('temp')
  16. OUTPUT_DIR = Path('output')
  17. IMG_DIR = Path('Images')
  18. def find_optimal_index(slice_path, template_path, approx, dist=2):
  19. """
  20. Finds the most appropriate index in template path to register the slice file to.
  21. :param approx: Approximate index of slice relative to template file
  22. :param dist: maximum distance of slices away from approximate to index to check
  23. :return: index of slice in template which has highest metric - combined mutual information and dice coefficient.
  24. """
  25. weights = []
  26. for i in range(approx-dist,approx+dist+1):
  27. print(f'Slice index: {i}')
  28. _slice = nib.load(str(slice_path))
  29. template_slice_loc = str(TEMP_DIR/f't_slice{i}.nii')
  30. registered_loc = str(TEMP_DIR/ f'reg_slice{i}.nii')
  31. transform_prefix = str(TEMP_DIR/f'{i}-')
  32. tools.save_slice(str(template_path), template_slice_loc, i, np.eye(4), _slice.header)
  33. registration.rigid_registration(template_slice_loc, str(slice_path), registered_loc, transform_prefix)
  34. template_data = nib.load(str(template_slice_loc)).get_data()
  35. registered_data = nib.load(registered_loc).get_data()
  36. registered_data = tools.resize_im(registered_data, template_data)
  37. nib.save(nib.Nifti1Image(registered_data, np.eye(4)), registered_loc)
  38. mutualinfo = tools.mutual_info_mask(template_data, registered_data)
  39. norm_factor = scipy.stats.norm.pdf(i-approx+dist, dist, dist*2) / scipy.stats.norm.pdf(dist, dist, dist*2)
  40. dice_coef = tools.dice_coef(template_data, registered_data)
  41. weights.append((i, norm_factor * (0.7 * mutualinfo + 0.3 * dice_coef)))
  42. pprint(weights)
  43. optimal = max(weights, key=lambda a: a[1])
  44. return optimal[0]
  45. def apply_transform(segmentation, fixed, index, out_dir):
  46. """
  47. Full transforms the slice->template slice, and then applies the reverse transform to the segmentation slice, to get
  48. a segmentation of the slice.
  49. :return: Location of the segmentation slice registered to the slice
  50. """
  51. seg_slice_loc = out_dir / "segment_slice.nii"
  52. tools.save_slice(str(segmentation), str(seg_slice_loc), index, np.eye(4), nib.load(str(fixed)).header)
  53. post_transform_loc = out_dir / f'{seg_slice_loc.stem}_t.nii'
  54. template_slice_loc = str(TEMP_DIR / f't_slice{index}.nii')
  55. registration.full_registration(str(template_slice_loc), str(fixed),
  56. str(Path(out_dir, f'{fixed.stem}_t.nii')), str(out_dir / f'Final-'))
  57. transform = [str(out_dir / f"Final-InverseComposite.h5")]
  58. reverse_transform = ants.ApplyTransforms(
  59. input_image=str(seg_slice_loc),
  60. reference_image=str(fixed),
  61. transforms=transform,
  62. invert_transform_flags=[False],
  63. interpolation='MultiLabel',
  64. dimension=2,
  65. output_image=str(post_transform_loc))
  66. reverse_transform.run()
  67. return post_transform_loc
  68. def generate_segmentation(slice_path, segment_path, template_path, approx_index, dapi):
  69. """
  70. Provided path to slice, segmentation, template and an approximate index, will apply multiple registrations to create
  71. an automatic segmentation of the slice
  72. :param dapi: index of DAPI modality channel
  73. :return: Location of the segmentation slice registered to the slice
  74. """
  75. out_subdir = Path(OUT_DIR) / Path(slice_path).stem.replace(' ', '').replace(',', '').replace('.', '')
  76. tools.create_dir(TEMP_DIR)
  77. tools.create_dir(out_subdir)
  78. channel_paths = pre.split_nii_channels(slice_path, out_subdir, False, dapi-1)
  79. optimal_index = find_optimal_index(channel_paths[dapi-1], template_path, approx_index)
  80. seg_loc = apply_transform(segment_path, channel_paths[dapi-1], optimal_index, out_subdir)
  81. tools.remove_dir(TEMP_DIR)
  82. return seg_loc
  83. if __name__ == "__main__":
  84. parser = argparse.ArgumentParser()
  85. parser.add_argument("sliceloc", metavar="Slice Location", help="Location of preprocessed slice")
  86. parser.add_argument("segloc", metavar="Segmentation Location", help="Location of segmentation")
  87. parser.add_argument("tloc", metavar="Template Location", help="Location of template file")
  88. parser.add_argument("--approx", type=int, default=-1, help="Approximate index of slice relative to template file")
  89. parser.add_argument("--bregma", type=float, help="Approx bregma coordinates")
  90. parser.add_argument("--out", default="output", help="Output directory")
  91. parser.add_argument("--dapi", type=int, default=0, help="DAPI channel number, default is last channel")
  92. args = parser.parse_args()
  93. OUT_DIR = Path(args.out)
  94. if args.approx == -1 and not args.bregma:
  95. print("Error: please specify an approximate location in the template, or a bregma coordinate of the slice")
  96. sys.exit(2)
  97. if args.bregma:
  98. args.approx = tools.bregma_to_slice_index(args.bregma)
  99. generate_segmentation(args.sliceloc, args.segloc, args.tloc, args.approx, args.dapi)