Ver Fonte

gin commit from rMBP-15-Malthe.local

New files: 10
Modified files: 1
malthe.nielsen há 4 anos atrás
pai
commit
61f6b58f1c

+ 1 - 0
Example.zip

@@ -0,0 +1 @@
+/annex/objects/MD5-s495344667--64c2c42dd71431057b2ca65394fc0a1f

+ 1 - 0
Example/data/dapi_template.nii.gz

@@ -0,0 +1 @@
+/annex/objects/MD5-s491624985--e7c61c5afd441c9b5c4f1dfaa36f9ec6

BIN
Example/data/dapi_template_segmentation_full.nii.gz


+ 138 - 0
Example/preprocess.py

@@ -0,0 +1,138 @@
+from pathlib import Path
+import argparse
+import numpy as np
+import nibabel as nib
+from PIL import Image, ImageSequence
+from scipy import ndimage
+from skimage import morphology, filters
+from nipype.interfaces.ants.segmentation import N4BiasFieldCorrection
+
+import warnings
+warnings.simplefilter(action='ignore', category=FutureWarning)
+
+
+def get_out_paths(out_dir, stem):
+    out_path_scaled = Path(out_dir, f'{stem}.nii')
+    out_path_unscaled = Path(out_dir, f'{stem}_us.nii')
+    return out_path_scaled, out_path_unscaled
+
+
+def bias_field_correction(image, output, dim=2):
+    N4BiasFieldCorrection(
+        input_image=image,
+        output_image=output,
+        dimension=dim
+    ).run()
+
+
+def image_data_to_nii(pixdim, image_data, shrink, out_dir, file_path, save_unscaled=False):
+    image_dim = (image_data.shape[0], image_data.shape[1])
+    scale = 1/shrink
+    new_dim = (round(image_dim[1] * scale), round(image_dim[0] * scale))
+    new_arr = np.ndarray(new_dim + (image_data.shape[2],))
+
+    for i in range(0,image_data.shape[2]):
+        cur_channel = image_data[:, :, i]
+        resized = np.array(Image.fromarray(cur_channel).resize(new_dim)).transpose()
+        new_arr[:, :, i] = resized
+    path_scaled, path_unscaled = get_out_paths(out_dir, Path(file_path).stem)
+    nii_scaled = nib.Nifti1Image(new_arr, np.eye(4))
+    nii_scaled.header['xyzt_units'] = 3
+    nii_scaled.header['pixdim'][1:3] = pixdim * shrink, pixdim * shrink
+    nib.save(nii_scaled, str(path_scaled))
+
+    if save_unscaled:
+        nii_unscaled = nib.Nifti1Image(image_data, np.eye(4))
+        nii_unscaled.header['xyzt_units'] = 3
+        nii_unscaled.header['pixdim'][1:3] = pixdim, pixdim
+        nib.save(nii_unscaled, str(path_unscaled))
+    print(f'Preprocessed:  {path_scaled}\n')
+    return path_scaled, path_unscaled
+
+
+def tiff_to_nii(tif_path, out_dir, pixdim=None, shrink=10):
+    Image.MAX_IMAGE_PIXELS = None
+    tif_image = Image.open(tif_path)
+    tif_header = dict(tif_image.tag)
+    output = np.empty(np.array(tif_image).shape + (0,))
+    if not pixdim:
+        pixdim = 10e6/tif_header[282][0][0]
+    for i, page in enumerate(ImageSequence.Iterator(tif_image)):
+        page_data = np.expand_dims(np.array(page), 2)
+        output = np.concatenate((output, page_data), 2)
+    return image_data_to_nii(pixdim, output, shrink, out_dir, tif_path)
+
+
+def split_nii_channels(nii_path, out_dir=None, flip=False, mask_index=-1, bias=False):
+    if out_dir is None:
+        out_dir = nii_path.parent
+    nii = nib.load(str(nii_path))
+    nii_data = nii.get_fdata()
+    nii_header = nii.header
+
+    if mask_index == -1:
+        mask_index = nii_data.shape[2] - 1
+    paths = []
+
+    for i in range(0, nii_data.shape[2]):
+        out_path = out_dir / f'im_c{i+1}.nii'
+        channel_data = nii_data[:, :, i]
+
+        if flip:
+            channel_data = np.flip(channel_data, 1)
+        if i == mask_index:
+            channel_data = mask_foreground(channel_data)
+
+        new_header = nii_header
+        new_header['dim'][0] = 2
+        nii = nib.Nifti1Image(channel_data, np.eye(4), header=new_header)
+        nib.save(nii, str(out_path))
+
+        if i == mask_index and bias:
+            bias_field_correction(str(out_path), str(out_path))
+            corrected = nib.load(str(out_path))
+            corrected_data = corrected.get_data()
+            corrected_normalized = corrected_data / np.mean(corrected_data[corrected_data != 0])
+            nii_corrected = nib.Nifti1Image(corrected_normalized, corrected.affine, corrected.header)
+            nib.save(nii_corrected, str(out_path))
+        paths.append(out_path)
+    return paths
+
+
+def mask_foreground(raw_data):
+    raw_max = raw_data.max()
+    raw_data = raw_data / raw_max
+    blurred_data = ndimage.gaussian_filter(raw_data, 4)
+
+    threshold = filters.threshold_otsu(raw_data) / 2
+    threshold_data = blurred_data > threshold
+    connected_structure = ndimage.generate_binary_structure(2, 2)  # Connects adjacent and diagonal.
+    padded_comp, padded_nr = ndimage.label(threshold_data, structure=connected_structure)
+
+    comps, comps_count = np.unique(padded_comp, return_counts=True)
+    comps_count, comps = zip(*sorted(zip(comps_count, comps), reverse=True))
+
+    two_biggest_cc = ((comps[0], np.average(comps[0])), (comps[1], np.average(comps[1])))
+    biggest_cc = max(two_biggest_cc, key=lambda a: a[1])[0]
+
+    foreground_mask = np.where(padded_comp == biggest_cc, True, False)
+    closed = morphology.binary_closing(foreground_mask, selem=morphology.square(30))
+    raw_data = np.where(closed, raw_data, 0)
+    return raw_data * raw_max
+    
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument("file", help="Location of file to process")
+    parser.add_argument("dir", help="Directory for preprocessed Files")
+    parser.add_argument("-s", "--series", type=int, help="Series to extract")
+    parser.add_argument('-b', type=bool, default= True, help='Bias field correct image')
+    parser.add_argument('--pdim', type=float, help='Pixel dimensions (Retrieved from Tiff file if not set)')
+    args = parser.parse_args()
+
+    filetype = Path(args.file).suffix.lower()
+    if filetype == '.tif' or filetype == '.tiff':
+        n_path = tiff_to_nii(Path(args.file), Path(args.dir))
+    elif filetype == '.nd2':
+        n_path = nd2_to_nii(Path(args.file), Path(args.dir), args.series)
+    elif filetype == '.nii':
+        n_path = Path(args.file)

