preprocess.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. from pathlib import Path
  2. import argparse
  3. import numpy as np
  4. import nibabel as nib
  5. from PIL import Image, ImageSequence
  6. from scipy import ndimage
  7. from skimage import morphology, filters
  8. from nipype.interfaces.ants.segmentation import N4BiasFieldCorrection
  9. import warnings
  10. warnings.simplefilter(action='ignore', category=FutureWarning)
  11. def get_out_paths(out_dir, stem):
  12. out_path_scaled = Path(out_dir, f'{stem}.nii')
  13. out_path_unscaled = Path(out_dir, f'{stem}_us.nii')
  14. return out_path_scaled, out_path_unscaled
  15. def bias_field_correction(image, output, dim=2):
  16. print('Running N4BiasFieldCorrection on input slice.')
  17. N4BiasFieldCorrection(
  18. input_image=image,
  19. output_image=output,
  20. dimension=dim
  21. ).run()
  22. def image_data_to_nii(pixdim, image_data, shrink, out_dir, file_path, save_unscaled=False):
  23. image_dim = (image_data.shape[0], image_data.shape[1])
  24. scale = 1/shrink
  25. new_dim = (round(image_dim[1] * scale), round(image_dim[0] * scale))
  26. new_arr = np.ndarray(new_dim + (image_data.shape[2],))
  27. print(f'Pixel dimensions: {pixdim} mm')
  28. for i in range(0,image_data.shape[2]):
  29. cur_channel = image_data[:, :, i]
  30. resized = np.array(Image.fromarray(cur_channel).resize(new_dim)).transpose()
  31. new_arr[:, :, i] = resized
  32. path_scaled, path_unscaled = get_out_paths(out_dir, Path(file_path).stem)
  33. nii_scaled = nib.Nifti1Image(new_arr, np.eye(4))
  34. nii_scaled.header['pixdim'][1:3] = pixdim * shrink, pixdim * shrink
  35. nib.save(nii_scaled, str(path_scaled))
  36. if save_unscaled:
  37. nii_unscaled = nib.Nifti1Image(image_data, np.eye(4))
  38. nii_unscaled.header['xyzt_units'] = 3
  39. nii_unscaled.header['pixdim'][1:3] = pixdim, pixdim
  40. nib.save(nii_unscaled, str(path_unscaled))
  41. print(f'Preprocessed: {path_scaled}\n')
  42. return path_scaled, path_unscaled
  43. def tiff_to_nii(tif_path, out_dir, pixdim=None, shrink=10):
  44. Image.MAX_IMAGE_PIXELS = None
  45. tif_image = Image.open(tif_path)
  46. tif_header = dict(tif_image.tag)
  47. output = np.empty(np.array(tif_image).shape + (0,))
  48. if not pixdim:
  49. pixdim = 10e2/tif_header[282][0][0]
  50. for i, page in enumerate(ImageSequence.Iterator(tif_image)):
  51. page_data = np.expand_dims(np.array(page), 2)
  52. output = np.concatenate((output, page_data), 2)
  53. return image_data_to_nii(pixdim, output, shrink, out_dir, tif_path)
  54. def split_nii_channels(nii_path, out_dir=None, flip=False, mask_index=-1, bias=True):
  55. if out_dir is None:
  56. out_dir = nii_path.parent
  57. nii = nib.load(str(nii_path))
  58. nii_data = nii.get_fdata()
  59. nii_header = nii.header
  60. if mask_index == -1:
  61. mask_index = nii_data.shape[2] - 1
  62. paths = []
  63. for i in range(0, nii_data.shape[2]):
  64. out_path = out_dir / f'im_c{i+1}.nii'
  65. channel_data = nii_data[:, :, i]
  66. if flip:
  67. channel_data = np.flip(channel_data, 1)
  68. if i == mask_index:
  69. channel_data = mask_foreground(channel_data)
  70. new_header = nii_header
  71. new_header['dim'][0] = 2
  72. nii = nib.Nifti1Image(channel_data, np.eye(4), header=new_header)
  73. nib.save(nii, str(out_path))
  74. if i == mask_index and bias:
  75. bias_field_correction(str(out_path), str(out_path))
  76. corrected = nib.load(str(out_path))
  77. corrected_data = corrected.get_fdata()
  78. corrected_normalized = corrected_data / np.mean(corrected_data[corrected_data != 0])
  79. nii_corrected = nib.Nifti1Image(corrected_normalized, corrected.affine, corrected.header)
  80. nib.save(nii_corrected, str(out_path))
  81. paths.append(out_path)
  82. return paths
  83. def mask_foreground(raw_data):
  84. raw_max = raw_data.max()
  85. raw_data = raw_data / raw_max
  86. blurred_data = ndimage.gaussian_filter(raw_data, 4)
  87. threshold = filters.threshold_otsu(raw_data) / 2
  88. threshold_data = blurred_data > threshold
  89. connected_structure = ndimage.generate_binary_structure(2, 2) # Connects adjacent and diagonal.
  90. padded_comp, padded_nr = ndimage.label(threshold_data, structure=connected_structure)
  91. comps, comps_count = np.unique(padded_comp, return_counts=True)
  92. comps_count, comps = zip(*sorted(zip(comps_count, comps), reverse=True))
  93. two_biggest_cc = ((comps[0], np.average(comps[0])), (comps[1], np.average(comps[1])))
  94. biggest_cc = max(two_biggest_cc, key=lambda a: a[1])[0]
  95. foreground_mask = np.where(padded_comp == biggest_cc, True, False)
  96. closed = morphology.binary_closing(foreground_mask, selem=morphology.square(30))
  97. raw_data = np.where(closed, raw_data, 0)
  98. return raw_data * raw_max
  99. if __name__ == '__main__':
  100. parser = argparse.ArgumentParser()
  101. parser.add_argument("file", help="Location of file to process")
  102. parser.add_argument("dir", help="Directory for preprocessed Files")
  103. parser.add_argument("-s", "--series", type=int, help="Series to extract")
  104. parser.add_argument('-b', type=bool, default= True, help='Bias field correct image')
  105. parser.add_argument('--pdim', type=float, help='Pixel dimensions (Retrieved from Tiff file if not set)')
  106. args = parser.parse_args()
  107. filetype = Path(args.file).suffix.lower()
  108. if filetype == '.tif' or filetype == '.tiff':
  109. n_path = tiff_to_nii(Path(args.file), Path(args.dir))
  110. elif filetype == '.nd2':
  111. n_path = nd2_to_nii(Path(args.file), Path(args.dir), args.series)
  112. elif filetype == '.nii':
  113. n_path = Path(args.file)