iterativeRegistration.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. import os
  2. import numpy as np
  3. from .swcFuncs import transSWC, transSWC_rotAboutPoint
  4. from .SWCTransforms import SWCTranslate, objFun
  5. import shutil
  6. import json
  7. import subprocess
  8. def transPreference(x, y):
  9. """
  10. Given two transforms x and y, returns if x is preferred over y. Preferences are scaling > translation,
  11. scaling > rotation, translation > rotation.
  12. :param x:
  13. :param y:
  14. :return:
  15. """
  16. if x == 'scale':
  17. return 0
  18. elif y == 'scale':
  19. return 1
  20. elif x == 'rot':
  21. return 0
  22. else:
  23. return 1
  24. def getRemainderScale(scale, oldScale):
  25. """
  26. Elementwise divides oldScale by scale, effectively removing scale from old scale. The first entry is bounded
  27. above by 1 and the second entry is bounded below by 1.
  28. :param scale: 3 member list of 2 member float lists.
  29. :param oldScale: 3 member list of 2 member float lists.
  30. :return: 3 member list of 2 member float lists.
  31. """
  32. toReturn = []
  33. for s, oldS in zip(scale, oldScale):
  34. toReturn.append([min(oldS[0] / s, 1), max(oldS[1] / s, 1)])
  35. return toReturn
  36. class IterativeRegistration(object):
  37. """
  38. This class is used to run basic Reg-MaxS algorithm.
  39. """
  40. def __init__(self, refSWC, gridSizes, rotBounds, transBounds,
  41. transMinRes, scaleMinRes, rotMinRes, nCPU):
  42. """
  43. Initialization
  44. :param refSWC: valid file path to a valid SWC file, reference SWC
  45. :param gridSizes: list of three floats, the voxel sizes over which estimations are run, in micrometer
  46. :param rotBounds: three member list of two member float lists, the bounds for rotation euler angles abour XYZ
  47. axes, in radians
  48. :param transBounds: three member list of two member float lists, the bounds for translations along XYZ axes.
  49. :param transMinRes: float, minimum resolution of exhaustive search for translation parameters in micrometer.
  50. :param scaleMinRes: float. minimum (multiplicative) resolution of exhuasitve search for scaling paramers.
  51. :param rotMinRes: float. minimum resolution of exhuastive search for rotation euler angle parameters in radians.
  52. :param nCPU: int, number of processes to use
  53. """
  54. super(IterativeRegistration, self).__init__()
  55. self.refSWC = refSWC
  56. self.gridSizes = gridSizes
  57. self.rotBounds = rotBounds
  58. self.transBounds = transBounds
  59. self.rotBounds = rotBounds
  60. self.transMinRes = transMinRes
  61. self.rotMinRes = rotMinRes
  62. self.scaleMinRes = scaleMinRes
  63. self.nCPU = nCPU
  64. self.allFuncs = {'trans': self.transOnce, 'rot': self.rotOnce, 'scale': self.scaleOnce}
  65. def rotOnce(self, SWC2Align, outFiles, ipParFile):
  66. """
  67. Runs exhaustive search to find the best rotation euler angles about XYZ axes that maximize the volume overlap
  68. between SWC2Align and self.refSWC. Results are written in the two file paths of outFiles.
  69. If solution found is no better than doing nothing or if the angles found are lower than minimum resolution,
  70. zero angles are returned with done as true
  71. :param SWC2Align: valid file path to a valid SWC file
  72. :param outFiles: list of two valid file paths. SWC2Align rotated with optimum parameters is written to
  73. outFiles[0], a log file of the process is written into ouFiles[1].
  74. :param ipParFile: valid file path. Temporary file used.
  75. :return: bestSol, bestVal, done
  76. bestSol: list of three floats, best Euler angles in radians
  77. bestVal: float, best value of dissimilarity between SWC2Align and refSWC at the lowest voxel size
  78. done: boolean, see above.
  79. """
  80. pars = [self.refSWC, SWC2Align, outFiles,
  81. self.gridSizes, self.rotBounds, self.rotMinRes, self.nCPU]
  82. with open(ipParFile, 'w') as fle:
  83. json.dump(pars, fle)
  84. f2call = os.path.join(os.path.split(__file__)[0], 'rotOnce.py')
  85. subprocess.call(['python', f2call, ipParFile])
  86. with open(outFiles[1], 'r') as fle:
  87. out = json.load(fle)
  88. bestSol = out['bestSol']
  89. done = out['done']
  90. bestVal = out['bestVal']
  91. print(bestSol, bestVal, done)
  92. return bestSol, bestVal, done
  93. def transOnce(self, SWC2Align, outFiles, ipParFile):
  94. """
  95. Runs exhaustive search to find the best translations along XYZ axes that maximize the volume overlap between
  96. SWC2Align and self.refSWC. Results are written in the two file paths of outFiles.
  97. If solution found is no better than doing nothing or if the translations found are lower than
  98. the minimum resolution, zero translations are returned with done set to true
  99. :param SWC2Align: valid file path to a valid SWC file
  100. :param outFiles: list of two valid file paths. SWC2Align translated with optimum parameters is written to
  101. outFiles[0], a log file of the process is written into ouFiles[1].
  102. :param ipParFile: valid file path. Temporary file used.
  103. :return: bestSol, bestVal, done
  104. bestSol: list of three floats, best translations in micrometer
  105. bestVal: float, best value of dissimilarity between SWC2Align and refSWC at the lowest voxel size
  106. done: boolean, see above.
  107. """
  108. pars = [self.refSWC, SWC2Align, outFiles,
  109. self.gridSizes, self.transBounds, self.transMinRes, self.nCPU]
  110. with open(ipParFile, 'w') as fle:
  111. json.dump(pars, fle)
  112. f2call = os.path.join(os.path.split(__file__)[0], 'transOnce.py')
  113. subprocess.call(['python', f2call, ipParFile])
  114. with open(outFiles[1], 'r') as fle:
  115. out = json.load(fle)
  116. bestSol = out['bestSol']
  117. done = out['done']
  118. bestVal = out['bestVal']
  119. print(bestSol, bestVal, done)
  120. return bestSol, bestVal, done
  121. def scaleOnce(self, SWC2Align, outFiles, ipParFile, scaleBounds):
  122. """
  123. Runs exhaustive search to find the best scaling parameters along XYZ axes that maximize the volume overlap
  124. between SWC2Align and self.refSWC. Results are written in the two file paths of outFiles.
  125. If solution found is no better than doing nothing or if the scaling parameters found are lower than
  126. the minimum resolution, unity scaling parameters are returned with done set to true
  127. :param SWC2Align: valid file path to a valid SWC file
  128. :param outFiles: list of two valid file paths. SWC2Align scaled with optimum parameters is written to
  129. outFiles[0], a log file of the process is written into ouFiles[1].
  130. :param ipParFile: valid file path. Temporary file used.
  131. :return: bestSol, bestVal, done
  132. bestSol: list of three floats, best scaling parameters
  133. bestVal: float, best value of dissimilarity between SWC2Align and refSWC at the lowest voxel size
  134. done: boolean, see above.
  135. """
  136. pars = [self.refSWC, SWC2Align, outFiles,
  137. self.gridSizes, scaleBounds, self.scaleMinRes, self.nCPU]
  138. with open(ipParFile, 'w') as fle:
  139. json.dump(pars, fle)
  140. f2call = os.path.join(os.path.split(__file__)[0], 'scaleOnce.py')
  141. subprocess.call(['python', f2call, ipParFile])
  142. with open(outFiles[1], 'r') as fle:
  143. out = json.load(fle)
  144. bestSol = out['bestSol']
  145. done = out['done']
  146. bestVal = out['bestVal']
  147. print(bestSol, bestVal, done)
  148. return bestSol, bestVal, done
  149. def compare(self, srts, SWC2Align, tempOutFiles, ipParFile, scaleBounds):
  150. """
  151. Runs the exhaustive searchs for the transforms in srts and returns some info about the searches
  152. :param srts: list of strings, valid entries are 'scale', 'trans' and 'rot'
  153. :param SWC2Align: valid file path of valid SWC file.
  154. :param tempOutFiles: list of two valid file paths, for temporary internal use.
  155. :param ipParFile: valid file path, for temporary internal use.
  156. :param scaleBounds: three member list of two member float lists, the bounds for scaling parameters
  157. along XYZ axes.
  158. :return: tempDones, presBestSol, presBestVal, presBestDone, presBestTrans
  159. tempDones: list of booleans, same size as srts, contains the value of 'done' of respective exhaustive
  160. searches
  161. presBestTrans: transform among srts leading to the lowest dissimilarity
  162. presBestSol: list of three floats, correspong transform parameters
  163. presBestDone: boolean, 'done' value of presBestTrans
  164. """
  165. presBestVal = 1e6
  166. presBestTrans = 'trans'
  167. presBestSol = [0, 0, 0]
  168. presBestDone = False
  169. tempDones = {}
  170. for g in srts:
  171. if g == 'scale':
  172. bestSol, bestVal, done = self.scaleOnce(SWC2Align, tempOutFiles[g], ipParFile, scaleBounds)
  173. elif g == 'rot':
  174. bestSol, bestVal, done = self.rotOnce(SWC2Align, tempOutFiles[g], ipParFile)
  175. elif g == 'trans':
  176. bestSol, bestVal, done = self.transOnce(SWC2Align, tempOutFiles[g], ipParFile)
  177. else:
  178. raise('Invalid transformation type ' + g)
  179. tempDones[g] = done
  180. if (bestVal == presBestVal and transPreference(presBestTrans, g)) or (bestVal < presBestVal):
  181. presBestTrans = g
  182. presBestVal = bestVal
  183. presBestSol = bestSol
  184. presBestDone = done
  185. return tempDones, presBestSol, presBestVal, presBestDone, presBestTrans
  186. def performReg(self, SWC2Align, resFile, scaleBounds,
  187. inPartsDir=None, outPartsDir=None,
  188. initGuessType='just_centroids',
  189. retainTempFiles=False):
  190. """
  191. Repeatedly applies translation, rotation and scaling transforms to SWC2Align to maximize its volume overlap
  192. with self.refSWC. See Reg-MaxS-N manuscript for more info.
  193. :param SWC2Align: valid file path of a valid SWC file, the SWC that is registered to self.refSWC
  194. :param resFile: valid file path, where SWC2Align registered to self.refSWC is written
  195. :param scaleBounds: three member list of two member float lists, the bounds for scaling parameters
  196. :param inPartsDir: valid directory path, any swc files with this will be transformed exactly same as
  197. SWC2Align and written in to outPartsDir
  198. :param outPartsDir: valid directory path
  199. :param initGuessType: string, valid values are 'just centroids' and 'nothing'. If 'just centroids', the
  200. centroids are initially matched, if 'nothing' they are not.
  201. :param retainTempFiles: boolean, whether to retain the intermediate files.
  202. :return: finalFile, finalSolFile
  203. finalFile: same as resFile
  204. finalSolFile: a file at <resFile name>Sol.txt where results of the process are logged.
  205. """
  206. resDir, expName = os.path.split(resFile[:-4])
  207. ipParFile = os.path.join(resDir, 'tmp.json')
  208. vals = ['trans', 'rot', 'scale']
  209. tempOutFiles = {}
  210. for val in vals:
  211. fle1 = os.path.join(resDir, val + '.swc')
  212. fle2 = os.path.join(resDir, val + 'bestSol.json')
  213. tempOutFiles[val] = [fle1, fle2]
  214. refMean = np.loadtxt(self.refSWC)[:, 2:5].mean(axis=0)
  215. iterationNo = 0
  216. tempOutPath = os.path.join(resDir, expName + 'trans')
  217. if not os.path.isdir(tempOutPath):
  218. os.mkdir(tempOutPath)
  219. SWC2AlignLocal = os.path.join(tempOutPath, str(iterationNo) + '.swc')
  220. SWC2AlignMean = np.loadtxt(SWC2Align)[:, 2:5].mean(axis=0)
  221. if initGuessType == 'just_centroids':
  222. transSWC(SWC2Align, np.eye(3), refMean - SWC2AlignMean, SWC2AlignLocal)
  223. totalTransform = np.eye(4)
  224. totalTransform[:3, 3] = -SWC2AlignMean
  225. totalTranslation = refMean
  226. elif initGuessType == 'nothing':
  227. shutil.copy(SWC2Align, SWC2AlignLocal)
  228. totalTransform = np.eye(4)
  229. totalTransform[:3, 3] = -SWC2AlignMean
  230. totalTranslation = SWC2AlignMean
  231. else:
  232. raise(ValueError('Unknown value for argument \'initGuessType\''))
  233. SWC2AlignT = SWC2AlignLocal
  234. scaleDone = False
  235. bestVals = {}
  236. while not scaleDone:
  237. done = False
  238. srts = ['rot', 'trans']
  239. while not done:
  240. tempDones, bestSol, bestVal, lDone, g = self.compare(srts, SWC2AlignT, tempOutFiles, ipParFile, None)
  241. outFile = os.path.join(tempOutPath, str(iterationNo) + g[0] + '.swc')
  242. outFileSol = os.path.join(tempOutPath, 'bestSol' + str(iterationNo) + g[0] + '.txt')
  243. shutil.copyfile(tempOutFiles[g][0], outFile)
  244. shutil.copyfile(tempOutFiles[g][1], outFileSol)
  245. with open(outFileSol, 'r') as fle:
  246. pars = json.load(fle)
  247. presTrans = np.array(pars['transMat'])
  248. if g == 'trans':
  249. totalTranslation += presTrans[:3, 3]
  250. else:
  251. totalTransform = np.dot(presTrans, totalTransform)
  252. print(str(iterationNo) + g)
  253. bestVals[bestVal] = {"outFile": outFile, "outFileSol": outFileSol,
  254. "totalTransform": totalTransform,
  255. "totalTranslation": totalTranslation,
  256. "iterationIndicator": str(iterationNo) + g
  257. }
  258. iterationNo += 1
  259. done = lDone
  260. SWC2AlignT = outFile
  261. bestSol, bestVal, sDone = self.scaleOnce(SWC2AlignT, tempOutFiles['scale'], ipParFile, scaleBounds)
  262. outFile = os.path.join(tempOutPath, str(iterationNo) + 's.swc')
  263. outFileSol = os.path.join(tempOutPath, 'bestSol' + str(iterationNo) + 's.txt')
  264. shutil.copyfile(tempOutFiles['scale'][0], outFile)
  265. shutil.copyfile(tempOutFiles['scale'][1], outFileSol)
  266. with open(outFileSol, 'r') as fle:
  267. pars = json.load(fle)
  268. presTrans = np.array(pars['transMat'])
  269. totalTransform = np.dot(presTrans, totalTransform)
  270. print(str(iterationNo) + 's')
  271. bestVals[bestVal] = {"outFile": outFile, "outFileSol": outFileSol,
  272. "totalTransform": totalTransform,
  273. "totalTranslation": totalTranslation,
  274. "iterationIndicator": str(iterationNo) + 's'
  275. }
  276. iterationNo += 1
  277. SWC2AlignT = outFile
  278. tempDones, bestSol, bestVal, lDone, g = self.compare(vals, SWC2AlignT, tempOutFiles, ipParFile, scaleBounds)
  279. scaleDone = all(tempDones.values())
  280. if not scaleDone:
  281. with open(tempOutFiles['rot'][1], 'r') as fle:
  282. pars = json.load(fle)
  283. rBestVal = pars['bestVal']
  284. with open(tempOutFiles['trans'][1], 'r') as fle:
  285. pars = json.load(fle)
  286. tBestVal = pars['bestVal']
  287. if rBestVal < tBestVal:
  288. g = 'rot'
  289. else:
  290. g = 'trans'
  291. outFile = os.path.join(tempOutPath, str(iterationNo) + g[0] + '.swc')
  292. outFileSol = os.path.join(tempOutPath, 'bestSol' + str(iterationNo) + g[0] + '.txt')
  293. shutil.copyfile(tempOutFiles[g][0], outFile)
  294. shutil.copyfile(tempOutFiles[g][1], outFileSol)
  295. with open(outFileSol, 'r') as fle:
  296. pars = json.load(fle)
  297. presTrans = np.array(pars['transMat'])
  298. if g == 'trans':
  299. totalTranslation += presTrans[:3, 3]
  300. else:
  301. totalTransform = np.dot(presTrans, totalTransform)
  302. print(str(iterationNo) + g)
  303. bestVals[bestVal] = {"outFile": outFile, "outFileSol": outFileSol,
  304. "totalTransform": totalTransform,
  305. "totalTranslation": totalTranslation,
  306. "iterationIndicator": str(iterationNo) + g
  307. }
  308. iterationNo += 1
  309. SWC2AlignT = outFile
  310. championBestVal = min(bestVals.keys())
  311. totalTransform = bestVals[championBestVal]["totalTransform"]
  312. totalTranslation = bestVals[championBestVal]["totalTranslation"]
  313. bestIterIndicator = bestVals[championBestVal]["iterationIndicator"]
  314. print("bestIter: {}, bestVal: {}".format(bestIterIndicator, championBestVal))
  315. totalTransform[:3, 3] += totalTranslation
  316. for g in vals:
  317. [os.remove(x) for x in tempOutFiles[g]]
  318. os.remove(ipParFile)
  319. if not retainTempFiles:
  320. shutil.rmtree(tempOutPath)
  321. finalFile = os.path.join(resDir, expName + '.swc')
  322. transSWC_rotAboutPoint(SWC2Align,
  323. totalTransform[:3, :3], totalTransform[:3, 3],
  324. finalFile,
  325. [0, 0, 0]
  326. )
  327. trans = SWCTranslate(self.refSWC, finalFile, self.gridSizes[-1])
  328. finalVal = objFun(([0, 0, 0], trans))
  329. finalSolFile = os.path.join(resDir, expName + 'Sol.txt')
  330. with open(finalSolFile, 'w') as fle:
  331. json.dump({'finalVal': finalVal, 'finalTransMat': totalTransform.tolist(), 'refSWC': self.refSWC,
  332. 'SWC2Align': SWC2Align, 'bestIteration': bestIterIndicator}, fle)
  333. if inPartsDir is not None:
  334. if os.path.isdir(inPartsDir):
  335. dirList = os.listdir(inPartsDir)
  336. dirList = [x for x in dirList if x.endswith('.swc')]
  337. if not os.path.isdir(outPartsDir):
  338. os.mkdir(outPartsDir)
  339. for entry in dirList:
  340. transSWC_rotAboutPoint(os.path.join(inPartsDir, entry),
  341. totalTransform[:3, :3], totalTransform[:3, 3],
  342. os.path.join(outPartsDir, entry),
  343. [0, 0, 0]
  344. )
  345. else:
  346. print('Specified partsDir {} not found'.format(inPartsDir))
  347. return finalFile, finalSolFile
  348. def composeRefSWC(alignedSWCs, newRefSWC, gridSize):
  349. """
  350. Given a list of SWCs, it constructs a fake SWC to represent the union of the volumes occupied by the SWCs.
  351. :param alignedSWCs: list of SWC files
  352. :param newRefSWC: valid file path, where the resulting SWC is written
  353. :param gridSize: float, the voxel size at which the volumes are discretized before forming the union.
  354. :return: dissim: float, 1 - (# of voxels in the intersection of volumes) / (# of voxels in the union of volumes)
  355. """
  356. indVoxs = []
  357. for aswc in alignedSWCs:
  358. aPts = np.loadtxt(aswc)[:, 2:5]
  359. aVox = np.array(np.round(aPts / gridSize), np.int32)
  360. aVoxSet = set(map(tuple, aVox))
  361. indVoxs.append(aVoxSet)
  362. # majority = [sum(x in y for y in indVoxs) >= 0.5 * len(indVoxs) for x in aUnion]
  363. # newRefVoxSet = [y for x, y in enumerate(aUnion) if majority[x]]
  364. aUnion = reduce(lambda x, y: x.union(y), indVoxs)
  365. aInt = reduce(lambda x, y: x.intersection(y), indVoxs)
  366. newRefVoxSet = aUnion
  367. newRefXYZ = np.array(list(newRefVoxSet)) * gridSize
  368. writeFakeSWC(newRefXYZ, newRefSWC)
  369. # print(len(aInt), len(aUnion))
  370. return 1 - len(aInt) / float(len(aUnion))
  371. def calcOverlap(refSWC, SWC2Align, gridSize):
  372. """
  373. Given two SWCs, it calculates a measure of dissimilarity between them using their discretized volumes.
  374. It's defined as 1 - size of intersection of the volumes / size of union of the volumes
  375. :param refSWC: valid file path to a valid SWC file
  376. :param SWC2Align: valid file path to a valid SWC file
  377. :param gridSize: float, the voxel size at which the volumes are discretized.
  378. :return: float, dissimilarity value
  379. """
  380. trans = SWCTranslate(refSWC, SWC2Align, gridSize)
  381. return objFun(([0, 0, 0], trans))
  382. def writeFakeSWC(data, fName, extraCol=None):
  383. """
  384. Forms a 7 column SWC data from the 3 column XYZ data in 'data' and writes it to a file at path fName adding
  385. a '!! Fake SWC !!' warning in the header.
  386. :param data: numpy.ndarray, 3 column XYZ data
  387. :param fName: valid file path to write the fake SWC file
  388. :param extraCol: iterable of the same size as the number of rows of data, will be added as the 8th column
  389. :return:
  390. """
  391. data = np.array(data)
  392. assert data.shape[1] == 3
  393. if extraCol is not None:
  394. extraCol = np.array(extraCol)
  395. assert extraCol.shape == (data.shape[0], ) or extraCol.shape == (data.shape[0], 1)
  396. toWrite = np.empty((data.shape[0], 8))
  397. toWrite[:, 7] = extraCol
  398. else:
  399. toWrite = np.empty((data.shape[0], 7))
  400. toWrite[:, 2:5] = data
  401. toWrite[:, 0] = range(1, data.shape[0] + 1)
  402. toWrite[:, 1] = 3
  403. toWrite[:, 5] = 1
  404. toWrite[:, 6] = -np.arange(1, data.shape[0] + 1)
  405. formatStr = '%d %d %0.6f %0.6f %0.6f %0.6f %d'
  406. if extraCol is not None:
  407. formatStr += ' %d'
  408. headr = '!! Fake SWC file !!'
  409. np.savetxt(fName, toWrite, fmt=formatStr, header=headr)