+ 64 - 0
Example/registration.py

@@ -0,0 +1,64 @@
+from nipype.interfaces import ants
+import warnings
+warnings.simplefilter(action='ignore', category=FutureWarning)
+
+
+# Parameters for a full registration (Affine,Syn)  --- SLOW --- 
+def full_registration(fixed, moving, output='full_trans.nii', transform_pre=''):
+    ants.Registration(
+        dimension=2,
+        transforms=['Rigid', 'Affine', 'SyN'],
+        transform_parameters=[(0.15,), (0.15,), (0.3, 3, 0.5)],
+        metric=['CC', 'MI'] * 3,
+        convergence_threshold=[1.e-7],
+        number_of_iterations=[[1000, 1000, 1000, 1000], [1000, 1000, 1000, 1000],
+                              [1000, 1000, 1000, 1000, 1000]],
+        smoothing_sigmas=[[4, 4, 2, 2], [4, 4, 2, 2], [4, 2, 2, 2, 1]],
+        shrink_factors=[[32, 16, 8, 4], [32, 16, 8, 4], [16, 8, 4, 2, 1]],
+        initial_moving_transform_com=1,
+        interpolation='BSpline',
+        metric_weight=[0.5] * 6,
+        radius_or_number_of_bins=[4,32] * 3,
+        sampling_strategy=['Regular'] * 3,
+        use_histogram_matching=True,
+        winsorize_lower_quantile=0.05,
+        winsorize_upper_quantile=0.95,
+        collapse_output_transforms=True,
+        output_transform_prefix=transform_pre,
+        fixed_image=[fixed],
+        moving_image=[moving],
+        output_warped_image=output,
+        num_threads=12,
+        verbose=False,
+        write_composite_transform=True,
+        output_inverse_warped_image=True
+    ).run()
+    return output
+
+
+# A faster rigid registration
+def rigid_registration(fixed, moving, output='affine_trans.nii', transform_pre=''):
+    ants.Registration(
+        dimension=2,
+        transforms=['Rigid'],
+        transform_parameters=[(0.15,)],
+        metric=['MI'],
+        smoothing_sigmas=[[4, 4, 2]],
+        convergence_threshold=[1.e-6],
+        number_of_iterations=[[1000, 1000, 1000]],
+        shrink_factors=[[32, 16, 8]],
+        initial_moving_transform_com=0,
+        interpolation='BSpline',
+        metric_weight=[1],
+        radius_or_number_of_bins=[32],
+        sampling_strategy=['Regular'],
+        use_histogram_matching=True,
+        winsorize_lower_quantile=0.05,
+        winsorize_upper_quantile=0.95,
+        terminal_output='none',
+        output_transform_prefix=transform_pre,
+        fixed_image=[fixed],
+        moving_image=[moving],
+        output_warped_image=output
+    ).run()
+    return output

