123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- from pathlib import Path
- from pprint import pprint
- import argparse
- import sys
- import os
- import numpy as np
- import scipy.stats
- import nibabel as nib
- from nipype.interfaces import ants
- import registration
- import tools
- import preprocess
- from PIL import Image, ImageSequence
- import warnings
- warnings.simplefilter(action='ignore', category=FutureWarning)
- def find_optimal_index(slice_path, template_path, approx, temporary_directory, dist = 2):
- weights = []
- for i in range(approx-dist,approx+dist+1):
- print(f'Slice index: {i}')
- _slice = nib.load(str(slice_path))
- template_slice_loc = str(temporary_directory / f't_slice{i}.nii')
- registered_loc = str(temporary_directory / f'reg_slice{i}.nii')
- transform_prefix = str(temporary_directory / f'{i}-')
- tools.save_slice(str(template_path), template_slice_loc, i, np.eye(4), _slice.header)
- registration.rigid_registration(template_slice_loc, str(slice_path), registered_loc, transform_prefix)
- template_data = nib.load(str(template_slice_loc)).get_fdata()
- registered_data = nib.load(registered_loc).get_fdata()
- registered_data = tools.resize_im(registered_data, template_data)
- nib.save(nib.Nifti1Image(registered_data, np.eye(4)), registered_loc)
- mutualinfo = tools.mutual_info_mask(template_data, registered_data)
- norm_factor = scipy.stats.norm.pdf(i-approx+dist, dist, dist*10) / scipy.stats.norm.pdf(dist, dist, dist*10)
- dice_coef = tools.dice_coef(template_data, registered_data)
- weights.append((i, norm_factor * (0.7 * mutualinfo + 0.3 * dice_coef)))
- pprint(weights)
- optimal = max(weights, key=lambda a: a[1])
- print(optimal[0])
- return optimal[0]
- def apply_transform(segmentation, fixed, index, out_dir, temporary_directory):
- seg_slice_loc = out_dir / "segment_slice.nii"
- tools.save_slice(str(segmentation), str(seg_slice_loc), index, np.eye(4), nib.load(str(fixed)).header)
- post_transform_loc = out_dir / f'Segmentation.nii'
- template_slice_loc = str(temporary_directory / f't_slice{index}.nii')
- registration.full_registration(str(template_slice_loc), str(fixed),
- str(Path(out_dir, f'{fixed.stem}_template.nii')), str(out_dir / f'TemplateSlice_'))
-
- transform = [str(out_dir / f"TemplateSlice_InverseComposite.h5")]
- reverse_transform = ants.ApplyTransforms(
- input_image=str(seg_slice_loc),
- reference_image=str(fixed),
- transforms=transform,
- invert_transform_flags=[False],
- interpolation='MultiLabel',
- dimension=2,
- output_image=str(post_transform_loc))
- reverse_transform.run()
- return post_transform_loc
- def generate_segmentation(slice_path, segment_path, template_path, approx_index, dapi, output_directory):
- tools.create_dir(output_directory)
- if slice_path.endswith('.tiff') or slice_path.endswith('.tif'):
- Image.MAX_IMAGE_PIXELS = None
- tif_path = preprocess.tiff_to_nii(slice_path, output_directory)
- slice_path = tif_path[0]
- elif not slice_path.endswith(".nii"):
- print('Selected slice file was neither a nifti or a TIFF file. Please check slice file format')
- sys.exit(2)
-
- out_subdir = Path(output_directory) / Path(slice_path).stem.replace(' ', '_').replace(',','_').replace('.','')
- tools.create_dir(out_subdir)
- temporary_directory = Path(out_subdir) / 'tmp'
- tools.create_dir(temporary_directory)
- channel_paths = preprocess.split_nii_channels(slice_path, out_subdir, True, dapi-1)
- optimal_index = find_optimal_index(channel_paths[dapi-1], template_path, approx_index, temporary_directory)
- seg_loc = apply_transform(segment_path, channel_paths[dapi-1], optimal_index, out_subdir, temporary_directory)
-
- tools.remove_dir(temporary_directory)
- os.remove(Path(out_subdir) / 'segment_slice.nii')
- os.rename(Path(out_subdir) / 'TemplateSlice_Composite.h5', Path(out_subdir) / 'Transformation_Composite.h5')
- os.rename(Path(out_subdir) / 'TemplateSlice_InverseComposite.h5', Path(out_subdir) / 'Transformation_InverseComposite.h5')
-
- return seg_loc
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("sliceloc", metavar="Slice Location", help="Location of preprocessed slice")
- parser.add_argument("segloc", metavar="Segmentation Location", help="Location of segmentation")
- parser.add_argument("tloc", metavar="Template Location", help="Location of template file")
- parser.add_argument("--approx", type = int, default=-1, help="Approximate index of slice relative to template file")
- parser.add_argument("bregma", metavar = 'Bregma index', type = float, help="Approx bregma coordinates")
- parser.add_argument("out", metavar = 'Output directory', default="output", help="Output directory")
- parser.add_argument("--dapi", type=int, default=0, help="DAPI channel number, default is last channel")
- args = parser.parse_args()
- if args.approx == -1 and not args.bregma:
- print("Error: Please specify an approximate location in the template, or a bregma coordinate of the slice")
- sys.exit(2)
- if args.bregma:
- args.approx = tools.bregma_to_slice_index(args.bregma)
- generate_segmentation(args.sliceloc, args.segloc, args.tloc, args.approx, args.dapi, args.out)
-
|