transOnce.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from SWCTransforms import SWCTranslate, ArgGenIterator, objFun
  2. import multiprocessing as mp
  3. import numpy as np
  4. import json
  5. import sys
  6. from itertools import product
  7. from regmaxsn.core.transforms import compose_matrix
  8. debugging = False
  9. # debugging = True
  10. assert len(sys.argv) == 2, 'Only one argument, the path of the swcfile expected, ' + str(len(sys.argv)) + 'found'
  11. parFile = sys.argv[1]
  12. with open(parFile, 'r') as fle:
  13. pars = json.load(fle)
  14. refSWC, SWC2Align, outFiles, gridSizes, bounds, minRes, nCPU = pars
  15. SWCDatas = [SWCTranslate(refSWC, SWC2Align, x) for x in gridSizes]
  16. pool = mp.Pool(processes=nCPU)
  17. bestSol = [0, 0, 0]
  18. for gridInd, gridSize in enumerate(gridSizes):
  19. bounds = (np.array(bounds).T - np.array(bestSol)).T
  20. boundsRoundedUp = np.sign(bounds) * np.ceil(np.abs(bounds) / gridSize) * gridSize
  21. possiblePts1D = [(bestSol[ind] + np.arange(x[0], x[1] + gridSize, gridSize)).tolist()
  22. for ind, x in enumerate(boundsRoundedUp)]
  23. possiblePts3D = list(product(*possiblePts1D))
  24. if debugging:
  25. print('Gridsize:' + str(gridSize))
  26. print(bounds)
  27. print(map(len, possiblePts1D))
  28. print([bestSol[ind] + x for ind, x in enumerate(boundsRoundedUp)])
  29. argGen = ArgGenIterator(possiblePts3D, SWCDatas[gridInd])
  30. funcVals = pool.map_async(objFun, argGen).get(1800)
  31. minimum = min(funcVals)
  32. minimzers = [y for x, y in enumerate(possiblePts3D) if funcVals[x] == minimum]
  33. if not gridInd:
  34. distFrom0 = np.linalg.norm(minimzers, axis=1)
  35. bestSol = minimzers[np.argmin(distFrom0)]
  36. else:
  37. prevVals = [objFun((x, SWCDatas[gridInd - 1])) for x in minimzers]
  38. bestSol = minimzers[np.argmin(prevVals)]
  39. bounds = map(lambda x: [x - gridSize, x + gridSize], bestSol)
  40. if debugging:
  41. bestVal = objFun((bestSol, SWCDatas[gridInd]))
  42. print(bestSol, bestVal)
  43. if minRes < gridSizes[-1]:
  44. bounds = (np.array(bounds).T - np.array(bestSol)).T
  45. boundsRoundedUp = np.sign(bounds) * np.ceil(np.abs(bounds) / minRes) * minRes
  46. possiblePts1D = [(bestSol[ind] + np.arange(x[0], x[1] + minRes, minRes)).tolist()
  47. for ind, x in enumerate(boundsRoundedUp)]
  48. possiblePts3D = list(product(*possiblePts1D))
  49. if debugging:
  50. print('StepSize:' + str(minRes))
  51. print(bounds)
  52. print(map(len, possiblePts1D))
  53. print([bestSol[ind] + x for ind, x in enumerate(boundsRoundedUp)])
  54. argGen = ArgGenIterator(possiblePts3D, SWCDatas[-1])
  55. funcVals = pool.map_async(objFun, argGen).get(1800)
  56. minimum = min(funcVals)
  57. minimzers = [y for x, y in enumerate(possiblePts3D) if funcVals[x] == minimum]
  58. prevVals = [objFun((x, SWCDatas[-2])) for x in minimzers]
  59. bestSol = minimzers[np.argmin(prevVals)]
  60. bestVal = objFun((bestSol, SWCDatas[-1]))
  61. nochange = objFun(([0, 0, 0], SWCDatas[-1]))
  62. if debugging:
  63. bestVals = [objFun((bestSol, x)) for x in SWCDatas]
  64. print(bestSol, bestVals)
  65. done = False
  66. # all values are worse than doing nothing
  67. if bestVal > nochange:
  68. done = True
  69. bestSol = [0, 0, 0]
  70. bestVal = nochange
  71. # best solution and no change are equally worse
  72. elif bestVal == nochange:
  73. # the solution step is very close to zero
  74. if np.abs(bestSol).max() <= min(minRes, gridSizes[-1]):
  75. done = True
  76. bestSol = [0, 0, 0]
  77. bestVal = nochange
  78. SWCDatas[-1].writeSolution(outFiles[0], bestSol)
  79. matrix = compose_matrix(translate=bestSol).tolist()
  80. with open(outFiles[1], 'w') as fle:
  81. json.dump({'type': 'XYZ Translations in um','bestSol': bestSol,
  82. 'transMat': matrix, 'done': done, 'bestVal': bestVal}, fle)