iterativeRegistration.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. import os
  2. import numpy as np
  3. from .swcFuncs import transSWC, transSWC_rotAboutPoint
  4. from .SWCTransforms import SWCTranslate, objFun
  5. import shutil
  6. import json
  7. import subprocess
  8. from functools import reduce
  9. import pathlib as pl
  10. def transPreference(x, y):
  11. """
  12. Given two transforms x and y, returns if x is preferred over y. Preferences are scaling > translation,
  13. scaling > rotation, translation > rotation.
  14. :param x:
  15. :param y:
  16. :return:
  17. """
  18. if x == 'scale':
  19. return 0
  20. elif y == 'scale':
  21. return 1
  22. elif x == 'rot':
  23. return 0
  24. else:
  25. return 1
  26. def getRemainderScale(scale, oldScale):
  27. """
  28. Elementwise divides oldScale by scale, effectively removing scale from old scale. The first entry is bounded
  29. above by 1 and the second entry is bounded below by 1.
  30. :param scale: 3 member list of 2 member float lists.
  31. :param oldScale: 3 member list of 2 member float lists.
  32. :return: 3 member list of 2 member float lists.
  33. """
  34. toReturn = []
  35. for s, oldS in zip(scale, oldScale):
  36. toReturn.append([min(oldS[0] / s, 1), max(oldS[1] / s, 1)])
  37. return toReturn
  38. class IterativeRegistration(object):
  39. """
  40. This class is used to run basic Reg-MaxS algorithm.
  41. """
  42. def __init__(self, refSWC, gridSizes, rotBounds, transBounds,
  43. transMinRes, scaleMinRes, rotMinRes, nCPU):
  44. """
  45. Initialization
  46. :param refSWC: valid file path to a valid SWC file, reference SWC
  47. :param gridSizes: list of three floats, the voxel sizes over which estimations are run, in micrometer
  48. :param rotBounds: three member list of two member float lists, the bounds for rotation euler angles abour XYZ
  49. axes, in radians
  50. :param transBounds: three member list of two member float lists, the bounds for translations along XYZ axes.
  51. :param transMinRes: float, minimum resolution of exhaustive search for translation parameters in micrometer.
  52. :param scaleMinRes: float. minimum (multiplicative) resolution of exhuasitve search for scaling paramers.
  53. :param rotMinRes: float. minimum resolution of exhuastive search for rotation euler angle parameters in radians.
  54. :param nCPU: int, number of processes to use
  55. """
  56. super(IterativeRegistration, self).__init__()
  57. self.refSWC = refSWC
  58. self.gridSizes = gridSizes
  59. self.rotBounds = rotBounds
  60. self.transBounds = transBounds
  61. self.rotBounds = rotBounds
  62. self.transMinRes = transMinRes
  63. self.rotMinRes = rotMinRes
  64. self.scaleMinRes = scaleMinRes
  65. self.nCPU = nCPU
  66. self.allFuncs = {'trans': self.transOnce, 'rot': self.rotOnce, 'scale': self.scaleOnce}
  67. def rotOnce(self, SWC2Align, outFiles, ipParFile):
  68. """
  69. Runs exhaustive search to find the best rotation euler angles about XYZ axes that maximize the volume overlap
  70. between SWC2Align and self.refSWC. Results are written in the two file paths of outFiles.
  71. If solution found is no better than doing nothing or if the angles found are lower than minimum resolution,
  72. zero angles are returned with done as true
  73. :param SWC2Align: valid file path to a valid SWC file
  74. :param outFiles: list of two valid file paths. SWC2Align rotated with optimum parameters is written to
  75. outFiles[0], a log file of the process is written into ouFiles[1].
  76. :param ipParFile: valid file path. Temporary file used.
  77. :return: bestSol, bestVal, done
  78. bestSol: list of three floats, best Euler angles in radians
  79. bestVal: float, best value of dissimilarity between SWC2Align and refSWC at the lowest voxel size
  80. done: boolean, see above.
  81. """
  82. pars = [self.refSWC, SWC2Align, outFiles,
  83. self.gridSizes, self.rotBounds, self.rotMinRes, self.nCPU]
  84. with open(ipParFile, 'w') as fle:
  85. json.dump(pars, fle)
  86. f2call = os.path.join(os.path.split(__file__)[0], 'rotOnce.py')
  87. subprocess.call(['python', f2call, ipParFile])
  88. with open(outFiles[1], 'r') as fle:
  89. out = json.load(fle)
  90. bestSol = out['bestSol']
  91. done = out['done']
  92. bestVal = out['bestVal']
  93. print((bestSol, bestVal, done))
  94. return bestSol, bestVal, done
  95. def transOnce(self, SWC2Align, outFiles, ipParFile):
  96. """
  97. Runs exhaustive search to find the best translations along XYZ axes that maximize the volume overlap between
  98. SWC2Align and self.refSWC. Results are written in the two file paths of outFiles.
  99. If solution found is no better than doing nothing or if the translations found are lower than
  100. the minimum resolution, zero translations are returned with done set to true
  101. :param SWC2Align: valid file path to a valid SWC file
  102. :param outFiles: list of two valid file paths. SWC2Align translated with optimum parameters is written to
  103. outFiles[0], a log file of the process is written into ouFiles[1].
  104. :param ipParFile: valid file path. Temporary file used.
  105. :return: bestSol, bestVal, done
  106. bestSol: list of three floats, best translations in micrometer
  107. bestVal: float, best value of dissimilarity between SWC2Align and refSWC at the lowest voxel size
  108. done: boolean, see above.
  109. """
  110. pars = [self.refSWC, SWC2Align, outFiles,
  111. self.gridSizes, self.transBounds, self.transMinRes, self.nCPU]
  112. with open(ipParFile, 'w') as fle:
  113. json.dump(pars, fle)
  114. f2call = os.path.join(os.path.split(__file__)[0], 'transOnce.py')
  115. subprocess.call(['python', f2call, ipParFile])
  116. with open(outFiles[1], 'r') as fle:
  117. out = json.load(fle)
  118. bestSol = out['bestSol']
  119. done = out['done']
  120. bestVal = out['bestVal']
  121. print((bestSol, bestVal, done))
  122. return bestSol, bestVal, done
  123. def scaleOnce(self, SWC2Align, outFiles, ipParFile, scaleBounds):
  124. """
  125. Runs exhaustive search to find the best scaling parameters along XYZ axes that maximize the volume overlap
  126. between SWC2Align and self.refSWC. Results are written in the two file paths of outFiles.
  127. If solution found is no better than doing nothing or if the scaling parameters found are lower than
  128. the minimum resolution, unity scaling parameters are returned with done set to true
  129. :param SWC2Align: valid file path to a valid SWC file
  130. :param outFiles: list of two valid file paths. SWC2Align scaled with optimum parameters is written to
  131. outFiles[0], a log file of the process is written into ouFiles[1].
  132. :param ipParFile: valid file path. Temporary file used.
  133. :return: bestSol, bestVal, done
  134. bestSol: list of three floats, best scaling parameters
  135. bestVal: float, best value of dissimilarity between SWC2Align and refSWC at the lowest voxel size
  136. done: boolean, see above.
  137. """
  138. pars = [self.refSWC, SWC2Align, outFiles,
  139. self.gridSizes, scaleBounds, self.scaleMinRes, self.nCPU]
  140. with open(ipParFile, 'w') as fle:
  141. json.dump(pars, fle)
  142. f2call = os.path.join(os.path.split(__file__)[0], 'scaleOnce.py')
  143. subprocess.call(['python', f2call, ipParFile])
  144. with open(outFiles[1], 'r') as fle:
  145. out = json.load(fle)
  146. bestSol = out['bestSol']
  147. done = out['done']
  148. bestVal = out['bestVal']
  149. print((bestSol, bestVal, done))
  150. return bestSol, bestVal, done
  151. def compare(self, srts, SWC2Align, tempOutFiles, ipParFile, scaleBounds):
  152. """
  153. Runs the exhaustive searchs for the transforms in srts and returns some info about the searches
  154. :param srts: list of strings, valid entries are 'scale', 'trans' and 'rot'
  155. :param SWC2Align: valid file path of valid SWC file.
  156. :param tempOutFiles: list of two valid file paths, for temporary internal use.
  157. :param ipParFile: valid file path, for temporary internal use.
  158. :param scaleBounds: three member list of two member float lists, the bounds for scaling parameters
  159. along XYZ axes.
  160. :return: tempDones, presBestSol, presBestVal, presBestDone, presBestTrans
  161. tempDones: list of booleans, same size as srts, contains the value of 'done' of respective exhaustive
  162. searches
  163. presBestTrans: transform among srts leading to the lowest dissimilarity
  164. presBestSol: list of three floats, correspong transform parameters
  165. presBestDone: boolean, 'done' value of presBestTrans
  166. """
  167. presBestVal = 1e6
  168. presBestTrans = 'trans'
  169. presBestSol = [0, 0, 0]
  170. presBestDone = False
  171. tempDones = {}
  172. for g in srts:
  173. if g == 'scale':
  174. bestSol, bestVal, done = self.scaleOnce(SWC2Align, tempOutFiles[g], ipParFile, scaleBounds)
  175. elif g == 'rot':
  176. bestSol, bestVal, done = self.rotOnce(SWC2Align, tempOutFiles[g], ipParFile)
  177. elif g == 'trans':
  178. bestSol, bestVal, done = self.transOnce(SWC2Align, tempOutFiles[g], ipParFile)
  179. else:
  180. raise ValueError(f'Invalid transformation type {g}')
  181. tempDones[g] = done
  182. if (bestVal == presBestVal and transPreference(presBestTrans, g)) or (bestVal < presBestVal):
  183. presBestTrans = g
  184. presBestVal = bestVal
  185. presBestSol = bestSol
  186. presBestDone = done
  187. return tempDones, presBestSol, presBestVal, presBestDone, presBestTrans
  188. def performReg(self, SWC2Align, resFile, scaleBounds,
  189. inPartsDir=None, outPartsDir=None,
  190. initGuessType='just_centroids',
  191. retainTempFiles=False):
  192. """
  193. Repeatedly applies translation, rotation and scaling transforms to SWC2Align to maximize its volume overlap
  194. with self.refSWC. See Reg-MaxS-N manuscript for more info.
  195. :param SWC2Align: valid file path of a valid SWC file, the SWC that is registered to self.refSWC
  196. :param resFile: valid file path, where SWC2Align registered to self.refSWC is written
  197. :param scaleBounds: three member list of two member float lists, the bounds for scaling parameters
  198. :param inPartsDir: valid directory path, any swc files with this will be transformed exactly same as
  199. SWC2Align and written in to outPartsDir
  200. :param outPartsDir: valid directory path
  201. :param initGuessType: string, valid values are 'just centroids' and 'nothing'. If 'just centroids', the
  202. centroids are initially matched, if 'nothing' they are not.
  203. :param retainTempFiles: boolean, whether to retain the intermediate files.
  204. :return: finalFile, finalSolFile
  205. finalFile: same as resFile
  206. finalSolFile: a file at <resFile name>Sol.txt where results of the process are logged.
  207. """
  208. resDir, expName = os.path.split(resFile[:-4])
  209. pl.Path(resDir).mkdir(exist_ok=True)
  210. ipParFile = os.path.join(resDir, 'tmp.json')
  211. vals = ['trans', 'rot', 'scale']
  212. tempOutFiles = {}
  213. for val in vals:
  214. fle1 = os.path.join(resDir, val + '.swc')
  215. fle2 = os.path.join(resDir, val + 'bestSol.json')
  216. tempOutFiles[val] = [fle1, fle2]
  217. refMean = np.loadtxt(self.refSWC)[:, 2:5].mean(axis=0)
  218. iterationNo = 0
  219. tempOutPath = os.path.join(resDir, expName + 'trans')
  220. if not os.path.isdir(tempOutPath):
  221. os.mkdir(tempOutPath)
  222. SWC2AlignLocal = os.path.join(tempOutPath, str(iterationNo) + '.swc')
  223. SWC2AlignMean = np.loadtxt(SWC2Align)[:, 2:5].mean(axis=0)
  224. if initGuessType == 'just_centroids':
  225. transSWC(SWC2Align, np.eye(3), refMean - SWC2AlignMean, SWC2AlignLocal)
  226. totalTransform = np.eye(4)
  227. totalTransform[:3, 3] = -SWC2AlignMean
  228. totalTranslation = refMean
  229. elif initGuessType == 'nothing':
  230. shutil.copy(SWC2Align, SWC2AlignLocal)
  231. totalTransform = np.eye(4)
  232. totalTransform[:3, 3] = -SWC2AlignMean
  233. totalTranslation = SWC2AlignMean
  234. else:
  235. raise ValueError('Unknown value for argument \'initGuessType\'')
  236. SWC2AlignT = SWC2AlignLocal
  237. scaleDone = False
  238. bestVals = {}
  239. while not scaleDone:
  240. done = False
  241. srts = ['rot', 'trans']
  242. while not done:
  243. tempDones, bestSol, bestVal, lDone, g = self.compare(srts, SWC2AlignT, tempOutFiles, ipParFile, None)
  244. outFile = os.path.join(tempOutPath, str(iterationNo) + g[0] + '.swc')
  245. outFileSol = os.path.join(tempOutPath, 'bestSol' + str(iterationNo) + g[0] + '.txt')
  246. shutil.copyfile(tempOutFiles[g][0], outFile)
  247. shutil.copyfile(tempOutFiles[g][1], outFileSol)
  248. with open(outFileSol, 'r') as fle:
  249. pars = json.load(fle)
  250. presTrans = np.array(pars['transMat'])
  251. if g == 'trans':
  252. totalTranslation += presTrans[:3, 3]
  253. else:
  254. totalTransform = np.dot(presTrans, totalTransform)
  255. print((str(iterationNo) + g))
  256. bestVals[bestVal] = {"outFile": outFile, "outFileSol": outFileSol,
  257. "totalTransform": totalTransform,
  258. "totalTranslation": totalTranslation,
  259. "iterationIndicator": str(iterationNo) + g
  260. }
  261. iterationNo += 1
  262. done = lDone
  263. SWC2AlignT = outFile
  264. bestSol, bestVal, sDone = self.scaleOnce(SWC2AlignT, tempOutFiles['scale'], ipParFile, scaleBounds)
  265. outFile = os.path.join(tempOutPath, str(iterationNo) + 's.swc')
  266. outFileSol = os.path.join(tempOutPath, 'bestSol' + str(iterationNo) + 's.txt')
  267. shutil.copyfile(tempOutFiles['scale'][0], outFile)
  268. shutil.copyfile(tempOutFiles['scale'][1], outFileSol)
  269. with open(outFileSol, 'r') as fle:
  270. pars = json.load(fle)
  271. presTrans = np.array(pars['transMat'])
  272. totalTransform = np.dot(presTrans, totalTransform)
  273. print((str(iterationNo) + 's'))
  274. bestVals[bestVal] = {"outFile": outFile, "outFileSol": outFileSol,
  275. "totalTransform": totalTransform,
  276. "totalTranslation": totalTranslation,
  277. "iterationIndicator": str(iterationNo) + 's'
  278. }
  279. iterationNo += 1
  280. SWC2AlignT = outFile
  281. tempDones, bestSol, bestVal, lDone, g = self.compare(vals, SWC2AlignT, tempOutFiles, ipParFile, scaleBounds)
  282. scaleDone = all(tempDones.values())
  283. if not scaleDone:
  284. with open(tempOutFiles['rot'][1], 'r') as fle:
  285. pars = json.load(fle)
  286. rBestVal = pars['bestVal']
  287. with open(tempOutFiles['trans'][1], 'r') as fle:
  288. pars = json.load(fle)
  289. tBestVal = pars['bestVal']
  290. if rBestVal < tBestVal:
  291. g = 'rot'
  292. else:
  293. g = 'trans'
  294. outFile = os.path.join(tempOutPath, str(iterationNo) + g[0] + '.swc')
  295. outFileSol = os.path.join(tempOutPath, 'bestSol' + str(iterationNo) + g[0] + '.txt')
  296. shutil.copyfile(tempOutFiles[g][0], outFile)
  297. shutil.copyfile(tempOutFiles[g][1], outFileSol)
  298. with open(outFileSol, 'r') as fle:
  299. pars = json.load(fle)
  300. presTrans = np.array(pars['transMat'])
  301. if g == 'trans':
  302. totalTranslation += presTrans[:3, 3]
  303. else:
  304. totalTransform = np.dot(presTrans, totalTransform)
  305. print((str(iterationNo) + g))
  306. bestVals[bestVal] = {"outFile": outFile, "outFileSol": outFileSol,
  307. "totalTransform": totalTransform,
  308. "totalTranslation": totalTranslation,
  309. "iterationIndicator": str(iterationNo) + g
  310. }
  311. iterationNo += 1
  312. SWC2AlignT = outFile
  313. championBestVal = min(bestVals.keys())
  314. totalTransform = bestVals[championBestVal]["totalTransform"]
  315. totalTranslation = bestVals[championBestVal]["totalTranslation"]
  316. bestIterIndicator = bestVals[championBestVal]["iterationIndicator"]
  317. print(("bestIter: {}, bestVal: {}".format(bestIterIndicator, championBestVal)))
  318. totalTransform[:3, 3] += totalTranslation
  319. for g in vals:
  320. [os.remove(x) for x in tempOutFiles[g]]
  321. os.remove(ipParFile)
  322. if not retainTempFiles:
  323. shutil.rmtree(tempOutPath)
  324. finalFile = os.path.join(resDir, expName + '.swc')
  325. transSWC_rotAboutPoint(SWC2Align,
  326. totalTransform[:3, :3], totalTransform[:3, 3],
  327. finalFile,
  328. [0, 0, 0]
  329. )
  330. trans = SWCTranslate(self.refSWC, finalFile, self.gridSizes[-1])
  331. finalVal = objFun(([0, 0, 0], trans))
  332. finalSolFile = os.path.join(resDir, expName + 'Sol.txt')
  333. with open(finalSolFile, 'w') as fle:
  334. json.dump({'finalVal': finalVal, 'finalTransMat': totalTransform.tolist(), 'refSWC': self.refSWC,
  335. 'SWC2Align': SWC2Align, 'bestIteration': bestIterIndicator}, fle)
  336. if inPartsDir is not None:
  337. if os.path.isdir(inPartsDir):
  338. dirList = os.listdir(inPartsDir)
  339. dirList = [x for x in dirList if x.endswith('.swc')]
  340. if not os.path.isdir(outPartsDir):
  341. os.mkdir(outPartsDir)
  342. for entry in dirList:
  343. transSWC_rotAboutPoint(os.path.join(inPartsDir, entry),
  344. totalTransform[:3, :3], totalTransform[:3, 3],
  345. os.path.join(outPartsDir, entry),
  346. [0, 0, 0]
  347. )
  348. else:
  349. print(('Specified partsDir {} not found'.format(inPartsDir)))
  350. return finalFile, finalSolFile
  351. def composeRefSWC(alignedSWCs, newRefSWC, gridSize):
  352. """
  353. Given a list of SWCs, it constructs a fake SWC to represent the union of the volumes occupied by the SWCs.
  354. :param alignedSWCs: list of SWC files
  355. :param newRefSWC: valid file path, where the resulting SWC is written
  356. :param gridSize: float, the voxel size at which the volumes are discretized before forming the union.
  357. :return: dissim: float, 1 - (# of voxels in the intersection of volumes) / (# of voxels in the union of volumes)
  358. """
  359. indVoxs = []
  360. for aswc in alignedSWCs:
  361. aPts = np.loadtxt(aswc)[:, 2:5]
  362. aVox = np.array(np.round(aPts / gridSize), np.int32)
  363. aVoxSet = set(map(tuple, aVox))
  364. indVoxs.append(aVoxSet)
  365. # majority = [sum(x in y for y in indVoxs) >= 0.5 * len(indVoxs) for x in aUnion]
  366. # newRefVoxSet = [y for x, y in enumerate(aUnion) if majority[x]]
  367. aUnion = reduce(lambda x, y: x.union(y), indVoxs)
  368. aInt = reduce(lambda x, y: x.intersection(y), indVoxs)
  369. newRefVoxSet = aUnion
  370. newRefXYZ = np.array(list(newRefVoxSet)) * gridSize
  371. writeFakeSWC(newRefXYZ, newRefSWC)
  372. # print(len(aInt), len(aUnion))
  373. return 1 - len(aInt) / float(len(aUnion))
  374. def calcOverlap(refSWC, SWC2Align, gridSize):
  375. """
  376. Given two SWCs, it calculates a measure of dissimilarity between them using their discretized volumes.
  377. It's defined as 1 - size of intersection of the volumes / size of union of the volumes
  378. :param refSWC: valid file path to a valid SWC file
  379. :param SWC2Align: valid file path to a valid SWC file
  380. :param gridSize: float, the voxel size at which the volumes are discretized.
  381. :return: float, dissimilarity value
  382. """
  383. trans = SWCTranslate(refSWC, SWC2Align, gridSize)
  384. return objFun(([0, 0, 0], trans))
  385. def writeFakeSWC(data, fName, extraCol=None):
  386. """
  387. Forms a 7 column SWC data from the 3 column XYZ data in 'data' and writes it to a file at path fName adding
  388. a '!! Fake SWC !!' warning in the header.
  389. :param data: numpy.ndarray, 3 column XYZ data
  390. :param fName: valid file path to write the fake SWC file
  391. :param extraCol: iterable of the same size as the number of rows of data, will be added as the 8th column
  392. :return:
  393. """
  394. data = np.array(data)
  395. assert data.shape[1] == 3
  396. if extraCol is not None:
  397. extraCol = np.array(extraCol)
  398. assert extraCol.shape == (data.shape[0], ) or extraCol.shape == (data.shape[0], 1)
  399. toWrite = np.empty((data.shape[0], 8))
  400. toWrite[:, 7] = extraCol
  401. else:
  402. toWrite = np.empty((data.shape[0], 7))
  403. toWrite[:, 2:5] = data
  404. toWrite[:, 0] = list(range(1, data.shape[0] + 1))
  405. toWrite[:, 1] = 3
  406. toWrite[:, 5] = 1
  407. toWrite[:, 6] = -np.arange(1, data.shape[0] + 1)
  408. formatStr = '%d %d %0.6f %0.6f %0.6f %0.6f %d'
  409. if extraCol is not None:
  410. formatStr += ' %d'
  411. headr = '!! Fake SWC file !!'
  412. np.savetxt(fName, toWrite, fmt=formatStr, header=headr)