+ 115 - 0
Example/runner.py

@@ -0,0 +1,115 @@
+from pathlib import Path
+from pprint import pprint
+import argparse
+import sys
+import os
+import numpy as np
+import scipy.stats
+import nibabel as nib
+from nipype.interfaces import ants
+import registration
+import tools
+import preprocess
+
+from PIL import Image, ImageSequence
+import warnings
+warnings.simplefilter(action='ignore', category=FutureWarning)
+
+def find_optimal_index(slice_path, template_path, approx, temporary_directory, dist = 2):
+
+    weights = []
+    for i in range(approx-dist,approx+dist+1):
+        print(f'Slice index: {i}')
+        _slice = nib.load(str(slice_path))
+        template_slice_loc = str(temporary_directory / f't_slice{i}.nii')
+        registered_loc = str(temporary_directory / f'reg_slice{i}.nii')
+        transform_prefix = str(temporary_directory / f'{i}-')
+
+        tools.save_slice(str(template_path), template_slice_loc, i, np.eye(4), _slice.header)
+        registration.rigid_registration(template_slice_loc, str(slice_path), registered_loc, transform_prefix)
+
+        template_data = nib.load(str(template_slice_loc)).get_data()
+        registered_data = nib.load(registered_loc).get_data()
+        registered_data = tools.resize_im(registered_data, template_data)
+        nib.save(nib.Nifti1Image(registered_data, np.eye(4)), registered_loc)
+
+        mutualinfo = tools.mutual_info_mask(template_data, registered_data)
+        norm_factor = scipy.stats.norm.pdf(i-approx+dist, dist, dist*2) / scipy.stats.norm.pdf(dist, dist, dist*2)
+        dice_coef = tools.dice_coef(template_data, registered_data)
+        weights.append((i, norm_factor * (0.7 * mutualinfo + 0.3 * dice_coef)))
+
+    pprint(weights)
+    optimal = max(weights, key=lambda a: a[1])
+    print(optimal[0])
+    return optimal[0]
+
+def apply_transform(segmentation, fixed, index, out_dir, temporary_directory):
+
+    seg_slice_loc = out_dir / "segment_slice.nii" 
+    tools.save_slice(str(segmentation), str(seg_slice_loc), index, np.eye(4), nib.load(str(fixed)).header)
+    #  post_transform_loc = out_dir / f'{seg_slice_loc.stem}_t.nii'
+    post_transform_loc = out_dir / f'Segmentation.nii'
+
+    template_slice_loc = str(temporary_directory / f't_slice{index}.nii')
+    registration.full_registration(str(template_slice_loc), str(fixed),
+                                   str(Path(out_dir, f'{fixed.stem}_template.nii')), str(out_dir / f'TemplateSlice_'))
+                                                             
+    transform = [str(out_dir / f"TemplateSlice_InverseComposite.h5")]
+    reverse_transform = ants.ApplyTransforms(
+            input_image=str(seg_slice_loc),
+            reference_image=str(fixed),
+            transforms=transform,
+            invert_transform_flags=[False],
+            interpolation='MultiLabel',
+            dimension=2,
+            output_image=str(post_transform_loc))
+    reverse_transform.run()
+
+    return post_transform_loc
+
+def generate_segmentation(slice_path, segment_path, template_path, approx_index, dapi, output_directory):     
+        
+    if slice_path.endswith('.tiff') or slice_path.endswith('.tif'):
+        Image.MAX_IMAGE_PIXELS = None
+        #  tif_image = Image.open(slice_path)
+        #  ch_nm = (list(tif_image.tag_v2[270].split('='))[2])
+        #  if int(ch_nm[0]) < dapi:
+        #      print(f'DAPI channel out of range, max range is {ch_nm[0]} channels')
+        #      sys.exit()
+        tif_path = preprocess.tiff_to_nii(slice_path, output_directory)
+        slice_path = tif_path[0]
+    elif not slice_path.endswith(".nii"):
+        print('Selected slice file was neither a nifti or a TIFF file. Please check slice file format')
+        sys.exit(2)
+        
+    out_subdir = Path(output_directory) / Path(slice_path).stem.replace(' ', '_').replace(',','_').replace('.','')
+    tools.create_dir(out_subdir)
+    temporary_directory = Path(output_directory) / Path(out_subdir) / 'tmp'
+    tools.create_dir(temporary_directory)
+    channel_paths = preprocess.split_nii_channels(slice_path, out_subdir, True, dapi-1)
+    optimal_index = find_optimal_index(channel_paths[dapi-1], template_path, approx_index, temporary_directory)
+    seg_loc = apply_transform(segment_path, channel_paths[dapi-1], optimal_index, out_subdir, temporary_directory)
+    tools.remove_dir(temporary_directory)
+    os.remove(Path(out_subdir) / 'segment_slice.nii')
+    os.rename(Path(out_subdir) / 'TemplateSlice_Composite.h5', Path(out_subdir) / 'Transformation_Composite.h5')
+    os.rename(Path(out_subdir) / 'TemplateSlice_InverseComposite.h5', Path(out_subdir) / 'Transformation_InverseComposite.h5')
+    return seg_loc
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("sliceloc", metavar="Slice Location", help="Location of preprocessed slice")
+    parser.add_argument("segloc", metavar="Segmentation Location", help="Location of segmentation")
+    parser.add_argument("tloc", metavar="Template Location", help="Location of template file")
+    parser.add_argument("--approx", type = int, default=-1, help="Approximate index of slice relative to template file")
+    parser.add_argument("bregma", metavar = 'Bregma index', type = float, help="Approx bregma coordinates")
+    parser.add_argument("out", metavar = 'Output directory', default="output", help="Output directory")
+    parser.add_argument("--dapi", type=int, default=0, help="DAPI channel number, default is last channel")
+
+    args = parser.parse_args()
+    if args.approx == -1 and not args.bregma:
+        print("Error: Please specify an approximate location in the template, or a bregma coordinate of the slice")
+        sys.exit(2)
+    if args.bregma:
+        args.approx = tools.bregma_to_slice_index(args.bregma)
+    generate_segmentation(args.sliceloc, args.segloc, args.tloc, args.approx, args.dapi, args.out)
+    

