runner.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from pathlib import Path
  2. from pprint import pprint
  3. import argparse
  4. import sys
  5. import os
  6. import numpy as np
  7. import scipy.stats
  8. import nibabel as nib
  9. from nipype.interfaces import ants
  10. import registration
  11. import tools
  12. import preprocess
  13. from PIL import Image, ImageSequence
  14. import warnings
  15. warnings.simplefilter(action='ignore', category=FutureWarning)
  16. def find_optimal_index(slice_path, template_path, approx, temporary_directory, dist = 2):
  17. weights = []
  18. for i in range(approx-dist,approx+dist+1):
  19. print(f'Slice index: {i}')
  20. _slice = nib.load(str(slice_path))
  21. template_slice_loc = str(temporary_directory / f't_slice{i}.nii')
  22. registered_loc = str(temporary_directory / f'reg_slice{i}.nii')
  23. transform_prefix = str(temporary_directory / f'{i}-')
  24. tools.save_slice(str(template_path), template_slice_loc, i, np.eye(4), _slice.header)
  25. registration.rigid_registration(template_slice_loc, str(slice_path), registered_loc, transform_prefix)
  26. template_data = nib.load(str(template_slice_loc)).get_fdata()
  27. registered_data = nib.load(registered_loc).get_fdata()
  28. registered_data = tools.resize_im(registered_data, template_data)
  29. nib.save(nib.Nifti1Image(registered_data, np.eye(4)), registered_loc)
  30. mutualinfo = tools.mutual_info_mask(template_data, registered_data)
  31. norm_factor = scipy.stats.norm.pdf(i-approx+dist, dist, dist*10) / scipy.stats.norm.pdf(dist, dist, dist*10)
  32. dice_coef = tools.dice_coef(template_data, registered_data)
  33. weights.append((i, norm_factor * (0.7 * mutualinfo + 0.3 * dice_coef)))
  34. pprint(weights)
  35. optimal = max(weights, key=lambda a: a[1])
  36. print(optimal[0])
  37. return optimal[0]
  38. def apply_transform(segmentation, fixed, index, out_dir, temporary_directory):
  39. seg_slice_loc = out_dir / "segment_slice.nii"
  40. tools.save_slice(str(segmentation), str(seg_slice_loc), index, np.eye(4), nib.load(str(fixed)).header)
  41. post_transform_loc = out_dir / f'Segmentation.nii'
  42. template_slice_loc = str(temporary_directory / f't_slice{index}.nii')
  43. registration.full_registration(str(template_slice_loc), str(fixed),
  44. str(Path(out_dir, f'{fixed.stem}_template.nii')), str(out_dir / f'TemplateSlice_'))
  45. transform = [str(out_dir / f"TemplateSlice_InverseComposite.h5")]
  46. reverse_transform = ants.ApplyTransforms(
  47. input_image=str(seg_slice_loc),
  48. reference_image=str(fixed),
  49. transforms=transform,
  50. invert_transform_flags=[False],
  51. interpolation='MultiLabel',
  52. dimension=2,
  53. output_image=str(post_transform_loc))
  54. reverse_transform.run()
  55. return post_transform_loc
  56. def generate_segmentation(slice_path, segment_path, template_path, approx_index, dapi, output_directory):
  57. tools.create_dir(output_directory)
  58. if slice_path.endswith('.tiff') or slice_path.endswith('.tif'):
  59. Image.MAX_IMAGE_PIXELS = None
  60. tif_path = preprocess.tiff_to_nii(slice_path, output_directory)
  61. slice_path = tif_path[0]
  62. elif not slice_path.endswith(".nii"):
  63. print('Selected slice file was neither a nifti or a TIFF file. Please check slice file format')
  64. sys.exit(2)
  65. out_subdir = Path(output_directory) / Path(slice_path).stem.replace(' ', '_').replace(',','_').replace('.','')
  66. tools.create_dir(out_subdir)
  67. temporary_directory = Path(out_subdir) / 'tmp'
  68. tools.create_dir(temporary_directory)
  69. channel_paths = preprocess.split_nii_channels(slice_path, out_subdir, True, dapi-1)
  70. optimal_index = find_optimal_index(channel_paths[dapi-1], template_path, approx_index, temporary_directory)
  71. seg_loc = apply_transform(segment_path, channel_paths[dapi-1], optimal_index, out_subdir, temporary_directory)
  72. tools.remove_dir(temporary_directory)
  73. os.remove(Path(out_subdir) / 'segment_slice.nii')
  74. os.rename(Path(out_subdir) / 'TemplateSlice_Composite.h5', Path(out_subdir) / 'Transformation_Composite.h5')
  75. os.rename(Path(out_subdir) / 'TemplateSlice_InverseComposite.h5', Path(out_subdir) / 'Transformation_InverseComposite.h5')
  76. return seg_loc
  77. if __name__ == "__main__":
  78. parser = argparse.ArgumentParser()
  79. parser.add_argument("sliceloc", metavar="Slice Location", help="Location of preprocessed slice")
  80. parser.add_argument("segloc", metavar="Segmentation Location", help="Location of segmentation")
  81. parser.add_argument("tloc", metavar="Template Location", help="Location of template file")
  82. parser.add_argument("--approx", type = int, default=-1, help="Approximate index of slice relative to template file")
  83. parser.add_argument("bregma", metavar = 'Bregma index', type = float, help="Approx bregma coordinates")
  84. parser.add_argument("out", metavar = 'Output directory', default="output", help="Output directory")
  85. parser.add_argument("--dapi", type=int, default=0, help="DAPI channel number, default is last channel")
  86. args = parser.parse_args()
  87. if args.approx == -1 and not args.bregma:
  88. print("Error: Please specify an approximate location in the template, or a bregma coordinate of the slice")
  89. sys.exit(2)
  90. if args.bregma:
  91. args.approx = tools.bregma_to_slice_index(args.bregma)
  92. generate_segmentation(args.sliceloc, args.segloc, args.tloc, args.approx, args.dapi, args.out)