scaleOnce.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. from SWCTransforms import SWCScale, SWCTranslate, ArgGenIterator, objFun
  2. import multiprocessing as mp
  3. import numpy as np
  4. import json
  5. import sys
  6. from itertools import product
  7. from regmaxsn.core.transforms import compose_matrix
  8. debugging = False
  9. # debugging = True
  10. assert len(sys.argv) == 2, 'Only one argument, the path of the swcfile expected, ' + str(len(sys.argv)) + 'found'
  11. parFile = sys.argv[1]
  12. with open(parFile, 'r') as fle:
  13. pars = json.load(fle)
  14. refSWC, SWC2Align, outFiles, gridSizes, bounds, minStepSize, nCPU = pars
  15. initBounds = bounds
  16. boundL = lambda x, y: max(y[0], min(y[1], x))
  17. data = np.loadtxt(SWC2Align)[:, 2:5]
  18. dataCentered = data - data.mean(axis=0)
  19. maxDist = max(np.linalg.norm(dataCentered, axis=1).max(), gridSizes[0] * 1.01)
  20. pool = mp.Pool(processes=nCPU)
  21. SWCDatas = [SWCScale(refSWC, SWC2Align, x) for x in gridSizes]
  22. bestSol = [1.0, 1.0, 1.0]
  23. stepSizes = [max(minStepSize, min(2.0, (maxDist / (maxDist - g)))) for g in gridSizes]
  24. if debugging:
  25. print(maxDist, [(maxDist / (maxDist - g)) for g in gridSizes])
  26. overestimationError = lambda d, g: (d + g) / d
  27. underestimationError = lambda d, g: ((d + 1.5 * g) * d) / ((d - 0.5 * g) * (d + g))
  28. for gridInd, gridSize in enumerate(gridSizes):
  29. stepSize = stepSizes[gridInd]
  30. bounds = np.array(bounds)
  31. boundsExponents = np.log([x / y for x, y in zip(bounds, bestSol)]) / np.log(stepSize)
  32. boundsExponentsRoundedDown = np.sign(boundsExponents) * np.ceil(np.abs(boundsExponents))
  33. possiblePts1D = [(bestSol[x] * (stepSize ** np.arange(int(y[0]), int(y[1]) + 1)))
  34. for x, y in enumerate(boundsExponentsRoundedDown)]
  35. if debugging:
  36. print(stepSize)
  37. print('Gridsize:' + str(gridSize))
  38. print(bounds)
  39. print(map(len, possiblePts1D))
  40. print([bestSol[x] * (stepSize ** y) for x, y in enumerate(boundsExponentsRoundedDown)])
  41. possiblePts3D = np.round(list(product(*possiblePts1D)), 6).tolist()
  42. argGen = ArgGenIterator(possiblePts3D, SWCDatas[gridInd])
  43. funcVals = pool.map_async(objFun, argGen).get(1800)
  44. minimum = min(funcVals)
  45. minimzers = [y for x, y in enumerate(possiblePts3D) if funcVals[x] == minimum]
  46. if not gridInd:
  47. distFrom0 = [np.mean([max(x, 1 / x) for x in y]) for y in minimzers]
  48. bestSol = minimzers[np.argmin(distFrom0)]
  49. else:
  50. prevVals = [objFun((x, SWCDatas[gridInd - 1])) for x in minimzers]
  51. bestSol = minimzers[np.argmin(prevVals)]
  52. bounds = [[boundL(x / overestimationError(maxDist, gridSize), iB),
  53. boundL(x * underestimationError(maxDist, gridSize), iB)]
  54. for x, iB in zip(bestSol, initBounds)]
  55. if debugging:
  56. print(bestSol)
  57. if stepSizes[-1] > minStepSize:
  58. stepSize = minStepSize
  59. bounds = np.array(bounds)
  60. boundsExponents = np.log([x / y for x, y in zip(bounds, bestSol)]) / np.log(stepSize)
  61. boundsExponentsRoundedDown = np.sign(boundsExponents) * np.ceil(np.abs(boundsExponents))
  62. possiblePts1D = [(bestSol[x] * (stepSize ** np.arange(int(y[0]), int(y[1]) + 1)))
  63. for x, y in enumerate(boundsExponentsRoundedDown)]
  64. if debugging:
  65. print(stepSize)
  66. print(bounds)
  67. print(map(len, possiblePts1D))
  68. print([bestSol[x] * (stepSize ** y) for x, y in enumerate(boundsExponentsRoundedDown)])
  69. possiblePts3D = np.round(list(product(*possiblePts1D)), 6).tolist()
  70. argGen = ArgGenIterator(possiblePts3D, SWCDatas[-1])
  71. funcVals = pool.map_async(objFun, argGen).get(1800)
  72. minimum = min(funcVals)
  73. minimzers = [y for x, y in enumerate(possiblePts3D) if funcVals[x] == minimum]
  74. prevVals = [objFun((x, SWCDatas[-2])) for x in minimzers]
  75. bestSol = minimzers[np.argmin(prevVals)]
  76. if debugging:
  77. print(bestSol, min(funcVals))
  78. bestVal = objFun((bestSol, SWCDatas[-1]))
  79. nochange = objFun(([1, 1, 1], SWCDatas[-1]))
  80. if debugging:
  81. bestVals = [objFun((bestSol, x)) for x in SWCDatas]
  82. print(bestSol, nochange, bestVal)
  83. done = False
  84. # all values are worse than doing nothing
  85. if bestVal > nochange:
  86. done = True
  87. bestSol = [1, 1, 1]
  88. bestVal = nochange
  89. # best solution and no change are equally worse
  90. elif bestVal == nochange:
  91. # the solution step is very close to zero
  92. if np.abs(bestSol).max() <= min(minStepSize, stepSizes[-1]):
  93. done = True
  94. bestSol = [1, 1, 1]
  95. bestVal = nochange
  96. SWCDatas[-1].writeSolution(outFiles[0], bestSol)
  97. temp = SWCTranslate(refSWC, outFiles[0], gridSizes[-1])
  98. bestVal = objFun(([0, 0, 0], temp))
  99. matrix = compose_matrix(scale=bestSol).tolist()
  100. with open(outFiles[1], 'w') as fle:
  101. json.dump({'type': 'XYZ Scales', 'bestSol': bestSol,
  102. 'transMat': matrix, 'done': done, 'bestVal': bestVal}, fle)