rotOnce.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from SWCTransforms import SWCRotate, ArgGenIterator, objFun
  2. import multiprocessing as mp
  3. import numpy as np
  4. import json
  5. import sys
  6. from itertools import product
  7. from regmaxsn.core.transforms import compose_matrix
  8. debugging = False
  9. # debugging = True
  10. assert len(sys.argv) == 2, 'Only one argument, the path of the swcfile expected, ' + str(len(sys.argv)) + 'found'
  11. parFile = sys.argv[1]
  12. with open(parFile, 'r') as fle:
  13. pars = json.load(fle)
  14. refSWC, SWC2Align, outFiles, gridSizes, bounds, minRes, nCPU = pars
  15. data = np.loadtxt(SWC2Align)[:, 2:5]
  16. dataCentered = data - data.mean(axis=0)
  17. maxDist = np.linalg.norm(data, axis=1).max()
  18. pool = mp.Pool(processes=nCPU)
  19. SWCDatas = [SWCRotate(refSWC, SWC2Align, x) for x in gridSizes]
  20. stepSizes = [max(x / maxDist, minRes) for x in gridSizes]
  21. bestSol = [0, 0, 0]
  22. for gridInd, gridSize in enumerate(gridSizes):
  23. if debugging:
  24. print('Gridsize:' + str(gridSize))
  25. stepSize = stepSizes[gridInd]
  26. if debugging:
  27. print('Stepsize: ' + str(np.rad2deg(stepSize)))
  28. bounds = (np.array(bounds).T - np.array(bestSol)).T
  29. boundsRoundedUp = np.sign(bounds) * np.ceil(np.abs(bounds) / stepSize) * stepSize
  30. possiblePts1D = [np.round(bestSol[ind] + np.arange(x[0], x[1] + stepSize, stepSize), 3).tolist()
  31. for ind, x in enumerate(boundsRoundedUp)]
  32. if debugging:
  33. print(np.rad2deg([bestSol[ind] + x for ind, x in enumerate(boundsRoundedUp)]))
  34. possiblePts3D = np.round(list(product(*possiblePts1D)), 6).tolist()
  35. argGen = ArgGenIterator(possiblePts3D, SWCDatas[gridInd])
  36. funcVals = pool.map_async(objFun, argGen).get(1800)
  37. minimum = min(funcVals)
  38. minimzers = [y for x, y in enumerate(possiblePts3D) if funcVals[x] == minimum]
  39. if not gridInd:
  40. distFrom0 = np.linalg.norm(minimzers, axis=1)
  41. bestSol = minimzers[np.argmin(distFrom0)]
  42. else:
  43. prevVals = [objFun((x, SWCDatas[gridInd - 1])) for x in minimzers]
  44. bestSol = minimzers[np.argmin(prevVals)]
  45. bounds = map(lambda x: [x - np.sqrt(2) * stepSize, x + np.sqrt(2) * stepSize], bestSol)
  46. if debugging:
  47. bestVal = objFun((bestSol, SWCDatas[gridInd]))
  48. print(np.rad2deg(bestSol), bestVal)
  49. if minRes < stepSizes[-1]:
  50. if debugging:
  51. print('Stepsize: ' + str(np.rad2deg(minRes)))
  52. bounds = (np.array(bounds).T - np.array(bestSol)).T
  53. boundsRoundedUp = np.sign(bounds) * np.ceil(np.abs(bounds) / minRes) * minRes
  54. if debugging:
  55. print(np.rad2deg([bestSol[ind] + x for ind, x in enumerate(boundsRoundedUp)]))
  56. possiblePts1D = [np.round(bestSol[ind] + np.arange(x[0], x[1] + minRes, minRes), 3).tolist()
  57. for ind, x in enumerate(boundsRoundedUp)]
  58. possiblePts3D = np.round(list(product(*possiblePts1D)), 6).tolist()
  59. argGen = ArgGenIterator(possiblePts3D, SWCDatas[-1])
  60. funcVals = pool.map_async(objFun, argGen).get(1800)
  61. minimum = min(funcVals)
  62. minimzers = [y for x, y in enumerate(possiblePts3D) if funcVals[x] == minimum]
  63. prevVals = [objFun((x, SWCDatas[-2])) for x in minimzers]
  64. bestSol = minimzers[np.argmin(prevVals)]
  65. if debugging:
  66. bestVal = objFun((bestSol, SWCDatas[-1]))
  67. print(np.rad2deg(bestSol), bestVal)
  68. bestVal = objFun((bestSol, SWCDatas[-1]))
  69. nochange = objFun(([0, 0, 0], SWCDatas[-1]))
  70. if debugging:
  71. print(np.rad2deg(bestSol), bestVal, nochange)
  72. done = False
  73. # all values are worse than doing nothing
  74. if bestVal > nochange:
  75. done = True
  76. bestSol = [0, 0, 0]
  77. bestVal = nochange
  78. # best solution and no change are equally worse
  79. elif bestVal == nochange:
  80. # the solution step is very close to zero
  81. if np.abs(bestSol).max() <= min(minRes, stepSizes[-1]):
  82. done = True
  83. bestSol = [0, 0, 0]
  84. bestVal = nochange
  85. SWCDatas[-1].writeSolution(outFiles[0], bestSol)
  86. matrix = compose_matrix(angles=bestSol).tolist()
  87. with open(outFiles[1], 'w') as fle:
  88. json.dump({'type': 'XYZ Euler Angles in radians','bestSol': bestSol,
  89. 'transMat': matrix, 'done': done, 'bestVal': bestVal}, fle)