+ 49 - 0
Example/tools.py

@@ -0,0 +1,49 @@
+import os
+import nibabel as nib
+from PIL import Image
+import numpy as np
+from sklearn.metrics import mutual_info_score
+import shutil
+
+from preprocess import mask_foreground
+def save_slice(template, filename, index, affine=np.eye(4), header=None):
+    _slice = nib.load(str(template))
+    _slice_data = _slice.get_data()
+
+    newimg = nib.Nifti1Image(_slice_data[:, index, :], affine, header)
+    newimg.header['pixdim'] = header['pixdim']
+
+    if header is None:
+        newimg.header['pixdim'][1:3] = _slice.header['pixdim'][1], _slice.header['pixdim'][3]
+    newimg.to_filename(str(filename))
+    return filename
+
+def create_dir(name):
+    if not os.path.exists(name):
+        os.mkdir(name)
+def remove_dir(loc):
+    shutil.rmtree(loc)
+    
+def mutual_info_mask(slice1,slice2, bins=32):
+    slice1_masked = mask_foreground(slice1)
+    slice2_masked = mask_foreground(slice2)
+    common = np.logical_and(slice1_masked, slice2_masked)
+    slice1, slice2 = np.where(common, slice1, 0), np.where(common, slice2, 0)
+    hist = np.histogram2d(slice1.ravel(), slice2.ravel(), bins=bins)[0]
+    return mutual_info_score(None, None, contingency=hist)
+
+
+def dice_coef(slice1, slice2):
+    mask_1 = mask_foreground(slice1).astype(np.bool)
+    mask_2 = mask_foreground(slice2).astype(np.bool)
+    intersection = np.logical_and(mask_1, mask_2)
+    return 2. * intersection.sum() / (mask_1.sum() + mask_2.sum())
+
+
+def resize_im(image_data, image_with_dimensions):
+    dimensions = (image_with_dimensions.shape[1], image_with_dimensions.shape[0])
+    resized = np.array(Image.fromarray(image_data).resize(dimensions))
+    return resized
+
+def bregma_to_slice_index(bregma):
+   return round(27.908*bregma + 116.831)

