plotDensities.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from .swcFuncs import resampleSWC, digitizeSWCXYZ, getPCADetails, readSWC_numpy, writeSWC_numpy
  2. import numpy as np
  3. import os
  4. from scipy.ndimage import gaussian_filter
  5. import tifffile
  6. class DensityVizualizations(object):
  7. def __init__(self, swcSet, gridUnitSizes, resampleLen,
  8. masks=None, pcaView=None, refSWC=None, initTrans=None):
  9. if initTrans is None:
  10. initTrans = np.eye(3)
  11. self.gridUnitSizes = np.array(gridUnitSizes)
  12. if masks is None:
  13. masks = [None] * len(swcSet)
  14. else:
  15. assert len(masks) == len(swcSet), 'Improper value for masks1'
  16. minXYZs = []
  17. maxXYZs = []
  18. self.transMat = np.eye(4)
  19. datas = {}
  20. for swcInd, swcFile in enumerate(swcSet):
  21. print('Resamping ' + swcFile)
  22. totalLen, data = resampleSWC(swcFile, resampleLen, mask=masks[swcInd])
  23. dataT = np.dot(initTrans, data[:, :3].T).T
  24. datas[swcFile] = dataT
  25. self.transMat[:3, :3] = initTrans
  26. allData = np.concatenate(tuple(datas.itervalues()), axis=0)
  27. self.allDataMean = allData.mean(axis=0)
  28. if pcaView == 'closestPCMatch':
  29. evecs, newStds = getPCADetails(None, center=True, data=allData)
  30. mean2Use = self.allDataMean
  31. if refSWC:
  32. refEvecs, thrash = getPCADetails(refSWC, center=True)
  33. fEvecs = np.empty_like(refEvecs)
  34. coreff = np.dot(refEvecs.T, evecs)
  35. possInds = range(refEvecs.shape[1])
  36. for rowInd in range(refEvecs.shape[1]):
  37. bestCorrInd = np.argmax(np.abs(coreff[rowInd, possInds]))
  38. fEvecs[:, rowInd] = np.sign(coreff[rowInd, possInds[bestCorrInd]]) * evecs[:, possInds[bestCorrInd]]
  39. possInds.pop(int(bestCorrInd))
  40. else:
  41. fEvecs = evecs
  42. elif pcaView == 'assumeRegistered':
  43. if refSWC:
  44. refEvecs, thrash = getPCADetails(refSWC, center=True)
  45. fEvecs = refEvecs
  46. mean2Use = np.loadtxt(refSWC)[:, 2:5].mean(axis=0)
  47. else:
  48. raise(ValueError('RefSWC must be specified when pcaView == \'assumeRegistered\''))
  49. else:
  50. fEvecs = np.eye(3)
  51. mean2Use = self.allDataMean
  52. self.digDatas = {}
  53. for swcFile, data in datas.iteritems():
  54. print('Digitizing ' + swcFile)
  55. data -= mean2Use
  56. data = np.dot(fEvecs.T, data.T).T
  57. digData = digitizeSWCXYZ(data + mean2Use, gridUnitSizes)
  58. self.digDatas[swcFile] = digData
  59. minXYZs.append(digData[:, :3].min(axis=0))
  60. maxXYZs.append(digData[:, :3].max(axis=0))
  61. temp = np.eye(4)
  62. temp[:3, 3] = -mean2Use
  63. self.transMat = np.dot(temp, self.transMat)
  64. temp = np.eye(4)
  65. temp[:3, :3] = fEvecs.T
  66. self.transMat = np.dot(temp, self.transMat)
  67. self.transMat[:3, 3] += mean2Use
  68. self.minXYZ = np.array(minXYZs).min(axis=0) - 20
  69. self.maxXYZ = np.array(maxXYZs).max(axis=0) + 20
  70. self.bins = [np.arange(x, y + 1) * z for x, y, z in zip(self.minXYZ, self.maxXYZ, self.gridUnitSizes)]
  71. def calculateDensity(self, swcFiles, sigmas):
  72. assert all(np.greater_equal(sigmas, self.gridUnitSizes)), 'sigma along every dimenstion must be larger than gridUnitSize'
  73. digSigs = np.around(np.array(sigmas) / self.gridUnitSizes)
  74. densityMatSum = np.zeros(tuple((self.maxXYZ - self.minXYZ).tolist()))
  75. for swcFile in swcFiles:
  76. print('Calculating Density for ' + swcFile)
  77. densityMat = np.zeros_like(densityMatSum)
  78. print('Doing ' + os.path.split(swcFile)[1])
  79. if swcFile in self.digDatas:
  80. digDataTranslated = self.digDatas[swcFile][:, :3] - self.minXYZ
  81. else:
  82. raise(ValueError(swcFile + ' not initialized in constructing DensityVizualizations object'))
  83. densityMat[digDataTranslated[:, 0], digDataTranslated[:, 1], digDataTranslated[:, 2]] = 1
  84. densityMatSum += densityMat
  85. del densityMat
  86. densityMatSum /= float(len(swcFiles))
  87. smoothDensityMat = gaussian_filter(densityMatSum, sigma=digSigs, truncate=3)
  88. del densityMatSum
  89. smoothDensityMat *= (2 ** 1.5) * digSigs.prod()
  90. smoothDensityMat[smoothDensityMat > 1] = 1
  91. smoothDensityMat[smoothDensityMat < 0] = 0
  92. return smoothDensityMat, self.bins
  93. def generateDensityColoredSSWC(self, swcFiles, outFiles, density=None, sigmas=None):
  94. if density is None:
  95. density, bins = self.calculateDensity(swcFiles, sigmas)
  96. for swcInd, swcFile in enumerate(swcFiles):
  97. headr, data = readSWC_numpy(swcFile)
  98. dataXYZ = data[:, 2:5]
  99. rotData = np.dot(self.transMat[:3, :3], dataXYZ.T).T + self.transMat[:3, 3]
  100. digData = digitizeSWCXYZ(rotData, self.gridUnitSizes)
  101. digDataTranslated = digData - self.minXYZ
  102. colorInds = [density[x[0], x[1], x[2]] for x in digDataTranslated]
  103. colorInds = np.reshape(colorInds, (len(colorInds), 1))
  104. colorInds[colorInds > 1] = 1
  105. colorInds[colorInds < 0] = 0
  106. # data[:, 2:5] = rotData
  107. outData = np.concatenate((data, colorInds), axis=1)
  108. writeSWC_numpy(outFiles[swcInd], outData, headr)
  109. def writeTIFF(density, outFile):
  110. assert type(density) is np.ndarray, 'density must be a numpy ndarray'
  111. assert len(density.shape) == 3, 'density must a 3d numpy array'
  112. densityUInt8 = np.array(density * 255, dtype=np.uint8)
  113. tiffArgs = {
  114. 'compress': 0,
  115. # 'planarconfig': 'planar',
  116. 'photometric': 'minisblack'
  117. }
  118. tifffile.imsave(outFile + '.tiff', densityUInt8, **tiffArgs)