# System imports from pathlib import Path from pprint import pprint import argparse import sys # 3rd Party imports import numpy as np import scipy.stats import nibabel as nib from nipype.interfaces import ants # Local source tree imports import registration import tools import preprocess as pre TEMP_DIR = Path('temp') OUTPUT_DIR = Path('output') IMG_DIR = Path('Images') def find_optimal_index(slice_path, template_path, approx, dist=3): """ Finds the most appropriate index in template path to register the slice file to. :param approx: Approximate index of slice relative to template file :param dist: maximum distance of slices away from approximate to index to check :return: index of slice in template which has highest metric - combined mutual information and dice coefficient. """ weights = [] for i in range(approx-dist,approx+dist+1): print(f'Slice index: {i}') _slice = nib.load(str(slice_path)) template_slice_loc = str(TEMP_DIR/f't_slice{i}.nii') registered_loc = str(TEMP_DIR/ f'reg_slice{i}.nii') transform_prefix = str(TEMP_DIR/f'{i}-') tools.save_slice(str(template_path), template_slice_loc, i, np.eye(4), _slice.header) registration.rigid_registration(template_slice_loc, str(slice_path), registered_loc, transform_prefix) template_data = nib.load(str(template_slice_loc)).get_data() registered_data = nib.load(registered_loc).get_data() registered_data = tools.resize_im(registered_data, template_data) nib.save(nib.Nifti1Image(registered_data, np.eye(4)), registered_loc) mutualinfo = tools.mutual_info_mask(template_data, registered_data) norm_factor = scipy.stats.norm.pdf(i-approx+dist, dist, dist*2) / scipy.stats.norm.pdf(dist, dist, dist*2) dice_coef = tools.dice_coef(template_data, registered_data) weights.append((i, norm_factor * (0.7 * mutualinfo + 0.3 * dice_coef))) pprint(weights) optimal = max(weights, key=lambda a: a[1]) return optimal[0] def apply_transform(segmentation, fixed, index, out_dir): """ Full transforms the slice->template slice, and then applies the reverse transform to the segmentation slice, to get a segmentation of the slice. :return: Location of the segmentation slice registered to the slice """ seg_slice_loc = out_dir / "segment_slice.nii" tools.save_slice(str(segmentation), str(seg_slice_loc), index, np.eye(4), nib.load(str(fixed)).header) post_transform_loc = out_dir / f'{seg_slice_loc.stem}_t.nii' template_slice_loc = str(TEMP_DIR / f't_slice{index}.nii') registration.full_registration(str(template_slice_loc), str(fixed), str(Path(out_dir, f'{fixed.stem}_t.nii')), str(out_dir / f'Final-')) transform = [str(out_dir / f"Final-InverseComposite.h5")] reverse_transform = ants.ApplyTransforms( input_image=str(seg_slice_loc), reference_image=str(fixed), transforms=transform, invert_transform_flags=[False], interpolation='MultiLabel', dimension=2, output_image=str(post_transform_loc)) reverse_transform.run() return post_transform_loc def generate_segmentation(slice_path, segment_path, template_path, approx_index, dapi): """ Provided path to slice, segmentation, template and an approximate index, will apply multiple registrations to create an automatic segmentation of the slice :param dapi: index of DAPI modality channel :return: Location of the segmentation slice registered to the slice """ out_subdir = Path(OUT_DIR) / Path(slice_path).stem.replace(' ', '').replace(',', '').replace('.', '') tools.create_dir(TEMP_DIR) tools.create_dir(out_subdir) channel_paths = pre.split_nii_channels(slice_path, out_subdir, False, dapi-1) optimal_index = find_optimal_index(channel_paths[dapi-1], template_path, approx_index) seg_loc = apply_transform(segment_path, channel_paths[dapi-1], optimal_index, out_subdir) tools.remove_dir(TEMP_DIR) return seg_loc if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("sliceloc", metavar="Slice Location", help="Location of preprocessed slice") parser.add_argument("segloc", metavar="Segmentation Location", help="Location of segmentation") parser.add_argument("tloc", metavar="Template Location", help="Location of template file") parser.add_argument("--approx", type=int, default=-1, help="Approximate index of slice relative to template file") parser.add_argument("--bregma", type=float, help="Approx bregma coordinates") parser.add_argument("--out", default="output", help="Output directory") parser.add_argument("--dapi", type=int, default=0, help="DAPI channel number, default is last channel") args = parser.parse_args() OUT_DIR = Path(args.out) if args.approx == -1 and not args.bregma: print("Error: please specify an approximate location in the template, or a bregma coordinate of the slice") sys.exit(2) if args.bregma: args.approx = tools.bregma_to_slice_index(args.bregma) generate_segmentation(args.sliceloc, args.segloc, args.tloc, args.approx, args.dapi)