+ 138 - 0
automatic_segmentation_program/preprocess.py

@@ -0,0 +1,138 @@
+from pathlib import Path
+import argparse
+import numpy as np
+import nibabel as nib
+from PIL import Image, ImageSequence
+from scipy import ndimage
+from skimage import morphology, filters
+from nipype.interfaces.ants.segmentation import N4BiasFieldCorrection
+
+import warnings
+warnings.simplefilter(action='ignore', category=FutureWarning)
+
+
+def get_out_paths(out_dir, stem):
+    out_path_scaled = Path(out_dir, f'{stem}.nii')
+    out_path_unscaled = Path(out_dir, f'{stem}_us.nii')
+    return out_path_scaled, out_path_unscaled
+
+
+def bias_field_correction(image, output, dim=2):
+    N4BiasFieldCorrection(
+        input_image=image,
+        output_image=output,
+        dimension=dim
+    ).run()
+
+
+def image_data_to_nii(pixdim, image_data, shrink, out_dir, file_path, save_unscaled=False):
+    image_dim = (image_data.shape[0], image_data.shape[1])
+    scale = 1/shrink
+    new_dim = (round(image_dim[1] * scale), round(image_dim[0] * scale))
+    new_arr = np.ndarray(new_dim + (image_data.shape[2],))
+
+    for i in range(0,image_data.shape[2]):
+        cur_channel = image_data[:, :, i]
+        resized = np.array(Image.fromarray(cur_channel).resize(new_dim)).transpose()
+        new_arr[:, :, i] = resized
+    path_scaled, path_unscaled = get_out_paths(out_dir, Path(file_path).stem)
+    nii_scaled = nib.Nifti1Image(new_arr, np.eye(4))
+    nii_scaled.header['xyzt_units'] = 3
+    nii_scaled.header['pixdim'][1:3] = pixdim * shrink, pixdim * shrink
+    nib.save(nii_scaled, str(path_scaled))
+
+    if save_unscaled:
+        nii_unscaled = nib.Nifti1Image(image_data, np.eye(4))
+        nii_unscaled.header['xyzt_units'] = 3
+        nii_unscaled.header['pixdim'][1:3] = pixdim, pixdim
+        nib.save(nii_unscaled, str(path_unscaled))
+    print(f'Preprocessed:  {path_scaled}\n')
+    return path_scaled, path_unscaled
+
+
+def tiff_to_nii(tif_path, out_dir, pixdim=None, shrink=10):
+    Image.MAX_IMAGE_PIXELS = None
+    tif_image = Image.open(tif_path)
+    tif_header = dict(tif_image.tag)
+    output = np.empty(np.array(tif_image).shape + (0,))
+    if not pixdim:
+        pixdim = 10e6/tif_header[282][0][0]
+    for i, page in enumerate(ImageSequence.Iterator(tif_image)):
+        page_data = np.expand_dims(np.array(page), 2)
+        output = np.concatenate((output, page_data), 2)
+    return image_data_to_nii(pixdim, output, shrink, out_dir, tif_path)
+
+
+def split_nii_channels(nii_path, out_dir=None, flip=False, mask_index=-1, bias=False):
+    if out_dir is None:
+        out_dir = nii_path.parent
+    nii = nib.load(str(nii_path))
+    nii_data = nii.get_fdata()
+    nii_header = nii.header
+
+    if mask_index == -1:
+        mask_index = nii_data.shape[2] - 1
+    paths = []
+
+    for i in range(0, nii_data.shape[2]):
+        out_path = out_dir / f'im_c{i+1}.nii'
+        channel_data = nii_data[:, :, i]
+
+        if flip:
+            channel_data = np.flip(channel_data, 1)
+        if i == mask_index:
+            channel_data = mask_foreground(channel_data)
+
+        new_header = nii_header
+        new_header['dim'][0] = 2
+        nii = nib.Nifti1Image(channel_data, np.eye(4), header=new_header)
+        nib.save(nii, str(out_path))
+
+        if i == mask_index and bias:
+            bias_field_correction(str(out_path), str(out_path))
+            corrected = nib.load(str(out_path))
+            corrected_data = corrected.get_data()
+            corrected_normalized = corrected_data / np.mean(corrected_data[corrected_data != 0])
+            nii_corrected = nib.Nifti1Image(corrected_normalized, corrected.affine, corrected.header)
+            nib.save(nii_corrected, str(out_path))
+        paths.append(out_path)
+    return paths
+
+
+def mask_foreground(raw_data):
+    raw_max = raw_data.max()
+    raw_data = raw_data / raw_max
+    blurred_data = ndimage.gaussian_filter(raw_data, 4)
+
+    threshold = filters.threshold_otsu(raw_data) / 2
+    threshold_data = blurred_data > threshold
+    connected_structure = ndimage.generate_binary_structure(2, 2)  # Connects adjacent and diagonal.
+    padded_comp, padded_nr = ndimage.label(threshold_data, structure=connected_structure)
+
+    comps, comps_count = np.unique(padded_comp, return_counts=True)
+    comps_count, comps = zip(*sorted(zip(comps_count, comps), reverse=True))
+
+    two_biggest_cc = ((comps[0], np.average(comps[0])), (comps[1], np.average(comps[1])))
+    biggest_cc = max(two_biggest_cc, key=lambda a: a[1])[0]
+
+    foreground_mask = np.where(padded_comp == biggest_cc, True, False)
+    closed = morphology.binary_closing(foreground_mask, selem=morphology.square(30))
+    raw_data = np.where(closed, raw_data, 0)
+    return raw_data * raw_max
+    
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument("file", help="Location of file to process")
+    parser.add_argument("dir", help="Directory for preprocessed Files")
+    parser.add_argument("-s", "--series", type=int, help="Series to extract")
+    parser.add_argument('-b', type=bool, default= True, help='Bias field correct image')
+    parser.add_argument('--pdim', type=float, help='Pixel dimensions (Retrieved from Tiff file if not set)')
+    args = parser.parse_args()
+
+    filetype = Path(args.file).suffix.lower()
+    if filetype == '.tif' or filetype == '.tiff':
+        n_path = tiff_to_nii(Path(args.file), Path(args.dir))
+    elif filetype == '.nd2':
+        n_path = nd2_to_nii(Path(args.file), Path(args.dir), args.series)
+    elif filetype == '.nii':
+        n_path = Path(args.file)

