tools.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import os
  2. import shutil
  3. import nibabel as nib
  4. from PIL import Image
  5. import numpy as np
  6. from sklearn.metrics import mutual_info_score
  7. from preprocess import mask_foreground
  8. from pprint import pprint
  9. def save_slice(template, filename, index, affine=np.eye(4), header=None):
  10. _slice = nib.load(str(template))
  11. _slice_data = _slice.get_data()
  12. newimg = nib.Nifti1Image(_slice_data[:, index, :], affine, header)
  13. newimg.header['pixdim'] = header['pixdim']
  14. if header is None:
  15. newimg.header['pixdim'][1:3] = _slice.header['pixdim'][1], _slice.header['pixdim'][3]
  16. newimg.to_filename(str(filename))
  17. return filename
  18. # Functions for temporary directory
  19. def create_dir(name):
  20. if not os.path.exists(name):
  21. os.mkdir(name)
  22. # Functions for calculating mutual info of two slices
  23. def remove_dir(loc):
  24. shutil.rmtree(loc)
  25. def remove_nan(input_array):
  26. output_list = input_array.ravel()
  27. output_list[np.isnan(output_list)] = 0
  28. return output_list
  29. def mutual_info(slice1, slice2, bins=32):
  30. slice1, slice2 = remove_nan(slice1), remove_nan(slice2)
  31. hist = np.histogram2d(slice1, slice2, bins=bins)[0]
  32. return mutual_info_score(None, None, contingency=hist)
  33. def mutual_info_mask(slice1,slice2, bins=32):
  34. slice1_masked = mask_foreground(slice1)
  35. slice2_masked = mask_foreground(slice2)
  36. common = np.logical_and(slice1_masked, slice2_masked)
  37. slice1, slice2 = np.where(common, slice1, 0), np.where(common, slice2, 0)
  38. hist = np.histogram2d(slice1.ravel(), slice2.ravel(), bins=bins)[0]
  39. return mutual_info_score(None, None, contingency=hist)
  40. def dice_coef(slice1, slice2):
  41. mask_1 = mask_foreground(slice1).astype(np.bool)
  42. mask_2 = mask_foreground(slice2).astype(np.bool)
  43. intersection = np.logical_and(mask_1, mask_2)
  44. return 2. * intersection.sum() / (mask_1.sum() + mask_2.sum())
  45. def resize_im(image_data, image_with_dimensions):
  46. dimensions = (image_with_dimensions.shape[1], image_with_dimensions.shape[0])
  47. resized = np.array(Image.fromarray(image_data).resize(dimensions))
  48. return resized
  49. def slice_num_to_index(slice_num):
  50. index_dict = {
  51. 1: 175,
  52. 2: 154,
  53. 3: 123,
  54. 4: 112,
  55. 5: 93,
  56. 6: 75,
  57. 7: 31,
  58. 8: 7,
  59. 9: 119,
  60. 10: 100,
  61. 11: 83,
  62. 12: 71
  63. }
  64. return index_dict[slice_num]
  65. def bregma_to_slice_index(bregma):
  66. return round(27.908*bregma + 116.831)
  67. def slice_index_to_bregma(slice_index):
  68. return round(0.03564*slice_index - 4.168)