123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- import os
- from regmaxsn.core.swcFuncs import transSWC, getPCADetails, transSWC_rotAboutPoint
- from regmaxsn.core.SWCTransforms import objFun, SWCTranslate
- from regmaxsn.core.misc import parFileCheck
- import json
- import numpy as np
- import sys
- def pca_based(parFile):
- ch = input('Using parameter File {}.\n Continue?(y/n)'.format(parFile))
- if ch != 'y':
- print('User Abort!')
- sys.exit()
- from regmaxsn.core.RegMaxSPars import pcaBasedParNames
- parsList = parFileCheck(parFile, pcaBasedParNames)
- for parInd, pars in enumerate(parsList):
- print('Current Parameters:')
- for parN, parV in pars.items():
- print(('{}: {}'.format(parN, parV)))
- refSWC = pars['refSWC']
- testSWC = pars['testSWC']
- gridSizes = pars['gridSizes']
- resFile = pars['resFile']
- usePartsDir = pars['usePartsDir']
- resDir, expName = os.path.split(resFile[:-4])
- resSolFile = os.path.join(resDir, expName + 'bestSol.txt')
- refPts = np.loadtxt(refSWC)[:, 2:5]
- refMean = refPts.mean(axis=0)
- SWC2AlignPts = np.loadtxt(testSWC)[:, 2:5]
- SWC2AlignMean = SWC2AlignPts.mean(axis=0)
- refEvecs, refNStds = getPCADetails(refSWC)
- STAEvecs, STANStds = getPCADetails(testSWC)
- scales = [x / y for x, y in zip(refNStds, STANStds)]
- totalTransform = np.eye(4)
- totalTransform[:3, 3] = -SWC2AlignMean
- temp = np.eye(4)
- temp[:3, :3] = STAEvecs.T
- totalTransform = np.dot(temp, totalTransform)
- temp = np.eye(4)
- temp[:3, :3] = np.diag(scales)
- totalTransform = np.dot(temp, totalTransform)
- temp = np.eye(4)
- temp[:3, :3] = refEvecs
- totalTransform = np.dot(temp, totalTransform)
- totalTranslation = refMean
- totalTransform[:3, 3] += totalTranslation
- transSWC(testSWC, totalTransform[:3, :3], totalTransform[:3, 3], resFile)
- trans = SWCTranslate(refSWC, resFile, gridSizes[-1])
- bestVal = objFun(([0, 0, 0], trans))
- if usePartsDir:
- inPartsDir = testSWC[:-4]
- if os.path.isdir(inPartsDir):
- dirList = os.listdir(inPartsDir)
- dirList = [x for x in dirList if x.endswith('.swc')]
- outPartsDir = resFile[:-4]
- if not os.path.isdir(outPartsDir):
- os.mkdir(outPartsDir)
- for entry in dirList:
- transSWC_rotAboutPoint(os.path.join(inPartsDir, entry),
- totalTransform[:3, :3], totalTransform[:3, 3],
- os.path.join(outPartsDir, entry),
- [0, 0, 0]
- )
- else:
- print(('Specified partsDir {} not found'.format(inPartsDir)))
- with open(resSolFile, 'w') as fle:
- json.dump({'transMat': totalTransform.tolist(), 'bestVal': bestVal,
- 'refSWC': refSWC, 'testSWC': testSWC, 'gridSizes': gridSizes}, fle)
- # ----------------------------------------------------------------------------------------------------------------------
- if __name__ == '__main__':
- assert len(sys.argv) == 2, 'Improper usage! Please use as \'python pcaBasedReg.py parFile\''
- parFile = sys.argv[1]
- pca_based(parFile)
|