123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250 |
- import os
- import numpy as np
- from regmaxsn.core.iterativeRegistration import IterativeRegistration, composeRefSWC, calcOverlap, getRemainderScale
- import shutil
- import json
- import sys
- from regmaxsn.core.transforms import decompose_matrix
- from regmaxsn.core.swcFuncs import transSWC
- from regmaxsn.core.misc import parFileCheck
- from regmaxsn.core.occupancyBasedMeasure import occupancyEMD
- def normalizeFinally(ipFiles, resDir, opFiles, fnwrtName, maxIter):
- itersAll = sorted([int(fle[3:-4]) for fle in os.listdir(resDir) if fle.find('ref') == 0])
- iters = [x for x in itersAll if x <= maxIter]
- totalTrans = np.eye(4)
- for iter in iters:
- solFle = os.path.join(resDir, fnwrtName + str(iter) + 'Sol.txt')
- if os.path.exists(solFle):
- with open(solFle, 'r') as f:
- pars = json.load(f)
- totalTrans = np.dot(pars['finalTransMat'], totalTrans)
- iTrans = np.linalg.inv(totalTrans)
- for ipFile, opFile in zip(ipFiles, opFiles):
- transSWC(ipFile, iTrans[:3, :3], iTrans[:3, 3], opFile)
- ipDir, ipName = os.path.split(ipFile[:-4])
- partsDir = os.path.join(ipDir, ipName)
- if os.path.isdir(partsDir):
- normedPartsDir = opFile[:-4]
- os.mkdir(normedPartsDir)
- swcs = [x for x in os.listdir(partsDir) if x.endswith('.swc')]
- for swc in swcs:
- opPart = os.path.join(normedPartsDir, swc)
- ipPart = os.path.join(partsDir, swc)
- transSWC(ipPart, iTrans[:3, :3], iTrans[:3, 3], opPart)
- def runRegMaxSN(parFile, parNames):
- assert os.path.isfile(parFile), "{} not found".format(parFile)
- ch = input('Using parameter File {}.\n Continue?(y/n)'.format(parFile))
- if ch != 'y':
- print('User Abort!')
- sys.exit()
- parsList = parFileCheck(parFile, parNames)
- for pars in parsList:
- resDir = pars['resDir']
- refSWC = pars['initRefSWC']
- swcList = pars['swcList']
- fnwrt = pars['finallyNormalizeWRT']
- if os.path.isdir(resDir):
- ch = input('Folder exists: ' + resDir + '\nDelete(y/n)?')
- if ch == 'y':
- shutil.rmtree(resDir)
- else:
- quit()
- try:
- os.mkdir(resDir)
- except Exception as e:
- raise IOError
- assert os.path.isfile(refSWC), 'Could not find {}'.format(refSWC)
- for swc in swcList:
- assert os.path.isfile(swc), 'Could not find {}'.format(swc)
- assert swc.endswith('.swc'), 'Elements of swcList must be of SWC format with extension \'.swc\''
- assert fnwrt in swcList, 'The parameter finallyNormalizeWRT must be an element of the parameter swcList'
- print('All parameters are acceptable. Starting the Reg-MaxS-N jobs...')
- for parInd, pars in enumerate(parsList):
- print(('Starting Job # {}'.format(parInd + 1)))
- print('Current Parameters:')
- for parN, parV in pars.items():
- print(('{}: {}'.format(parN, parV)))
- resDir = pars['resDir']
- refSWC = pars['initRefSWC']
- swcList = pars['swcList']
- fnwrt = pars['finallyNormalizeWRT']
- usePartsDir = pars['usePartsDir']
- nIter = pars['maxIter']
- gridSizes = pars['gridSizes']
- rotBounds = pars['rotBounds']
- transBounds = pars['transBounds']
- scaleBounds = pars['scaleBounds']
- transMinRes = pars['transMinRes']
- minScaleStepSize = pars['minScaleStepSize']
- rotMinRes = pars['rotMinRes']
- nCPU = pars['nCPU']
- shutil.copyfile(refSWC, os.path.join(resDir, 'ref' + str(-1) + '.swc'))
- if usePartsDir:
- for swc in swcList:
- dirPath, expName = os.path.split(swc[:-4])
- partsDirO = os.path.join(dirPath, expName)
- if os.path.isdir(partsDirO):
- partsDirD = os.path.join(resDir, expName + str(-1))
- shutil.copytree(partsDirO, partsDirD)
- prevAlignedSWCs = swcList
- occupancyMeasureLargestGridSize = []
- bestIterInd = nIter - 1
- nrnScaleBounds = {swc: scaleBounds[:] for swc in swcList}
- for iterInd in range(nIter):
- iterReg = IterativeRegistration(refSWC, gridSizes, rotBounds, transBounds,
- transMinRes, minScaleStepSize, rotMinRes, nCPU)
- presAlignedSWCs = []
- dones = []
- for swcInd, swc in enumerate(swcList):
- dirPath, expName = os.path.split(swc[:-4])
- print(('Doing Iter ' + str(iterInd) + ' : ' + expName))
- SWC2Align = prevAlignedSWCs[swcInd]
- if iterInd > 0:
- initGuessTypeT = 'nothing'
- else:
- initGuessTypeT = 'just_centroids'
- initVals = [calcOverlap(refSWC, SWC2Align, g) for g in gridSizes]
- if usePartsDir:
- inPartsDir = os.path.join(resDir, expName + str(iterInd - 1))
- outPartsDir = os.path.join(resDir, expName + str(iterInd))
- else:
- inPartsDir = None
- outPartsDir = None
- resFile = os.path.join(resDir, expName + str(iterInd) + '.swc')
- resSWC, resSol = iterReg.performReg(SWC2Align, resFile,
- scaleBounds=nrnScaleBounds[swc],
- inPartsDir=inPartsDir,
- outPartsDir=outPartsDir,
- initGuessType=initGuessTypeT,
- retainTempFiles=True)
- finalVals = [calcOverlap(refSWC, resSWC, gridSize) for gridSize in gridSizes]
- considerIteration = False
- for iv, fv in zip(initVals, finalVals):
- if fv < iv:
- considerIteration = True
- break
- if fv > iv:
- considerIteration = False
- break
- if not considerIteration:
- shutil.copy(SWC2Align, resSWC)
- shutil.rmtree(os.path.join(resDir, expName + str(iterInd) + 'trans'))
- if usePartsDir and os.path.exists(outPartsDir):
- shutil.rmtree(outPartsDir)
- shutil.copytree(inPartsDir, outPartsDir)
- os.remove(resSol)
- print(('finalVal (' + str(finalVals) + ') >= initVal (' + str(initVals) + '). Doing Nothing!'))
- done = True
- else:
- print(('finalVal (' + str(finalVals) + ') < initVal (' + str(initVals) + '). Keeping the iteration!'))
- with open(resSol, 'r') as fle:
- pars = json.load(fle)
- totalTrans = np.array(pars['finalTransMat'])
- done = np.allclose(np.eye(3), totalTrans[:3, :3], atol=1e-3)
- scale, shear, angles, trans, persp = decompose_matrix(totalTrans)
- nrnScaleBounds[swc] = getRemainderScale(scale, nrnScaleBounds[swc])
- dones.append(done)
- print(('Finished ' + expName + ' : ' + str(done)))
- print(('Remainder scale: ' + str(nrnScaleBounds[swc])))
- presAlignedSWCs.append(resSWC)
- newRefSWC = os.path.join(resDir, 'ref' + str(iterInd) + '.swc')
- overallOverlap = composeRefSWC(presAlignedSWCs, newRefSWC, gridSizes[-1])
- occupancyMeasure = occupancyEMD(presAlignedSWCs, gridSizes[-1])
- occupancyMeasureLargestGridSize.append(occupancyMeasure)
- refSWC = newRefSWC
- prevAlignedSWCs = presAlignedSWCs
- if all(dones):
- break
- bestIterInd = np.argmin(occupancyMeasureLargestGridSize)
- bestMeasure = min(occupancyMeasureLargestGridSize)
- shutil.copy(os.path.join(resDir, 'ref' + str(bestIterInd) + '.swc'), os.path.join(resDir, 'finalRef.swc'))
- ipFiles = []
- opFiles = []
- thrash, fnwrtName = os.path.split(fnwrt[:-4])
- for swc in swcList:
- dirPath, expName = os.path.split(swc[:-4])
- ipFiles.append(os.path.join(resDir, '{}{}.swc'.format(expName, bestIterInd)))
- opFiles.append(os.path.join(resDir, '{}.swc'.format(expName)))
- normalizeFinally(ipFiles, resDir, opFiles, fnwrtName, bestIterInd)
- finalSolFile = os.path.join(resDir, "bestIterInd.json")
- with open(finalSolFile, 'w') as fle:
- json.dump({'finalVal': bestMeasure,
- 'bestIteration': int(bestIterInd)}, fle)
- print(('Finished Job # {}'.format(parInd + 1)))
- if __name__ == '__main__':
- from regmaxsn.core.RegMaxSPars import RegMaxSNParNames
- assert len(sys.argv) == 2, 'Improper usage! Please use as \'python RegMaxSN.py parFile\''
- parFile = sys.argv[1]
- runRegMaxSN(parFile, RegMaxSNParNames)
|