SWCTransforms.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import os
  2. import numpy as np
  3. from transforms import compose_matrix
  4. from swcFuncs import readSWC_numpy, writeSWC_numpy
  5. def three32BitInt2complexList(arr):
  6. assert arr.dtype == np.int32, 'Only 32 bit ints allowed'
  7. assert arr.shape[1] == 3, 'only three ints per row allowed'
  8. temp = np.zeros((arr.shape[0], 4), dtype=np.int32)
  9. temp[:, 1:] = arr
  10. temp1 = temp.view(np.int64)
  11. j = complex(0, 1)
  12. return (temp1[:, 0] + j * temp1[:, 1]).tolist()
  13. def objFun(y):
  14. x, data = y
  15. nrnPtsTrans = data.transform(x, data.SWC2AlignPts)
  16. nrnVox = np.array(np.round(nrnPtsTrans / data.gridSize), np.int32)
  17. nrnVoxSet = set(three32BitInt2complexList(nrnVox))
  18. nInter = len(nrnVoxSet.intersection(data.refVoxSet))
  19. nUnion = len(nrnVoxSet) + len(data.refVoxSet) - nInter
  20. sumOfDiffs = 1 - (nInter / float(nUnion))
  21. return sumOfDiffs
  22. class BaseSWCTranform(object):
  23. def __init__(self, refSWC, SWC2Align, gridSize):
  24. super(BaseSWCTranform, self).__init__()
  25. if os.path.isfile(refSWC) and refSWC.endswith('.swc'):
  26. self.refSWCPts = np.loadtxt(refSWC)[:, 2:5]
  27. elif type(refSWC) == np.ndarray:
  28. self.refSWCPts = refSWC[:, 2:5]
  29. else:
  30. raise(ValueError('Unknown data in SWC2Align'))
  31. self.refCenter = self.refSWCPts.mean(axis=0)
  32. refVox = np.array(np.round(self.refSWCPts / gridSize), np.int32)
  33. self.gridSize = gridSize
  34. self.refVoxSet = set(three32BitInt2complexList(refVox))
  35. if os.path.isfile(SWC2Align) and SWC2Align.endswith('.swc'):
  36. self.headr, self.SWC2AlignFull = readSWC_numpy(SWC2Align)
  37. elif type(SWC2Align) == np.ndarray:
  38. assert type(SWC2Align['data']) == np.ndarray, 'Unknown data in SWC2Align'
  39. self.SWC2AlignFull = SWC2Align
  40. self.headr = ''
  41. else:
  42. raise(ValueError('Unknown data in SWC2Align'))
  43. self.SWC2AlignPts = self.SWC2AlignFull[:, 2:5]
  44. self.center = self.SWC2AlignPts.mean(axis=0)
  45. def transform(self, pars, data):
  46. return data
  47. def writeSolution(self, outFile, bestSol, inFile=None):
  48. if inFile:
  49. headr, data = readSWC_numpy(inFile)
  50. else:
  51. headr, data = self.headr, self.SWC2AlignFull
  52. data[:, 2:5] = self.transform(bestSol, data[:, 2:5])
  53. writeSWC_numpy(outFile, data, headr)
  54. class SWCTranslate(BaseSWCTranform):
  55. def __init__(self, refSWC, SWC2Align, gridSize):
  56. super(SWCTranslate, self).__init__(refSWC, SWC2Align, gridSize)
  57. def transform(self, pars, data):
  58. return data + pars
  59. class SWCRotate(BaseSWCTranform):
  60. def __init__(self, refSWC, SWC2Align, gridSize):
  61. super(SWCRotate, self).__init__(refSWC, SWC2Align, gridSize)
  62. def transform(self, pars, data):
  63. rotMat = compose_matrix(angles=pars)
  64. dataCentered = data - self.center
  65. return np.dot(rotMat[:3, :3], dataCentered.T).T + self.center
  66. class SWCScale(object):
  67. def __init__(self, refSWC, SWC2Align, gridSize):
  68. super(SWCScale, self).__init__()
  69. self.gridSize = gridSize
  70. if os.path.isfile(refSWC) and refSWC.endswith('.swc'):
  71. self.refSWCPts = np.loadtxt(refSWC)[:, 2:5]
  72. elif type(refSWC) == np.ndarray:
  73. self.refSWCPts = refSWC[:, 2:5]
  74. else:
  75. raise(ValueError('Unknown data in SWC2Align'))
  76. refCenter = self.refSWCPts.mean(axis=0)
  77. refSWCPtsCentered = self.refSWCPts - refCenter
  78. refVox = np.array(np.round(refSWCPtsCentered / gridSize), np.int32)
  79. self.refVoxSet = set(three32BitInt2complexList(refVox))
  80. if os.path.isfile(SWC2Align) and SWC2Align.endswith('.swc'):
  81. self.headr, self.SWC2AlignFull = readSWC_numpy(SWC2Align)
  82. elif type(SWC2Align) == np.ndarray:
  83. assert type(SWC2Align['data']) == np.ndarray, 'Unknown data in SWC2Align'
  84. self.SWC2AlignFull = SWC2Align
  85. self.headr = ''
  86. else:
  87. raise(ValueError('Unknown data in SWC2Align'))
  88. self.SWC2AlignPts = self.SWC2AlignFull[:, 2:5].copy()
  89. self.center = self.SWC2AlignPts.mean(axis=0)
  90. self.SWC2AlignPts -= self.center
  91. def transform(self, pars, data):
  92. rotMat = compose_matrix(scale=pars)
  93. dataCentered = data
  94. return np.dot(rotMat[:3, :3], dataCentered.T).T
  95. def writeSolution(self, outFile, bestSol, inFile=None):
  96. if inFile:
  97. headr, data = readSWC_numpy(inFile)
  98. else:
  99. headr, data = self.headr, self.SWC2AlignFull
  100. data[:, 2:5] = self.transform(bestSol, data[:, 2:5] - self.center) + self.center
  101. writeSWC_numpy(outFile, data, headr)
  102. class ArgGenIterator:
  103. def __init__(self, arg1, arg2):
  104. self.arg1 = arg1
  105. self.arg2 = arg2
  106. self.pointsDone = 0
  107. def __iter__(self):
  108. self.pointsDone = 0
  109. return self
  110. def next(self):
  111. if self.pointsDone < len(self.arg1):
  112. toReturn = (self.arg1[self.pointsDone], self.arg2)
  113. self.pointsDone += 1
  114. return toReturn
  115. else:
  116. raise StopIteration