swcFuncs.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import numpy as np
  2. #***********************************************************************************************************************
  3. def readSWC_numpy(swcFile):
  4. '''
  5. Read the return the header and matrix data in a swcFile
  6. :param swcFile: filename
  7. :return: header (string), matrix data (ndarray)
  8. '''
  9. headr = ''
  10. with open(swcFile, 'r') as fle:
  11. lne = fle.readline()
  12. while lne[0] == '#':
  13. headr = headr + lne[1:]
  14. lne = fle.readline()
  15. headr = headr.rstrip('\n')
  16. swcData = np.loadtxt(swcFile)
  17. return headr, swcData
  18. #***********************************************************************************************************************
  19. def writeSWC_numpy(fName, swcData, headr=''):
  20. '''
  21. Write the SWC data in swcData to the file fName with the header headr
  22. :param fName: string
  23. :param swcData: 2D numpy.ndarray with 7 or 8 columns
  24. :param headr: string
  25. :return:
  26. '''
  27. swcData = np.array(swcData)
  28. assert swcData.shape[1] in [7, 8], 'Width given SWC Matrix data is incompatible.'
  29. formatStr = '%d %d %0.6f %0.6f %0.6f %0.6f %d'
  30. if swcData.shape[1] == 8:
  31. formatStr += ' %0.6f'
  32. np.savetxt(fName, swcData, formatStr, header=headr, comments='#')
  33. #***********************************************************************************************************************
  34. def transSWC(fName, A, b, destFle):
  35. '''
  36. Generate an SWC file at destFle with each point `x' of the morphology in fName transformed Affinely as Ax+b
  37. :param fName: string
  38. :param A: 2D numpy.ndarray of shape (3, 3)
  39. :param b: 3 member iterable
  40. :param destFle: string
  41. :return:
  42. '''
  43. headr = ''
  44. with open(fName, 'r') as fle:
  45. lne = fle.readline()
  46. while lne[0] == '#':
  47. headr = headr + lne[1:]
  48. lne = fle.readline()
  49. data = np.loadtxt(fName)
  50. data[:, 2:5] = np.dot(A, data[:, 2:5].T).T + np.array(b)
  51. if data.shape[1] == 7:
  52. formatStr = '%d %d %0.3f %0.3f %0.3f %0.3f %d'
  53. elif data.shape[1] == 8:
  54. formatStr = '%d %d %0.3f %0.3f %0.3f %0.3f %d %d'
  55. else:
  56. raise(TypeError('Data in the input file is of unknown format.'))
  57. np.savetxt(destFle, data, header=headr, fmt=formatStr)
  58. #***********************************************************************************************************************
  59. def transSWC_rotAboutPoint(fName, A, b, destFle, point):
  60. '''
  61. Generate an SWC file at destFle with each point `x' of the morphology in fName transformed Affinely as A(x-mu)+b
  62. where mu is a specified point.
  63. Essentially, the morphology is centered at a specified point before being Affinely transformed.
  64. :param fName: string
  65. :param A: 2D numpy.ndarray of shape (3, 3)
  66. :param b: 3 member iterable
  67. :param destFle: string
  68. :param point: 3 member iterable
  69. :return:
  70. '''
  71. headr = ''
  72. with open(fName, 'r') as fle:
  73. lne = fle.readline()
  74. while lne[0] == '#':
  75. headr = headr + lne[1:]
  76. lne = fle.readline()
  77. data = np.loadtxt(fName)
  78. pts = data[:, 2:5]
  79. rotAbout = np.array(point)
  80. ptsCentered = pts - rotAbout
  81. data[:, 2:5] = np.dot(A, ptsCentered.T).T + np.array(b) + rotAbout
  82. if data.shape[1] == 7:
  83. formatStr = '%d %d %0.3f %0.3f %0.3f %0.3f %d'
  84. elif data.shape[1] == 8:
  85. formatStr = '%d %d %0.3f %0.3f %0.3f %0.3f %d %d'
  86. else:
  87. raise(TypeError('Data in the input file is of unknown format.'))
  88. np.savetxt(destFle, data, header=headr, fmt=formatStr)
  89. #***********************************************************************************************************************
  90. def getPCADetails(swcFileName, center=True, data=None):
  91. '''
  92. Returns the principal components and standard deviations along the principal components.
  93. Ref: http://arxiv.org/pdf/1404.1100.pdf
  94. :param swcFileName: sting, input SWC file name
  95. :param center: Boolean, if True, data is centered before calculating PCA
  96. :param data: numpy ndarray of shape [<>, 3]
  97. :return: PC, STDs
  98. PC: numpy.ndarray with the prinicial components of the data in its columns
  99. STDs: 3 member float iterable containing the standard variances of the data along the prinicipal components.
  100. '''
  101. if data is None:
  102. data = np.loadtxt(swcFileName)[:, 2:5]
  103. if center:
  104. mu = np.mean(data, axis=0)
  105. data = data - mu
  106. U, S, V = np.linalg.svd(data.T)
  107. dataProj = np.dot(U, data.T).T
  108. newStds = np.std(dataProj, axis=0)
  109. return U.T, newStds
  110. #***********************************************************************************************************************
  111. # **********************************************************************************************************************
  112. def digitizeSWCXYZ(swcXYZ, gridUnitSizes):
  113. digSWCXYZ = np.empty_like(swcXYZ, dtype=np.intp)
  114. for ind in range(3):
  115. digSWCXYZ[:, ind] = np.array(np.around((swcXYZ[:, ind]) / gridUnitSizes[ind]), np.int)
  116. return digSWCXYZ
  117. # **********************************************************************************************************************
  118. def resampleSWC(swcFile, resampleLength, mask=None, swcData=None):
  119. '''
  120. Resample the SWC points to place points at every resamplelength along the central line of every segment. Radii are interpolated.
  121. :param swcData: nx4 swc point data
  122. :param resampleLength: length at with resampling is done.
  123. :return: totlLen, ndarray of shape (#pts, 4) with each row containing XYZR values
  124. '''
  125. if swcData is None:
  126. swcData = np.loadtxt(swcFile)
  127. inds = swcData[:, 0].tolist()
  128. if mask is None:
  129. mask = [True] * swcData.shape[0]
  130. else:
  131. assert len(mask) == swcData.shape[0], 'Supplied mask is invalid for ' + swcFile
  132. resampledSWCData = []
  133. getSegLen = lambda a, b: np.linalg.norm(a - b)
  134. totalLen = 0
  135. for pt in swcData:
  136. if pt[6] < 0:
  137. if mask[inds.index(int(pt[0]))]:
  138. resampledSWCData.append(pt[2:6])
  139. if (pt[6] > 0) and (int(pt[6]) in inds):
  140. if mask[inds.index(int(pt[0]))]:
  141. resampledSWCData.append(pt[2:6])
  142. parentPt = swcData[inds.index(pt[6]), :]
  143. segLen = getSegLen(pt[2:5], parentPt[2:5])
  144. totalLen += segLen
  145. if segLen > resampleLength:
  146. temp = parentPt[2:5] - pt[2:5]
  147. distTemp = np.linalg.norm(temp)
  148. unitDirection = temp / distTemp
  149. radGrad = (pt[5] - parentPt[5]) / distTemp
  150. for newPtsInd in range(1, int(np.floor(segLen / resampleLength)) + 1):
  151. temp = (pt[2:5] + newPtsInd * resampleLength * unitDirection).tolist()
  152. temp.append(pt[5] + newPtsInd * radGrad * resampleLength)
  153. resampledSWCData.append(temp)
  154. return totalLen, np.array(resampledSWCData)
  155. # **********************************************************************************************************************