from pathlib import Path from pprint import pprint import argparse import sys import os import numpy as np import scipy.stats import nibabel as nib from nipype.interfaces import ants import registration import tools import preprocess from PIL import Image, ImageSequence import warnings warnings.simplefilter(action='ignore', category=FutureWarning) def find_optimal_index(slice_path, template_path, approx, temporary_directory, dist = 2): weights = [] for i in range(approx-dist,approx+dist+1): print(f'Slice index: {i}') _slice = nib.load(str(slice_path)) template_slice_loc = str(temporary_directory / f't_slice{i}.nii') registered_loc = str(temporary_directory / f'reg_slice{i}.nii') transform_prefix = str(temporary_directory / f'{i}-') tools.save_slice(str(template_path), template_slice_loc, i, np.eye(4), _slice.header) registration.rigid_registration(template_slice_loc, str(slice_path), registered_loc, transform_prefix) template_data = nib.load(str(template_slice_loc)).get_fdata() registered_data = nib.load(registered_loc).get_fdata() registered_data = tools.resize_im(registered_data, template_data) nib.save(nib.Nifti1Image(registered_data, np.eye(4)), registered_loc) mutualinfo = tools.mutual_info_mask(template_data, registered_data) norm_factor = scipy.stats.norm.pdf(i-approx+dist, dist, dist*2) / scipy.stats.norm.pdf(dist, dist, dist*2) dice_coef = tools.dice_coef(template_data, registered_data) weights.append((i, norm_factor * (0.7 * mutualinfo + 0.3 * dice_coef))) pprint(weights) optimal = max(weights, key=lambda a: a[1]) print(optimal[0]) return optimal[0] def apply_transform(segmentation, fixed, index, out_dir, temporary_directory): seg_slice_loc = out_dir / "segment_slice.nii" tools.save_slice(str(segmentation), str(seg_slice_loc), index, np.eye(4), nib.load(str(fixed)).header) # post_transform_loc = out_dir / f'{seg_slice_loc.stem}_t.nii' post_transform_loc = out_dir / f'Segmentation.nii' template_slice_loc = str(temporary_directory / f't_slice{index}.nii') registration.full_registration(str(template_slice_loc), str(fixed), str(Path(out_dir, f'{fixed.stem}_template.nii')), str(out_dir / f'TemplateSlice_')) transform = [str(out_dir / f"TemplateSlice_InverseComposite.h5")] reverse_transform = ants.ApplyTransforms( input_image=str(seg_slice_loc), reference_image=str(fixed), transforms=transform, invert_transform_flags=[False], interpolation='MultiLabel', dimension=2, output_image=str(post_transform_loc)) reverse_transform.run() return post_transform_loc def generate_segmentation(slice_path, segment_path, template_path, approx_index, dapi, output_directory): if slice_path.endswith('.tiff') or slice_path.endswith('.tif'): Image.MAX_IMAGE_PIXELS = None # tif_image = Image.open(slice_path) # ch_nm = (list(tif_image.tag_v2[270].split('='))[2]) # if int(ch_nm[0]) < dapi: # print(f'DAPI channel out of range, max range is {ch_nm[0]} channels') # sys.exit() tif_path = preprocess.tiff_to_nii(slice_path, output_directory) slice_path = tif_path[0] elif not slice_path.endswith(".nii"): print('Selected slice file was neither a nifti or a TIFF file. Please check slice file format') sys.exit(2) out_subdir = Path(output_directory) / Path(slice_path).stem.replace(' ', '_').replace(',','_').replace('.','') tools.create_dir(out_subdir) temporary_directory = Path(output_directory) / Path(out_subdir) / 'tmp' tools.create_dir(temporary_directory) channel_paths = preprocess.split_nii_channels(slice_path, out_subdir, True, dapi-1) optimal_index = find_optimal_index(channel_paths[dapi-1], template_path, approx_index, temporary_directory) seg_loc = apply_transform(segment_path, channel_paths[dapi-1], optimal_index, out_subdir, temporary_directory) tools.remove_dir(temporary_directory) os.remove(Path(out_subdir) / 'segment_slice.nii') os.rename(Path(out_subdir) / 'TemplateSlice_Composite.h5', Path(out_subdir) / 'Transformation_Composite.h5') os.rename(Path(out_subdir) / 'TemplateSlice_InverseComposite.h5', Path(out_subdir) / 'Transformation_InverseComposite.h5') return seg_loc if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("sliceloc", metavar="Slice Location", help="Location of preprocessed slice") parser.add_argument("segloc", metavar="Segmentation Location", help="Location of segmentation") parser.add_argument("tloc", metavar="Template Location", help="Location of template file") parser.add_argument("--approx", type = int, default=-1, help="Approximate index of slice relative to template file") parser.add_argument("bregma", metavar = 'Bregma index', type = float, help="Approx bregma coordinates") parser.add_argument("out", metavar = 'Output directory', default="output", help="Output directory") parser.add_argument("--dapi", type=int, default=0, help="DAPI channel number, default is last channel") args = parser.parse_args() if args.approx == -1 and not args.bregma: print("Error: Please specify an approximate location in the template, or a bregma coordinate of the slice") sys.exit(2) if args.bregma: args.approx = tools.bregma_to_slice_index(args.bregma) generate_segmentation(args.sliceloc, args.segloc, args.tloc, args.approx, args.dapi, args.out)