import os import shutil import nibabel as nib from PIL import Image import numpy as np from sklearn.metrics import mutual_info_score from preprocess import mask_foreground from pprint import pprint def save_slice(template, filename, index, affine=np.eye(4), header=None): _slice = nib.load(str(template)) _slice_data = _slice.get_data() newimg = nib.Nifti1Image(_slice_data[:, index, :], affine, header) newimg.header['pixdim'] = header['pixdim'] if header is None: newimg.header['pixdim'][1:3] = _slice.header['pixdim'][1], _slice.header['pixdim'][3] newimg.to_filename(str(filename)) return filename # Functions for temporary directory def create_dir(name): if not os.path.exists(name): os.mkdir(name) # Functions for calculating mutual info of two slices def remove_dir(loc): shutil.rmtree(loc) def remove_nan(input_array): output_list = input_array.ravel() output_list[np.isnan(output_list)] = 0 return output_list def mutual_info(slice1, slice2, bins=32): slice1, slice2 = remove_nan(slice1), remove_nan(slice2) hist = np.histogram2d(slice1, slice2, bins=bins)[0] return mutual_info_score(None, None, contingency=hist) def mutual_info_mask(slice1,slice2, bins=32): slice1_masked = mask_foreground(slice1) slice2_masked = mask_foreground(slice2) common = np.logical_and(slice1_masked, slice2_masked) slice1, slice2 = np.where(common, slice1, 0), np.where(common, slice2, 0) hist = np.histogram2d(slice1.ravel(), slice2.ravel(), bins=bins)[0] return mutual_info_score(None, None, contingency=hist) def dice_coef(slice1, slice2): mask_1 = mask_foreground(slice1).astype(np.bool) mask_2 = mask_foreground(slice2).astype(np.bool) intersection = np.logical_and(mask_1, mask_2) return 2. * intersection.sum() / (mask_1.sum() + mask_2.sum()) def resize_im(image_data, image_with_dimensions): dimensions = (image_with_dimensions.shape[1], image_with_dimensions.shape[0]) resized = np.array(Image.fromarray(image_data).resize(dimensions)) return resized def slice_num_to_index(slice_num): index_dict = { 1: 175, 2: 154, 3: 123, 4: 112, 5: 93, 6: 75, 7: 31, 8: 7, 9: 119, 10: 100, 11: 83, 12: 71 } return index_dict[slice_num] def bregma_to_slice_index(bregma): return round(27.908*bregma + 116.831) def slice_index_to_bregma(slice_index): return round(0.03564*slice_index - 4.168)