123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- # System imports
- from pathlib import Path
- from pprint import pprint
- import argparse
- import sys
- # 3rd Party imports
- import numpy as np
- import scipy.stats
- import nibabel as nib
- from nipype.interfaces import ants
- # Local source tree imports
- import registration
- import tools
- import preprocess as pre
- TEMP_DIR = Path('temp')
- OUTPUT_DIR = Path('output')
- IMG_DIR = Path('Images')
- def find_optimal_index(slice_path, template_path, approx, dist=2):
- """
- Finds the most appropriate index in template path to register the slice file to.
- :param approx: Approximate index of slice relative to template file
- :param dist: maximum distance of slices away from approximate to index to check
- :return: index of slice in template which has highest metric - combined mutual information and dice coefficient.
- """
- weights = []
- for i in range(approx-dist,approx+dist+1):
- print(f'Slice index: {i}')
- _slice = nib.load(str(slice_path))
- template_slice_loc = str(TEMP_DIR/f't_slice{i}.nii')
- registered_loc = str(TEMP_DIR/ f'reg_slice{i}.nii')
- transform_prefix = str(TEMP_DIR/f'{i}-')
- tools.save_slice(str(template_path), template_slice_loc, i, np.eye(4), _slice.header)
- registration.rigid_registration(template_slice_loc, str(slice_path), registered_loc, transform_prefix)
- template_data = nib.load(str(template_slice_loc)).get_data()
- registered_data = nib.load(registered_loc).get_data()
- registered_data = tools.resize_im(registered_data, template_data)
- nib.save(nib.Nifti1Image(registered_data, np.eye(4)), registered_loc)
- mutualinfo = tools.mutual_info_mask(template_data, registered_data)
- norm_factor = scipy.stats.norm.pdf(i-approx+dist, dist, dist*2) / scipy.stats.norm.pdf(dist, dist, dist*2)
- dice_coef = tools.dice_coef(template_data, registered_data)
- weights.append((i, norm_factor * (0.7 * mutualinfo + 0.3 * dice_coef)))
- pprint(weights)
- optimal = max(weights, key=lambda a: a[1])
- return optimal[0]
- def apply_transform(segmentation, fixed, index, out_dir):
- """
- Full transforms the slice->template slice, and then applies the reverse transform to the segmentation slice, to get
- a segmentation of the slice.
- :return: Location of the segmentation slice registered to the slice
- """
- seg_slice_loc = out_dir / "segment_slice.nii"
- tools.save_slice(str(segmentation), str(seg_slice_loc), index, np.eye(4), nib.load(str(fixed)).header)
- post_transform_loc = out_dir / f'{seg_slice_loc.stem}_t.nii'
- template_slice_loc = str(TEMP_DIR / f't_slice{index}.nii')
- registration.full_registration(str(template_slice_loc), str(fixed),
- str(Path(out_dir, f'{fixed.stem}_t.nii')), str(out_dir / f'Final-'))
- transform = [str(out_dir / f"Final-InverseComposite.h5")]
- reverse_transform = ants.ApplyTransforms(
- input_image=str(seg_slice_loc),
- reference_image=str(fixed),
- transforms=transform,
- invert_transform_flags=[False],
- interpolation='MultiLabel',
- dimension=2,
- output_image=str(post_transform_loc))
- reverse_transform.run()
- return post_transform_loc
- def generate_segmentation(slice_path, segment_path, template_path, approx_index, dapi):
- """
- Provided path to slice, segmentation, template and an approximate index, will apply multiple registrations to create
- an automatic segmentation of the slice
- :param dapi: index of DAPI modality channel
- :return: Location of the segmentation slice registered to the slice
- """
- out_subdir = Path(OUT_DIR) / Path(slice_path).stem.replace(' ', '').replace(',', '').replace('.', '')
- tools.create_dir(TEMP_DIR)
- tools.create_dir(out_subdir)
- channel_paths = pre.split_nii_channels(slice_path, out_subdir, False, dapi-1)
- optimal_index = find_optimal_index(channel_paths[dapi-1], template_path, approx_index)
- seg_loc = apply_transform(segment_path, channel_paths[dapi-1], optimal_index, out_subdir)
- tools.remove_dir(TEMP_DIR)
- return seg_loc
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("sliceloc", metavar="Slice Location", help="Location of preprocessed slice")
- parser.add_argument("segloc", metavar="Segmentation Location", help="Location of segmentation")
- parser.add_argument("tloc", metavar="Template Location", help="Location of template file")
- parser.add_argument("--approx", type=int, default=-1, help="Approximate index of slice relative to template file")
- parser.add_argument("--bregma", type=float, help="Approx bregma coordinates")
- parser.add_argument("--out", default="output", help="Output directory")
- parser.add_argument("--dapi", type=int, default=0, help="DAPI channel number, default is last channel")
- args = parser.parse_args()
- OUT_DIR = Path(args.out)
- if args.approx == -1 and not args.bregma:
- print("Error: please specify an approximate location in the template, or a bregma coordinate of the slice")
- sys.exit(2)
- if args.bregma:
- args.approx = tools.bregma_to_slice_index(args.bregma)
- generate_segmentation(args.sliceloc, args.segloc, args.tloc, args.approx, args.dapi)
|