Ver código fonte

gin commit from SUND33778

New files: 4
Modified files: 1
Frederik Filip Stæger 5 anos atrás
pai
commit
37fa63a08a
5 arquivos alterados com 462 adições e 0 exclusões
  1. 16 0
      README.md
  2. 121 0
      auto_seg.py
  3. 169 0
      preprocess.py
  4. 62 0
      registration.py
  5. 94 0
      tools.py

+ 16 - 0
README.md

@@ -0,0 +1,16 @@
+# Automatic slice segmentation
+
+Program for automatically registering a segmentation to a DAPI-stained slice.
+
+## Dependencies
+* Python 3.6
+* Ants (https://github.com/ANTsX/ANTs)
+
+##Prerequisites
+* A high-resolution, DAPI modality TIFF image of a brain slice. Can have multiple channels.
+* A template file (.nii or .nii.gz)
+* A segmentation of the template file  (.nii or .nii.gz)
+
+## Usage:
+* Use preprocess.py to prepare input slice for automatic registration
+* Use auto_seg.py with preprocessed slice which will output a folder containing segmentation registered to slice.

+ 121 - 0
auto_seg.py

@@ -0,0 +1,121 @@
+# System imports
+from pathlib import Path
+from pprint import pprint
+import argparse
+import sys
+# 3rd Party imports
+import numpy as np
+import scipy.stats
+import nibabel as nib
+from nipype.interfaces import ants
+# Local source tree imports
+import registration
+import tools
+import preprocess as pre
+
+TEMP_DIR = Path('temp')
+OUTPUT_DIR = Path('output')
+IMG_DIR = Path('Images')
+
+
+def find_optimal_index(slice_path, template_path, approx, dist=2):
+    """
+    Finds the most appropriate index in template path to register the slice file to.
+    :param approx: Approximate index of slice relative to template file
+    :param dist: maximum distance of slices away from approximate to index to check
+    :return: index of slice in template which has highest metric - combined mutual information and dice coefficient.
+    """
+    weights = []
+    for i in range(approx-dist,approx+dist+1):
+        print(f'Slice index: {i}')
+        _slice = nib.load(str(slice_path))
+        template_slice_loc = str(TEMP_DIR/f't_slice{i}.nii')
+        registered_loc = str(TEMP_DIR/ f'reg_slice{i}.nii')
+        transform_prefix = str(TEMP_DIR/f'{i}-')
+
+        tools.save_slice(str(template_path), template_slice_loc, i, np.eye(4), _slice.header)
+        registration.rigid_registration(template_slice_loc, str(slice_path), registered_loc, transform_prefix)
+
+        template_data = nib.load(str(template_slice_loc)).get_data()
+        registered_data = nib.load(registered_loc).get_data()
+        registered_data = tools.resize_im(registered_data, template_data)
+        nib.save(nib.Nifti1Image(registered_data, np.eye(4)), registered_loc)
+
+        mutualinfo = tools.mutual_info_mask(template_data, registered_data)
+        norm_factor = scipy.stats.norm.pdf(i-approx+dist, dist, dist*2) / scipy.stats.norm.pdf(dist, dist, dist*2)
+        dice_coef = tools.dice_coef(template_data, registered_data)
+        weights.append((i, norm_factor * (0.7 * mutualinfo + 0.3 * dice_coef)))
+
+    pprint(weights)
+    optimal = max(weights, key=lambda a: a[1])
+    return optimal[0]
+
+
+def apply_transform(segmentation, fixed, index, out_dir):
+    """
+    Full transforms the slice->template slice, and then applies the reverse transform to the segmentation slice, to get
+    a segmentation of the slice.
+    :return: Location of the segmentation slice registered to the slice
+    """
+    seg_slice_loc = out_dir / "segment_slice.nii"
+    tools.save_slice(str(segmentation), str(seg_slice_loc), index, np.eye(4), nib.load(str(fixed)).header)
+    post_transform_loc = out_dir / f'{seg_slice_loc.stem}_t.nii'
+
+    template_slice_loc = str(TEMP_DIR / f't_slice{index}.nii')
+
+    registration.full_registration(str(template_slice_loc), str(fixed),
+                                   str(Path(out_dir, f'{fixed.stem}_t.nii')), str(out_dir / f'Final-'))
+    transform = [str(out_dir / f"Final-InverseComposite.h5")]
+
+    reverse_transform = ants.ApplyTransforms(
+            input_image=str(seg_slice_loc),
+            reference_image=str(fixed),
+            transforms=transform,
+            invert_transform_flags=[False],
+            interpolation='MultiLabel',
+            dimension=2,
+            output_image=str(post_transform_loc))
+
+    reverse_transform.run()
+
+    return post_transform_loc
+
+
+def generate_segmentation(slice_path, segment_path, template_path, approx_index, dapi):
+    """
+    Provided path to slice, segmentation, template and an approximate index, will apply multiple registrations to create
+    an automatic segmentation of the slice
+    :param dapi: index of DAPI modality channel
+    :return: Location of the segmentation slice registered to the slice
+    """
+    out_subdir = Path(OUT_DIR) / Path(slice_path).stem.replace(' ', '').replace(',', '').replace('.', '')
+    tools.create_dir(TEMP_DIR)
+    tools.create_dir(out_subdir)
+
+    channel_paths = pre.split_nii_channels(slice_path, out_subdir, False, dapi-1)
+
+    optimal_index = find_optimal_index(channel_paths[dapi-1], template_path, approx_index)
+
+    seg_loc = apply_transform(segment_path, channel_paths[dapi-1], optimal_index, out_subdir)
+    tools.remove_dir(TEMP_DIR)
+    return seg_loc
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("sliceloc", metavar="Slice Location", help="Location of preprocessed slice")
+    parser.add_argument("segloc", metavar="Segmentation Location", help="Location of segmentation")
+    parser.add_argument("tloc", metavar="Template Location", help="Location of template file")
+    parser.add_argument("--approx", type=int, default=-1, help="Approximate index of slice relative to template file")
+    parser.add_argument("--bregma", type=float, help="Approx bregma coordinates")
+    parser.add_argument("--out", default="output", help="Output directory")
+    parser.add_argument("--dapi", type=int, default=0, help="DAPI channel number, default is last channel")
+
+    args = parser.parse_args()
+    OUT_DIR = Path(args.out)
+    if args.approx == -1 and not args.bregma:
+        print("Error: please specify an approximate location in the template, or a bregma coordinate of the slice")
+        sys.exit(2)
+    if args.bregma:
+        args.approx = tools.bregma_to_slice_index(args.bregma)
+    generate_segmentation(args.sliceloc, args.segloc, args.tloc, args.approx, args.dapi)

+ 169 - 0
preprocess.py

@@ -0,0 +1,169 @@
+from pathlib import Path
+import xmltodict
+import argparse
+
+import javabridge
+import bioformats
+import numpy as np
+import nibabel as nib
+from PIL import Image, ImageSequence
+from scipy import ndimage
+from skimage import morphology, filters
+from nipype.interfaces.ants.segmentation import N4BiasFieldCorrection
+
+
+def get_out_paths(out_dir, stem):
+    out_path_scaled = Path(out_dir, f'{stem}.nii')
+    out_path_unscaled = Path(out_dir, f'{stem}_us.nii')
+    return out_path_scaled, out_path_unscaled
+
+
+def bias_field_correction(image, output, dim=2):
+    N4BiasFieldCorrection(
+        input_image=image,
+        output_image=output,
+        dimension=dim
+    ).run()
+
+
+def image_data_to_nii(pixdim, image_data, shrink, out_dir, file_path, save_unscaled=False):
+    """
+    Saves an array of image data as a nifti file, with correct pixel dimensions in
+    :param save_unscaled: boolean whether to save unscaled, as well as scaled nifti file
+    :param shrink: factor by which to shrink image dimensions by
+    :return: location of scaled and possibly unscaled nifti files
+    """
+    image_dim = (image_data.shape[0], image_data.shape[1])
+    scale = 1/shrink
+    new_dim = (round(image_dim[1] * scale), round(image_dim[0] * scale))
+    new_arr = np.ndarray(new_dim + (image_data.shape[2],))
+
+    for i in range(0,image_data.shape[2]):
+        cur_channel = image_data[:, :, i]
+        resized = np.array(Image.fromarray(cur_channel).resize(new_dim)).transpose()
+        new_arr[:, :, i] = resized
+    path_scaled, path_unscaled = get_out_paths(out_dir, Path(file_path).stem)
+    nii_scaled = nib.Nifti1Image(new_arr, np.eye(4))
+    nii_scaled.header['xyzt_units'] = 3
+    nii_scaled.header['pixdim'][1:3] = pixdim * shrink, pixdim * shrink
+    nib.save(nii_scaled, str(path_scaled))
+
+    if save_unscaled:
+        nii_unscaled = nib.Nifti1Image(image_data, np.eye(4))
+        nii_unscaled.header['xyzt_units'] = 3
+        nii_unscaled.header['pixdim'][1:3] = pixdim, pixdim
+        nib.save(nii_unscaled, str(path_unscaled))
+    print(f'Preprocessed:  {path_scaled}\n')
+    return path_scaled, path_unscaled
+
+
+def nd2_to_nii(nd2_path, out_dir, series, shrink=10):
+    """
+    Wrapper function for image_data_to_nii, for converting nd2 files to nifti
+    """
+    javabridge.start_vm(bioformats.JARS)
+    image = np.array(bioformats.load_image(str(nd2_path), series=series-1, rescale=False))
+    meta_dict = xmltodict.parse(bioformats.get_omexml_metadata(str(nd2_path)))
+    vox_size_x = float(meta_dict['OME']['Image'][series-1]['Pixels']['@PhysicalSizeX'])
+    javabridge.kill_vm()
+    return image_data_to_nii(vox_size_x, image, shrink, out_dir, nd2_path)
+
+
+def tiff_to_nii(tif_path, out_dir, pixdim=None, shrink=10):
+    """
+    Wrapper function for image_data_to_nii, for converting tiff files to nifti
+    """
+    tif_image = Image.open(tif_path)
+    tif_header = dict(tif_image.tag)
+    output = np.empty(np.array(tif_image).shape + (0,))
+    if not pixdim:
+        pixdim = 10e6/tif_header[282][0][0]
+    for i, page in enumerate(ImageSequence.Iterator(tif_image)):
+        page_data = np.expand_dims(np.array(page), 2)
+        output = np.concatenate((output, page_data), 2)
+    return image_data_to_nii(pixdim, output, shrink, out_dir, tif_path)
+
+
+def split_nii_channels(nii_path, out_dir=None, flip=False, mask_index=-1, bias=False):
+    """
+    Converts a single multi-channel nifti file to multiple single-channel nifti files, and masks foreground
+    :param flip: Whether to vertically flip the image, in order to properly align nifti file with template
+    :param mask_index: index of DAPI stained channel, on which to mask other channels on
+    :param bias: whether to bias-field correct the image
+    :return: Location of multiple single-channel nifti files
+    """
+    if out_dir is None:
+        out_dir = nii_path.parent
+    nii = nib.load(str(nii_path))
+    nii_data = nii.get_data()
+    nii_header = nii.header
+
+    if mask_index == -1:
+        mask_index = nii_data.shape[2] - 1
+    paths = []
+
+    for i in range(0, nii_data.shape[2]):
+        out_path = out_dir / f'im_c{i+1}.nii'
+        channel_data = nii_data[:, :, i]
+
+        if flip:
+            channel_data = np.flip(channel_data, 1)
+
+        if i == mask_index:
+            channel_data = mask_foreground(channel_data)
+
+        new_header = nii_header
+        new_header['dim'][0] = 2
+        nii = nib.Nifti1Image(channel_data, np.eye(4), header=new_header)
+        nib.save(nii, str(out_path))
+
+        if i == mask_index and bias:
+            bias_field_correction(str(out_path), str(out_path))
+            corrected = nib.load(str(out_path))
+            corrected_data = corrected.get_data()
+            corrected_normalized = corrected_data / np.mean(corrected_data[corrected_data != 0])
+            nii_corrected = nib.Nifti1Image(corrected_normalized, corrected.affine, corrected.header)
+            nib.save(nii_corrected, str(out_path))
+        paths.append(out_path)
+    return paths
+
+
+def mask_foreground(raw_data):
+    """
+        Mask the foreground of an image, using otsu threshold and connected components to remove background noise
+    """
+    raw_max = raw_data.max()
+    raw_data = raw_data / raw_max
+    blurred_data = ndimage.gaussian_filter(raw_data, 4)
+
+    threshold = filters.threshold_otsu(raw_data) / 2
+    threshold_data = blurred_data > threshold
+    connected_structure = ndimage.generate_binary_structure(2, 2)  # Connects adjacent and diagonal.
+    padded_comp, padded_nr = ndimage.label(threshold_data, structure=connected_structure)
+
+    comps, comps_count = np.unique(padded_comp, return_counts=True)
+    comps_count, comps = zip(*sorted(zip(comps_count, comps), reverse=True))
+
+    two_biggest_cc = ((comps[0], np.average(comps[0])), (comps[1], np.average(comps[1])))
+    biggest_cc = max(two_biggest_cc, key=lambda a: a[1])[0]
+
+    foreground_mask = np.where(padded_comp == biggest_cc, True, False)
+    closed = morphology.binary_closing(foreground_mask, selem=morphology.square(30))
+    raw_data = np.where(closed, raw_data, 0)
+    return raw_data * raw_max
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument("file", help="Location of file to process")
+    parser.add_argument("dir", help="Directory for preprocessed Files")
+    parser.add_argument("-s", type=int, help="Series to extract")
+    parser.add_argument('-b', type=bool, default= True, help='Bias field correct image')
+    parser.add_argument('--pdim', type=float, help='Pixel dimensions (Retrieved from Tiff file if not set)')
+    args = parser.parse_args()
+
+    filetype = Path(args.file).suffix.lower()
+    if filetype == '.tif' or filetype == '.tiff':
+        n_path = tiff_to_nii(Path(args.file), Path(args.dir))
+    elif filetype == '.nd2':
+        n_path = nd2_to_nii(Path(args.file), Path(args.dir), args.series)

+ 62 - 0
registration.py

@@ -0,0 +1,62 @@
+from nipype.interfaces import ants
+
+
+# Parameters for a full registration (Affine,Syn)  --- SLOW ---
+def full_registration(fixed, moving, output='full_trans.nii', transform_pre=''):
+    ants.Registration(
+        dimension=2,
+        transforms=['Rigid', 'Affine', 'SyN'],
+        transform_parameters=[(0.15,), (0.15,), (0.3, 3, 0.5)],
+        metric=['CC', 'MI'] * 3,
+        convergence_threshold=[1.e-7],
+        number_of_iterations=[[1000, 1000, 1000, 1000], [1000, 1000, 1000, 1000],
+                              [1000, 1000, 1000, 1000, 1000]],
+        smoothing_sigmas=[[4, 4, 2, 2], [4, 4, 2, 2], [4, 2, 2, 2, 1]],
+        shrink_factors=[[32, 16, 8, 4], [32, 16, 8, 4], [16, 8, 4, 2, 1]],
+        initial_moving_transform_com=1,
+        interpolation='BSpline',
+        metric_weight=[0.5] * 6,
+        radius_or_number_of_bins=[4,32] * 3,
+        sampling_strategy=['Regular'] * 3,
+        use_histogram_matching=True,
+        winsorize_lower_quantile=0.05,
+        winsorize_upper_quantile=0.95,
+        collapse_output_transforms=True,
+        output_transform_prefix=transform_pre,
+        fixed_image=[fixed],
+        moving_image=[moving],
+        output_warped_image=output,
+        num_threads=12,
+        verbose=False,
+        write_composite_transform=True,
+        output_inverse_warped_image=True
+    ).run()
+    return output
+
+
+# A faster rigid registration
+def rigid_registration(fixed, moving, output='affine_trans.nii', transform_pre=''):
+    ants.Registration(
+        dimension=2,
+        transforms=['Rigid'],
+        transform_parameters=[(0.15,)],
+        metric=['MI'],
+        smoothing_sigmas=[[4, 4, 2]],
+        convergence_threshold=[1.e-6],
+        number_of_iterations=[[1000, 1000, 1000]],
+        shrink_factors=[[32, 16, 8]],
+        initial_moving_transform_com=0,
+        interpolation='BSpline',
+        metric_weight=[1],
+        radius_or_number_of_bins=[32],
+        sampling_strategy=['Regular'],
+        use_histogram_matching=True,
+        winsorize_lower_quantile=0.05,
+        winsorize_upper_quantile=0.95,
+        terminal_output='none',
+        output_transform_prefix=transform_pre,
+        fixed_image=[fixed],
+        moving_image=[moving],
+        output_warped_image=output
+    ).run()
+    return output

+ 94 - 0
tools.py

@@ -0,0 +1,94 @@
+import os
+import shutil
+
+import nibabel as nib
+from PIL import Image
+import numpy as np
+from sklearn.metrics import mutual_info_score
+
+from preprocess import mask_foreground
+from pprint import pprint
+
+
+def save_slice(template, filename, index, affine=np.eye(4), header=None):
+    _slice = nib.load(str(template))
+    _slice_data = _slice.get_data()
+
+    newimg = nib.Nifti1Image(_slice_data[:, index, :], affine, header)
+    newimg.header['pixdim'] = header['pixdim']
+
+    if header is None:
+        newimg.header['pixdim'][1:3] = _slice.header['pixdim'][1], _slice.header['pixdim'][3]
+    newimg.to_filename(str(filename))
+    return filename
+
+
+# Functions for temporary directory
+def create_dir(name):
+    if not os.path.exists(name):
+        os.mkdir(name)
+
+
+# Functions for calculating mutual info of two slices
+def remove_dir(loc):
+    shutil.rmtree(loc)
+
+
+def remove_nan(input_array):
+    output_list = input_array.ravel()
+    output_list[np.isnan(output_list)] = 0
+    return output_list
+
+
+def mutual_info(slice1, slice2, bins=32):
+    slice1, slice2 = remove_nan(slice1), remove_nan(slice2)
+    hist = np.histogram2d(slice1, slice2, bins=bins)[0]
+    return mutual_info_score(None, None, contingency=hist)
+
+
+def mutual_info_mask(slice1,slice2, bins=32):
+    slice1_masked = mask_foreground(slice1)
+    slice2_masked = mask_foreground(slice2)
+    common = np.logical_and(slice1_masked, slice2_masked)
+    slice1, slice2 = np.where(common, slice1, 0), np.where(common, slice2, 0)
+    hist = np.histogram2d(slice1.ravel(), slice2.ravel(), bins=bins)[0]
+    return mutual_info_score(None, None, contingency=hist)
+
+
+def dice_coef(slice1, slice2):
+    mask_1 = mask_foreground(slice1).astype(np.bool)
+    mask_2 = mask_foreground(slice2).astype(np.bool)
+    intersection = np.logical_and(mask_1, mask_2)
+    return 2. * intersection.sum() / (mask_1.sum() + mask_2.sum())
+
+
+def resize_im(image_data, image_with_dimensions):
+    dimensions = (image_with_dimensions.shape[1], image_with_dimensions.shape[0])
+    resized = np.array(Image.fromarray(image_data).resize(dimensions))
+    return resized
+
+
+def slice_num_to_index(slice_num):
+    index_dict = {
+        1: 175,
+        2: 154,
+        3: 123,
+        4: 112,
+        5: 93,
+        6: 75,
+        7: 31,
+        8: 7,
+        9: 119,
+        10: 100,
+        11: 83,
+        12: 71
+    }
+    return index_dict[slice_num]
+
+
+def bregma_to_slice_index(bregma):
+    return round(27.908*bregma + 116.831)
+
+
+def slice_index_to_bregma(slice_index):
+    return round(0.03564*slice_index - 4.168)