+ 64 - 0
automatic_segmentation_program/registration.py

@@ -0,0 +1,64 @@
+from nipype.interfaces import ants
+import warnings
+warnings.simplefilter(action='ignore', category=FutureWarning)
+
+
+# Parameters for a full registration (Affine,Syn)  --- SLOW --- 
+def full_registration(fixed, moving, output='full_trans.nii', transform_pre=''):
+    ants.Registration(
+        dimension=2,
+        transforms=['Rigid', 'Affine', 'SyN'],
+        transform_parameters=[(0.15,), (0.15,), (0.3, 3, 0.5)],
+        metric=['CC', 'MI'] * 3,
+        convergence_threshold=[1.e-7],
+        number_of_iterations=[[1000, 1000, 1000, 1000], [1000, 1000, 1000, 1000],
+                              [1000, 1000, 1000, 1000, 1000]],
+        smoothing_sigmas=[[4, 4, 2, 2], [4, 4, 2, 2], [4, 2, 2, 2, 1]],
+        shrink_factors=[[32, 16, 8, 4], [32, 16, 8, 4], [16, 8, 4, 2, 1]],
+        initial_moving_transform_com=1,
+        interpolation='BSpline',
+        metric_weight=[0.5] * 6,
+        radius_or_number_of_bins=[4,32] * 3,
+        sampling_strategy=['Regular'] * 3,
+        use_histogram_matching=True,
+        winsorize_lower_quantile=0.05,
+        winsorize_upper_quantile=0.95,
+        collapse_output_transforms=True,
+        output_transform_prefix=transform_pre,
+        fixed_image=[fixed],
+        moving_image=[moving],
+        output_warped_image=output,
+        num_threads=12,
+        verbose=False,
+        write_composite_transform=True,
+        output_inverse_warped_image=True
+    ).run()
+    return output
+
+
+# A faster rigid registration
+def rigid_registration(fixed, moving, output='affine_trans.nii', transform_pre=''):
+    ants.Registration(
+        dimension=2,
+        transforms=['Rigid'],
+        transform_parameters=[(0.15,)],
+        metric=['MI'],
+        smoothing_sigmas=[[4, 4, 2]],
+        convergence_threshold=[1.e-6],
+        number_of_iterations=[[1000, 1000, 1000]],
+        shrink_factors=[[32, 16, 8]],
+        initial_moving_transform_com=0,
+        interpolation='BSpline',
+        metric_weight=[1],
+        radius_or_number_of_bins=[32],
+        sampling_strategy=['Regular'],
+        use_histogram_matching=True,
+        winsorize_lower_quantile=0.05,
+        winsorize_upper_quantile=0.95,
+        terminal_output='none',
+        output_transform_prefix=transform_pre,
+        fixed_image=[fixed],
+        moving_image=[moving],
+        output_warped_image=output
+    ).run()
+    return output

