preprocess.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. from pathlib import Path
  2. import argparse
  3. import numpy as np
  4. import nibabel as nib
  5. from PIL import Image, ImageSequence
  6. from scipy import ndimage
  7. from skimage import morphology, filters
  8. from nipype.interfaces.ants.segmentation import N4BiasFieldCorrection
  9. import warnings
  10. warnings.simplefilter(action='ignore', category=FutureWarning)
  11. def get_out_paths(out_dir, stem):
  12. out_path_scaled = Path(out_dir, f'{stem}.nii')
  13. out_path_unscaled = Path(out_dir, f'{stem}_us.nii')
  14. return out_path_scaled, out_path_unscaled
  15. def bias_field_correction(image, output, dim=2):
  16. N4BiasFieldCorrection(
  17. input_image=image,
  18. output_image=output,
  19. dimension=dim
  20. ).run()
  21. def image_data_to_nii(pixdim, image_data, shrink, out_dir, file_path, save_unscaled=False):
  22. image_dim = (image_data.shape[0], image_data.shape[1])
  23. scale = 1/shrink
  24. new_dim = (round(image_dim[1] * scale), round(image_dim[0] * scale))
  25. new_arr = np.ndarray(new_dim + (image_data.shape[2],))
  26. for i in range(0,image_data.shape[2]):
  27. cur_channel = image_data[:, :, i]
  28. resized = np.array(Image.fromarray(cur_channel).resize(new_dim)).transpose()
  29. new_arr[:, :, i] = resized
  30. path_scaled, path_unscaled = get_out_paths(out_dir, Path(file_path).stem)
  31. nii_scaled = nib.Nifti1Image(new_arr, np.eye(4))
  32. nii_scaled.header['xyzt_units'] = 3
  33. nii_scaled.header['pixdim'][1:3] = pixdim * shrink, pixdim * shrink
  34. nib.save(nii_scaled, str(path_scaled))
  35. if save_unscaled:
  36. nii_unscaled = nib.Nifti1Image(image_data, np.eye(4))
  37. nii_unscaled.header['xyzt_units'] = 3
  38. nii_unscaled.header['pixdim'][1:3] = pixdim, pixdim
  39. nib.save(nii_unscaled, str(path_unscaled))
  40. print(f'Preprocessed: {path_scaled}\n')
  41. return path_scaled, path_unscaled
  42. def tiff_to_nii(tif_path, out_dir, pixdim=None, shrink=10):
  43. Image.MAX_IMAGE_PIXELS = None
  44. tif_image = Image.open(tif_path)
  45. tif_header = dict(tif_image.tag)
  46. output = np.empty(np.array(tif_image).shape + (0,))
  47. if not pixdim:
  48. pixdim = 10e6/tif_header[282][0][0]
  49. for i, page in enumerate(ImageSequence.Iterator(tif_image)):
  50. page_data = np.expand_dims(np.array(page), 2)
  51. output = np.concatenate((output, page_data), 2)
  52. return image_data_to_nii(pixdim, output, shrink, out_dir, tif_path)
  53. def split_nii_channels(nii_path, out_dir=None, flip=False, mask_index=-1, bias=False):
  54. if out_dir is None:
  55. out_dir = nii_path.parent
  56. nii = nib.load(str(nii_path))
  57. nii_data = nii.get_fdata()
  58. nii_header = nii.header
  59. if mask_index == -1:
  60. mask_index = nii_data.shape[2] - 1
  61. paths = []
  62. for i in range(0, nii_data.shape[2]):
  63. out_path = out_dir / f'im_c{i+1}.nii'
  64. channel_data = nii_data[:, :, i]
  65. if flip:
  66. channel_data = np.flip(channel_data, 1)
  67. if i == mask_index:
  68. channel_data = mask_foreground(channel_data)
  69. new_header = nii_header
  70. new_header['dim'][0] = 2
  71. nii = nib.Nifti1Image(channel_data, np.eye(4), header=new_header)
  72. nib.save(nii, str(out_path))
  73. if i == mask_index and bias:
  74. bias_field_correction(str(out_path), str(out_path))
  75. corrected = nib.load(str(out_path))
  76. corrected_data = corrected.get_fdata()
  77. corrected_normalized = corrected_data / np.mean(corrected_data[corrected_data != 0])
  78. nii_corrected = nib.Nifti1Image(corrected_normalized, corrected.affine, corrected.header)
  79. nib.save(nii_corrected, str(out_path))
  80. paths.append(out_path)
  81. return paths
  82. def mask_foreground(raw_data):
  83. raw_max = raw_data.max()
  84. raw_data = raw_data / raw_max
  85. blurred_data = ndimage.gaussian_filter(raw_data, 4)
  86. threshold = filters.threshold_otsu(raw_data) / 2
  87. threshold_data = blurred_data > threshold
  88. connected_structure = ndimage.generate_binary_structure(2, 2) # Connects adjacent and diagonal.
  89. padded_comp, padded_nr = ndimage.label(threshold_data, structure=connected_structure)
  90. comps, comps_count = np.unique(padded_comp, return_counts=True)
  91. comps_count, comps = zip(*sorted(zip(comps_count, comps), reverse=True))
  92. two_biggest_cc = ((comps[0], np.average(comps[0])), (comps[1], np.average(comps[1])))
  93. biggest_cc = max(two_biggest_cc, key=lambda a: a[1])[0]
  94. foreground_mask = np.where(padded_comp == biggest_cc, True, False)
  95. closed = morphology.binary_closing(foreground_mask, selem=morphology.square(30))
  96. raw_data = np.where(closed, raw_data, 0)
  97. return raw_data * raw_max
  98. if __name__ == '__main__':
  99. parser = argparse.ArgumentParser()
  100. parser.add_argument("file", help="Location of file to process")
  101. parser.add_argument("dir", help="Directory for preprocessed Files")
  102. parser.add_argument("-s", "--series", type=int, help="Series to extract")
  103. parser.add_argument('-b', type=bool, default= True, help='Bias field correct image')
  104. parser.add_argument('--pdim', type=float, help='Pixel dimensions (Retrieved from Tiff file if not set)')
  105. args = parser.parse_args()
  106. filetype = Path(args.file).suffix.lower()
  107. if filetype == '.tif' or filetype == '.tiff':
  108. n_path = tiff_to_nii(Path(args.file), Path(args.dir))
  109. elif filetype == '.nd2':
  110. n_path = nd2_to_nii(Path(args.file), Path(args.dir), args.series)
  111. elif filetype == '.nii':
  112. n_path = Path(args.file)