pixelwise.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import numpy as np
  2. from view.idl_translation_core.bleach_correction import fitlogdecay, model_func
  3. from itertools import product
  4. import multiprocessing as mp
  5. import logging
  6. import platform
  7. shared_arr_g = None
  8. patch_size_g = None
  9. def _init(shared_arr_, patch_size_, weights_, fit_pars_dict_):
  10. # The shared array pointer is a global variable so that it can be accessed by the
  11. # child processes. It is a tuple (pointer, dtype, shape).
  12. global shared_arr, patch_size, weights, fit_pars_dict
  13. shared_arr = shared_arr_
  14. patch_size = patch_size_
  15. weights = weights_
  16. fit_pars_dict = fit_pars_dict_
  17. def shared_to_numpy(shared_array_pars, copy=False):
  18. """Get a NumPy array from a shared memory buffer, with a given dtype and shape.
  19. No copy is involved, the array reflects the underlying shared buffer."""
  20. shared_arr, dtype, shape = shared_array_pars
  21. wrapped_arr = np.frombuffer(shared_arr, dtype=dtype).reshape(shape)
  22. if copy:
  23. copy_arr = np.empty_like(wrapped_arr)
  24. np.copyto(dst=copy_arr, src=wrapped_arr)
  25. return copy_arr
  26. else:
  27. return wrapped_arr
  28. def numpy2raw_array(ndarray):
  29. """
  30. Convert numpy array to raw array
  31. """
  32. dtype = ndarray.dtype
  33. shape = ndarray.shape
  34. # Get a ctype type from the NumPy dtype.
  35. cdtype = np.ctypeslib.as_ctypes_type(dtype)
  36. # Create the RawArray instance.
  37. shared_arr = mp.RawArray(cdtype, int(np.prod(shape)))
  38. # Wrap shared_arr as an numpy array so we can easily manipulates its data (here only to copy data into it)
  39. fake_numpy_arr = shared_to_numpy((shared_arr, dtype, shape))
  40. np.copyto(dst=fake_numpy_arr, src=ndarray)
  41. return shared_arr, dtype, shape
  42. def bleach_correct_pixelwise(movie: np.ndarray, weights, area, ncpu: int):
  43. assert movie.shape[:2] == area.shape, f"Area file specified has dimensions {area.shape} that does not match with" \
  44. f"data dimensions {movie.shape}"
  45. pixel_inds = [ind for ind, val in np.ndenumerate(area) if val]
  46. global shared_arr_g, weights_g
  47. # mmappickle could be used instead of shared memory. https://mmappickle.readthedocs.io/en/latest/
  48. shared_arr_g = numpy2raw_array(movie)
  49. weights_g = weights
  50. if ncpu > 1:
  51. assert platform.system() != "Windows", \
  52. "Pixelwise bleach correction currently does not work on Windows due to parallization issues. Sorry!"
  53. # apply bleach correction to each patch in parallel
  54. with mp.Pool(processes=ncpu) as p: # use all cores
  55. op_params_list = p.map(bleach_correct_pixelwise_worker, pixel_inds, chunksize=100)
  56. elif ncpu == 1:
  57. # apply bleach correction to each patch without parallelization
  58. op_params_list = []
  59. for pixel_ind_nr, pixel_ind in enumerate(pixel_inds):
  60. logging.getLogger("VIEW").debug(f"Doing pixel {pixel_ind_nr + 1}/{len(pixel_inds)}")
  61. op_params = bleach_correct_pixelwise_worker(pixel_ind)
  62. op_params_list.append(op_params)
  63. else:
  64. raise ValueError(f"Paramater ncpu has to be 1 or more ({ncpu} specified)")
  65. array2return = shared_to_numpy(shared_arr_g, copy=True)
  66. return array2return, {k: v for k, v in zip(pixel_inds, op_params_list)}
  67. def bleach_correct_pixelwise_worker(pixel_index: tuple):
  68. movie = shared_to_numpy(shared_arr_g)
  69. # reduce patch to curve
  70. curve = movie[pixel_index[0], pixel_index[1], :]
  71. # apply bleach correction to curve and return the parameters A, K and C
  72. fitted_curve, (A, K, C) = fitlogdecay(lineIn=curve, weights=weights_g, showresults=False)
  73. # sometimes A and/or K can be NAN, then don't bleach correct
  74. # adding the mean of the fitted curve ensures the average intensity value of every pixel
  75. # is not affected by the bleach correction applied
  76. if not np.isnan(A) and not np.isnan(K):
  77. movie[pixel_index[0], pixel_index[1], :] = curve - fitted_curve + fitted_curve.mean()
  78. return A, K, C