+ 115 - 0
automatic_segmentation_program/runner.py

@@ -0,0 +1,115 @@
+from pathlib import Path
+from pprint import pprint
+import argparse
+import sys
+import os
+import numpy as np
+import scipy.stats
+import nibabel as nib
+from nipype.interfaces import ants
+import registration
+import tools
+import preprocess
+
+from PIL import Image, ImageSequence
+import warnings
+warnings.simplefilter(action='ignore', category=FutureWarning)
+
+def find_optimal_index(slice_path, template_path, approx, temporary_directory, dist = 2):
+
+    weights = []
+    for i in range(approx-dist,approx+dist+1):
+        print(f'Slice index: {i}')
+        _slice = nib.load(str(slice_path))
+        template_slice_loc = str(temporary_directory / f't_slice{i}.nii')
+        registered_loc = str(temporary_directory / f'reg_slice{i}.nii')
+        transform_prefix = str(temporary_directory / f'{i}-')
+
+        tools.save_slice(str(template_path), template_slice_loc, i, np.eye(4), _slice.header)
+        registration.rigid_registration(template_slice_loc, str(slice_path), registered_loc, transform_prefix)
+
+        template_data = nib.load(str(template_slice_loc)).get_data()
+        registered_data = nib.load(registered_loc).get_data()
+        registered_data = tools.resize_im(registered_data, template_data)
+        nib.save(nib.Nifti1Image(registered_data, np.eye(4)), registered_loc)
+
+        mutualinfo = tools.mutual_info_mask(template_data, registered_data)
+        norm_factor = scipy.stats.norm.pdf(i-approx+dist, dist, dist*2) / scipy.stats.norm.pdf(dist, dist, dist*2)
+        dice_coef = tools.dice_coef(template_data, registered_data)
+        weights.append((i, norm_factor * (0.7 * mutualinfo + 0.3 * dice_coef)))
+
+    pprint(weights)
+    optimal = max(weights, key=lambda a: a[1])
+    print(optimal[0])
+    return optimal[0]
+
+def apply_transform(segmentation, fixed, index, out_dir, temporary_directory):
+
+    seg_slice_loc = out_dir / "segment_slice.nii" 
+    tools.save_slice(str(segmentation), str(seg_slice_loc), index, np.eye(4), nib.load(str(fixed)).header)
+    #  post_transform_loc = out_dir / f'{seg_slice_loc.stem}_t.nii'
+    post_transform_loc = out_dir / f'Segmentation.nii'
+
+    template_slice_loc = str(temporary_directory / f't_slice{index}.nii')
+    registration.full_registration(str(template_slice_loc), str(fixed),
+                                   str(Path(out_dir, f'{fixed.stem}_template.nii')), str(out_dir / f'TemplateSlice_'))
+                                                             
+    transform = [str(out_dir / f"TemplateSlice_InverseComposite.h5")]
+    reverse_transform = ants.ApplyTransforms(
+            input_image=str(seg_slice_loc),
+            reference_image=str(fixed),
+            transforms=transform,
+            invert_transform_flags=[False],
+            interpolation='MultiLabel',
+            dimension=2,
+            output_image=str(post_transform_loc))
+    reverse_transform.run()
+
+    return post_transform_loc
+
+def generate_segmentation(slice_path, segment_path, template_path, approx_index, dapi, output_directory):     
+        
+    if slice_path.endswith('.tiff') or slice_path.endswith('.tif'):
+        Image.MAX_IMAGE_PIXELS = None
+        #  tif_image = Image.open(slice_path)
+        #  ch_nm = (list(tif_image.tag_v2[270].split('='))[2])
+        #  if int(ch_nm[0]) < dapi:
+        #      print(f'DAPI channel out of range, max range is {ch_nm[0]} channels')
+        #      sys.exit()
+        tif_path = preprocess.tiff_to_nii(slice_path, output_directory)
+        slice_path = tif_path[0]
+    elif not slice_path.endswith(".nii"):
+        print('Selected slice file was neither a nifti or a TIFF file. Please check slice file format')
+        sys.exit(2)
+        
+    out_subdir = Path(output_directory) / Path(slice_path).stem.replace(' ', '_').replace(',','_').replace('.','')
+    tools.create_dir(out_subdir)
+    temporary_directory = Path(output_directory) / Path(out_subdir) / 'tmp'
+    tools.create_dir(temporary_directory)
+    channel_paths = preprocess.split_nii_channels(slice_path, out_subdir, True, dapi-1)
+    optimal_index = find_optimal_index(channel_paths[dapi-1], template_path, approx_index, temporary_directory)
+    seg_loc = apply_transform(segment_path, channel_paths[dapi-1], optimal_index, out_subdir, temporary_directory)
+    tools.remove_dir(temporary_directory)
+    os.remove(Path(out_subdir) / 'segment_slice.nii')
+    os.rename(Path(out_subdir) / 'TemplateSlice_Composite.h5', Path(out_subdir) / 'Transformation_Composite.h5')
+    os.rename(Path(out_subdir) / 'TemplateSlice_InverseComposite.h5', Path(out_subdir) / 'Transformation_InverseComposite.h5')
+    return seg_loc
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("sliceloc", metavar="Slice Location", help="Location of preprocessed slice")
+    parser.add_argument("segloc", metavar="Segmentation Location", help="Location of segmentation")
+    parser.add_argument("tloc", metavar="Template Location", help="Location of template file")
+    parser.add_argument("--approx", type = int, default=-1, help="Approximate index of slice relative to template file")
+    parser.add_argument("bregma", metavar = 'Bregma index', type = float, help="Approx bregma coordinates")
+    parser.add_argument("out", metavar = 'Output directory', default="output", help="Output directory")
+    parser.add_argument("--dapi", type=int, default=0, help="DAPI channel number, default is last channel")
+
+    args = parser.parse_args()
+    if args.approx == -1 and not args.bregma:
+        print("Error: Please specify an approximate location in the template, or a bregma coordinate of the slice")
+        sys.exit(2)
+    if args.bregma:
+        args.approx = tools.bregma_to_slice_index(args.bregma)
+    generate_segmentation(args.sliceloc, args.segloc, args.tloc, args.approx, args.dapi, args.out)
+    

