123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212 |
- import os
- import numpy as np
- from regmaxsn.core.transforms import compose_matrix
- from regmaxsn.core.swcFuncs import readSWC_numpy, writeSWC_numpy
- def three32BitInt2complexList(arr):
- assert arr.dtype == np.int32, 'Only 32 bit ints allowed'
- assert arr.shape[1] == 3, 'only three ints per row allowed'
- temp = np.zeros((arr.shape[0], 4), dtype=np.int32)
- temp[:, 1:] = arr
- temp1 = temp.view(np.int64)
- j = complex(0, 1)
- return (temp1[:, 0] + j * temp1[:, 1]).tolist()
- def objFun(y):
- x, data = y
- nrnPtsTrans = data.transform(x, data.SWC2AlignPts)
- nrnVox = np.array(np.round(nrnPtsTrans / data.gridSize), np.int32)
- nrnVoxSet = set(three32BitInt2complexList(nrnVox))
- nInter = len(nrnVoxSet.intersection(data.refVoxSet))
- nUnion = len(nrnVoxSet) + len(data.refVoxSet) - nInter
- sumOfDiffs = 1 - (nInter / float(nUnion))
- return sumOfDiffs
- class BaseSWCTranform(object):
- def __init__(self, refSWC, SWC2Align, gridSize):
- super(BaseSWCTranform, self).__init__()
- if os.path.isfile(refSWC) and refSWC.endswith('.swc'):
- self.refSWCPts = np.loadtxt(refSWC)[:, 2:5]
- elif type(refSWC) == np.ndarray:
- self.refSWCPts = refSWC[:, 2:5]
- else:
- raise ValueError('Unknown data in SWC2Align')
- self.refCenter = self.refSWCPts.mean(axis=0)
- refVox = np.array(np.round(self.refSWCPts / gridSize), np.int32)
- self.gridSize = gridSize
- self.refVoxSet = set(three32BitInt2complexList(refVox))
- if os.path.isfile(SWC2Align) and SWC2Align.endswith('.swc'):
- self.headr, self.SWC2AlignFull = readSWC_numpy(SWC2Align)
- elif type(SWC2Align) == np.ndarray:
- assert type(SWC2Align['data']) == np.ndarray, 'Unknown data in SWC2Align'
- self.SWC2AlignFull = SWC2Align
- self.headr = ''
- else:
- raise ValueError('Unknown data in SWC2Align')
- self.SWC2AlignPts = self.SWC2AlignFull[:, 2:5]
- self.center = self.SWC2AlignPts.mean(axis=0)
- def transform(self, pars, data):
- return data
- def writeSolution(self, outFile, bestSol, inFile=None):
- if inFile:
- headr, data = readSWC_numpy(inFile)
- else:
- headr, data = self.headr, self.SWC2AlignFull
- data[:, 2:5] = self.transform(bestSol, data[:, 2:5])
- writeSWC_numpy(outFile, data, headr)
- class SWCTranslate(BaseSWCTranform):
- def __init__(self, refSWC, SWC2Align, gridSize):
- super(SWCTranslate, self).__init__(refSWC, SWC2Align, gridSize)
- def transform(self, pars, data):
- return data + pars
- class SWCRotate(BaseSWCTranform):
- def __init__(self, refSWC, SWC2Align, gridSize):
- super(SWCRotate, self).__init__(refSWC, SWC2Align, gridSize)
- def transform(self, pars, data):
- rotMat = compose_matrix(angles=pars)
- dataCentered = data - self.center
- return np.dot(rotMat[:3, :3], dataCentered.T).T + self.center
- class SWCScale(object):
- def __init__(self, refSWC, SWC2Align, gridSize):
- super(SWCScale, self).__init__()
- self.gridSize = gridSize
- if os.path.isfile(refSWC) and refSWC.endswith('.swc'):
- self.refSWCPts = np.loadtxt(refSWC)[:, 2:5]
- elif type(refSWC) == np.ndarray:
- self.refSWCPts = refSWC[:, 2:5]
- else:
- raise ValueError('Unknown data in SWC2Align')
- refCenter = self.refSWCPts.mean(axis=0)
- refSWCPtsCentered = self.refSWCPts - refCenter
- refVox = np.array(np.round(refSWCPtsCentered / gridSize), np.int32)
- self.refVoxSet = set(three32BitInt2complexList(refVox))
- if os.path.isfile(SWC2Align) and SWC2Align.endswith('.swc'):
- self.headr, self.SWC2AlignFull = readSWC_numpy(SWC2Align)
- elif type(SWC2Align) == np.ndarray:
- assert type(SWC2Align['data']) == np.ndarray, 'Unknown data in SWC2Align'
- self.SWC2AlignFull = SWC2Align
- self.headr = ''
- else:
- raise ValueError('Unknown data in SWC2Align')
- self.SWC2AlignPts = self.SWC2AlignFull[:, 2:5].copy()
- self.center = self.SWC2AlignPts.mean(axis=0)
- self.SWC2AlignPts -= self.center
- def transform(self, pars, data):
- rotMat = compose_matrix(scale=pars)
- dataCentered = data
- return np.dot(rotMat[:3, :3], dataCentered.T).T
- def writeSolution(self, outFile, bestSol, inFile=None):
- if inFile:
- headr, data = readSWC_numpy(inFile)
- else:
- headr, data = self.headr, self.SWC2AlignFull
- data[:, 2:5] = self.transform(bestSol, data[:, 2:5] - self.center) + self.center
- writeSWC_numpy(outFile, data, headr)
- class ArgGenIterator:
- def __init__(self, arg1, arg2):
- self.arg1 = arg1
- self.arg2 = arg2
- self.pointsDone = 0
- def __iter__(self):
- self.pointsDone = 0
- return self
- def __next__(self):
- if self.pointsDone < len(self.arg1):
- toReturn = (self.arg1[self.pointsDone], self.arg2)
- self.pointsDone += 1
- return toReturn
- else:
- raise StopIteration
|