pcaBasedReg.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import os
  2. from regmaxsn.core.swcFuncs import transSWC, getPCADetails, transSWC_rotAboutPoint
  3. from regmaxsn.core.SWCTransforms import objFun, SWCTranslate
  4. from regmaxsn.core.misc import parFileCheck
  5. import json
  6. import numpy as np
  7. import sys
  8. def pca_based(parFile):
  9. ch = raw_input('Using parameter File {}.\n Continue?(y/n)'.format(parFile))
  10. if ch != 'y':
  11. print('User Abort!')
  12. sys.exit()
  13. from regmaxsn.core.RegMaxSPars import pcaBasedParNames
  14. parsList = parFileCheck(parFile, pcaBasedParNames)
  15. for parInd, pars in enumerate(parsList):
  16. print('Current Parameters:')
  17. for parN, parV in pars.iteritems():
  18. print('{}: {}'.format(parN, parV))
  19. refSWC = pars['refSWC']
  20. testSWC = pars['testSWC']
  21. gridSizes = pars['gridSizes']
  22. resFile = pars['resFile']
  23. usePartsDir = pars['usePartsDir']
  24. resDir, expName = os.path.split(resFile[:-4])
  25. resSolFile = os.path.join(resDir, expName + 'bestSol.txt')
  26. refPts = np.loadtxt(refSWC)[:, 2:5]
  27. refMean = refPts.mean(axis=0)
  28. SWC2AlignPts = np.loadtxt(testSWC)[:, 2:5]
  29. SWC2AlignMean = SWC2AlignPts.mean(axis=0)
  30. refEvecs, refNStds = getPCADetails(refSWC)
  31. STAEvecs, STANStds = getPCADetails(testSWC)
  32. scales = [x / y for x, y in zip(refNStds, STANStds)]
  33. totalTransform = np.eye(4)
  34. totalTransform[:3, 3] = -SWC2AlignMean
  35. temp = np.eye(4)
  36. temp[:3, :3] = STAEvecs.T
  37. totalTransform = np.dot(temp, totalTransform)
  38. temp = np.eye(4)
  39. temp[:3, :3] = np.diag(scales)
  40. totalTransform = np.dot(temp, totalTransform)
  41. temp = np.eye(4)
  42. temp[:3, :3] = refEvecs
  43. totalTransform = np.dot(temp, totalTransform)
  44. totalTranslation = refMean
  45. totalTransform[:3, 3] += totalTranslation
  46. transSWC(testSWC, totalTransform[:3, :3], totalTransform[:3, 3], resFile)
  47. trans = SWCTranslate(refSWC, resFile, gridSizes[-1])
  48. bestVal = objFun(([0, 0, 0], trans))
  49. if usePartsDir:
  50. inPartsDir = testSWC[:-4]
  51. if os.path.isdir(inPartsDir):
  52. dirList = os.listdir(inPartsDir)
  53. dirList = [x for x in dirList if x.endswith('.swc')]
  54. outPartsDir = resFile[:-4]
  55. if not os.path.isdir(outPartsDir):
  56. os.mkdir(outPartsDir)
  57. for entry in dirList:
  58. transSWC_rotAboutPoint(os.path.join(inPartsDir, entry),
  59. totalTransform[:3, :3], totalTransform[:3, 3],
  60. os.path.join(outPartsDir, entry),
  61. [0, 0, 0]
  62. )
  63. else:
  64. print('Specified partsDir {} not found'.format(inPartsDir))
  65. with open(resSolFile, 'w') as fle:
  66. json.dump({'transMat': totalTransform.tolist(), 'bestVal': bestVal,
  67. 'refSWC': refSWC, 'testSWC': testSWC, 'gridSizes': gridSizes}, fle)
  68. # ----------------------------------------------------------------------------------------------------------------------
  69. if __name__ == '__main__':
  70. assert len(sys.argv) == 2, 'Improper usage! Please use as \'python pcaBasedReg.py parFile\''
  71. parFile = sys.argv[1]
  72. pca_based(parFile)