+ 49 - 0
automatic_segmentation_program/tools.py

@@ -0,0 +1,49 @@
+import os
+import nibabel as nib
+from PIL import Image
+import numpy as np
+from sklearn.metrics import mutual_info_score
+import shutil
+
+from preprocess import mask_foreground
+def save_slice(template, filename, index, affine=np.eye(4), header=None):
+    _slice = nib.load(str(template))
+    _slice_data = _slice.get_data()
+
+    newimg = nib.Nifti1Image(_slice_data[:, index, :], affine, header)
+    newimg.header['pixdim'] = header['pixdim']
+
+    if header is None:
+        newimg.header['pixdim'][1:3] = _slice.header['pixdim'][1], _slice.header['pixdim'][3]
+    newimg.to_filename(str(filename))
+    return filename
+
+def create_dir(name):
+    if not os.path.exists(name):
+        os.mkdir(name)
+def remove_dir(loc):
+    shutil.rmtree(loc)
+    
+def mutual_info_mask(slice1,slice2, bins=32):
+    slice1_masked = mask_foreground(slice1)
+    slice2_masked = mask_foreground(slice2)
+    common = np.logical_and(slice1_masked, slice2_masked)
+    slice1, slice2 = np.where(common, slice1, 0), np.where(common, slice2, 0)
+    hist = np.histogram2d(slice1.ravel(), slice2.ravel(), bins=bins)[0]
+    return mutual_info_score(None, None, contingency=hist)
+
+
+def dice_coef(slice1, slice2):
+    mask_1 = mask_foreground(slice1).astype(np.bool)
+    mask_2 = mask_foreground(slice2).astype(np.bool)
+    intersection = np.logical_and(mask_1, mask_2)
+    return 2. * intersection.sum() / (mask_1.sum() + mask_2.sum())
+
+
+def resize_im(image_data, image_with_dimensions):
+    dimensions = (image_with_dimensions.shape[1], image_with_dimensions.shape[0])
+    resized = np.array(Image.fromarray(image_data).resize(dimensions))
+    return resized
+
+def bregma_to_slice_index(bregma):
+   return round(27.908*bregma + 116.831)