saveAverageDensity.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. import os
  2. from regmaxsn.core.plotDensities import DensityVizualizations, writeTIFF
  3. from regmaxsn.core.transforms import compose_matrix
  4. import numpy as np
  5. from matplotlib import pyplot as plt
  6. from regmaxsn.core.RegMaxSPars import DensitySaveParNames
  7. from regmaxsn.core.misc import parFileCheck
  8. homeFolder = os.path.expanduser('~')
  9. import sys
  10. def saveAverageDensity(regMaxSParFile, refSWC, outFile, gridUnitSize, sigma, reflections, rotations,
  11. onlyTips=False):
  12. tempRotMat = compose_matrix(angles=np.deg2rad(rotations))
  13. initTrans = np.dot(np.diag(reflections), tempRotMat[:3, :3])
  14. resampleLen = 1
  15. gridUnitSize = [gridUnitSize] * 3
  16. sigmas = [sigma] * 3
  17. parsList = parFileCheck(regMaxSParFile, DensitySaveParNames)
  18. swcFiles = []
  19. for pars in parsList:
  20. resFile = pars['resFile']
  21. swcFiles.append(resFile)
  22. if onlyTips:
  23. masks = []
  24. for swcFile in swcFiles:
  25. data = np.loadtxt(swcFile)
  26. mask = map(lambda ptInd: ptInd not in data[:, 6], data[:, 0])
  27. masks.append(mask)
  28. densityViz = DensityVizualizations(swcFiles, gridUnitSize, resampleLen, masks=masks,
  29. pcaView=True, refSWC=refSWC, initTrans=initTrans)
  30. density, bins = densityViz.calculateDensity(swcFiles, sigmas)
  31. else:
  32. # densityViz = DensityVizualizations(swcFiles, gridUnitSize, resampleLen,
  33. # pcaView='closestPCMatch', refSWC=refSWC, initTrans=initTrans)
  34. densityViz = DensityVizualizations(swcFiles, gridUnitSize, resampleLen,
  35. pcaView='assumeRegistered', refSWC=refSWC,
  36. initTrans=initTrans)
  37. density, bins = densityViz.calculateDensity(swcFiles, sigmas)
  38. # density, bins = calcMorphDensity(swcFiles, sigmas, gridUnitSize, resampleLen, pcaView=False, refSWC=refSWC)
  39. # # writeTIFF(density, outFile)
  40. np.savez_compressed(outFile, density=density, bins=bins, expNames=swcFiles)
  41. def savePlotsTogether(densityDir, outDir):
  42. plt.rcParams["backend"] = 'agg'
  43. compressedFiles = [os.path.join(densityDir, x) for x in os.listdir(densityDir) if x.endswith('.npz')]
  44. mins = np.empty((len(compressedFiles), 3))
  45. maxs = np.empty((len(compressedFiles), 3))
  46. for comInd, comFile in enumerate(compressedFiles):
  47. compressedData = np.load(comFile)
  48. bins = compressedData['bins']
  49. mins[comInd, :] = [x.min() for x in bins]
  50. maxs[comInd, :] = [x.max() for x in bins]
  51. allMaxs = maxs.max(axis=0)
  52. allMins = mins.min(axis=0)
  53. del maxs, mins, bins
  54. for comInd, comFile in enumerate(compressedFiles):
  55. label = os.path.split(comFile)[1][:-4]
  56. print("Doing {}".format(label))
  57. compressedData = np.load(comFile)
  58. density = compressedData['density']
  59. bins = compressedData['bins']
  60. binWidths = [x[1] - x[0] for x in bins]
  61. finalExtents = [int((x - y) / z) for x, y, z in zip(allMaxs, allMins, binWidths)]
  62. tempDensity = np.zeros(finalExtents)
  63. binInds = [(int((binAxis[0] - minAxis) / binWidth),
  64. int((binAxis[-1] - minAxis) / binWidth))
  65. for binAxis, minAxis, binWidth in zip(bins, allMins, binWidths)]
  66. tempDensity[binInds[0][0]: binInds[0][1],
  67. binInds[1][0]: binInds[1][1],
  68. binInds[2][0]: binInds[2][1]] = density
  69. # fig1, ax1 = plt.subplots(figsize=(10, 8))
  70. density01 = np.max(tempDensity, axis=2)
  71. # im1 = ax1.imshow(density01, interpolation='none',
  72. # cmap=plt.cm.jet, vmin=0, vmax=1, aspect='equal')
  73. #
  74. #
  75. # fig1.colorbar(im1, ax=ax1, use_gridspec=True)
  76. # ax1.set_ylabel('Axis 1')
  77. # ax1.set_xlabel('Axis 2')
  78. # ax1.set_xticklabels([str(x * gridUnitSize[1]) for x in ax1.get_xticks()])
  79. # ax1.set_yticklabels([str(x * gridUnitSize[0]) for x in ax1.get_yticks()])
  80. #
  81. #
  82. #
  83. # fig2, ax2 = plt.subplots(figsize=(10, 8))
  84. density02 = np.max(tempDensity, axis=1)
  85. # im2 = ax2.imshow(density02, interpolation='none',
  86. # cmap=plt.cm.jet, vmin=0, vmax=1, aspect='equal')
  87. #
  88. #
  89. # fig2.colorbar(im2, ax=ax2, use_gridspec=True)
  90. # ax2.set_xlabel('Axis 3')
  91. # ax2.set_ylabel('Axis 1')
  92. # ax2.set_xticklabels([str(x * gridUnitSize[2]) for x in ax2.get_xticks()])
  93. # ax2.set_yticklabels([str(x * gridUnitSize[0]) for x in ax2.get_yticks()])
  94. #
  95. # for fig in [fig1, fig2]:
  96. # fig.tight_layout()
  97. # fig.canvas.draw()
  98. # scaleBar = ScaleBar(gridUnitSize[1] * 1e-6)
  99. # ax1.add_artist(scaleBar)
  100. # scaleBar = ScaleBar(gridUnitSize[1] * 1e-6)
  101. # ax2.add_artist(scaleBar)
  102. outFile = os.path.join(outDir, label)
  103. # fig1, ax1 = plt.subplots(figsize=np.array(density01.shape) / 300.)
  104. # ax1.imshow(density01, cmap=plt.cm.jet, vmax=1, vmin=0, interpolation='none')
  105. # ax1.axis('off')
  106. # # fig1.tight_layout()
  107. # fig1.savefig(outFile + '12' + '.ps', dpi=300, bbox_inches='tight', pad_inches=0, frameon=False)
  108. plt.imsave(outFile + '12.png', density01, cmap=plt.cm.jet, format='png', vmin=0, vmax=1)
  109. # fig2, ax2 = plt.subplots(figsize=np.array(density02.shape) / 300.)
  110. # ax2.imshow(density02, cmap=plt.cm.jet, vmax=1, vmin=0, interpolation='none')
  111. # ax2.axis('off')
  112. # # fig2.tight_layout()
  113. # fig2.savefig(outFile + '13' + '.ps', dpi=300, bbox_inches='tight', pad_inches=0, frameon=False)
  114. plt.imsave(outFile + '13.png', density02, cmap=plt.cm.jet, format='png', vmin=0, vmax=1)
  115. del tempDensity
  116. def savePlotsSingle(npCompressedFile, label, outDir):
  117. plt.rcParams["backend"] = 'agg'
  118. compressedData = np.load(npCompressedFile)
  119. density = compressedData['density']
  120. # fig1, ax1 = plt.subplots(figsize=(10, 8))
  121. density01 = np.max(density, axis=2)
  122. # im1 = ax1.imshow(density01, interpolation='none',
  123. # cmap=plt.cm.jet, vmin=0, vmax=1, aspect='equal')
  124. #
  125. #
  126. # fig1.colorbar(im1, ax=ax1, use_gridspec=True)
  127. # ax1.set_ylabel('Axis 1')
  128. # ax1.set_xlabel('Axis 2')
  129. # ax1.set_xticklabels([str(x * gridUnitSize[1]) for x in ax1.get_xticks()])
  130. # ax1.set_yticklabels([str(x * gridUnitSize[0]) for x in ax1.get_yticks()])
  131. #
  132. #
  133. #
  134. # fig2, ax2 = plt.subplots(figsize=(10, 8))
  135. density02 = np.max(density, axis=1)
  136. # im2 = ax2.imshow(density02, interpolation='none',
  137. # cmap=plt.cm.jet, vmin=0, vmax=1, aspect='equal')
  138. #
  139. #
  140. # fig2.colorbar(im2, ax=ax2, use_gridspec=True)
  141. # ax2.set_xlabel('Axis 3')
  142. # ax2.set_ylabel('Axis 1')
  143. # ax2.set_xticklabels([str(x * gridUnitSize[2]) for x in ax2.get_xticks()])
  144. # ax2.set_yticklabels([str(x * gridUnitSize[0]) for x in ax2.get_yticks()])
  145. #
  146. # for fig in [fig1, fig2]:
  147. # fig.tight_layout()
  148. # fig.canvas.draw()
  149. # scaleBar = ScaleBar(gridUnitSize[1] * 1e-6)
  150. # ax1.add_artist(scaleBar)
  151. # scaleBar = ScaleBar(gridUnitSize[1] * 1e-6)
  152. # ax2.add_artist(scaleBar)
  153. outFile = os.path.join(outDir, label)
  154. # fig1, ax1 = plt.subplots(figsize=np.array(density01.shape) / 300.)
  155. # ax1.imshow(density01, cmap=plt.cm.jet, vmax=1, vmin=0, interpolation='none')
  156. # ax1.axis('off')
  157. # # fig1.tight_layout()
  158. # fig1.savefig(outFile + '12' + '.ps', dpi=300, bbox_inches='tight', pad_inches=0, frameon=False)
  159. plt.imsave(outFile + '12.png', density01, cmap=plt.cm.jet, format='png', vmin=0, vmax=1)
  160. # fig2, ax2 = plt.subplots(figsize=np.array(density02.shape) / 300.)
  161. # ax2.imshow(density02, cmap=plt.cm.jet, vmax=1, vmin=0, interpolation='none')
  162. # ax2.axis('off')
  163. # # fig2.tight_layout()
  164. # fig2.savefig(outFile + '13' + '.ps', dpi=300, bbox_inches='tight', pad_inches=0, frameon=False)
  165. plt.imsave(outFile + '13.png', density02, cmap=plt.cm.jet, format='png', vmin=0, vmax=1)
  166. # densityViz.generateDensityColoredSSWC(swcFiles, [os.path.join(densityDir, x + '_density.sswc') for x in expNames],
  167. # density)
  168. if __name__ == "__main__":
  169. if len(sys.argv) == 9 and sys.argv[1] == "saveData":
  170. parFile = sys.argv[2]
  171. refSWC = sys.argv[3]
  172. outFile = sys.argv[4]
  173. gridUnitSize = float(sys.argv[5])
  174. sigma = float(sys.argv[6])
  175. reflections = np.array([int(x) for x in sys.argv[7].split()])
  176. rotations = np.array([float(x) for x in sys.argv[8].split()])
  177. saveAverageDensity(parFile, refSWC, outFile, gridUnitSize, sigma, reflections, rotations)
  178. elif len(sys.argv) == 5 and sys.argv[1] == "savePlotsSingle":
  179. npCompressedFile = sys.argv[2]
  180. label = sys.argv[3]
  181. outDir = sys.argv[4]
  182. savePlotsSingle(npCompressedFile, label, outDir)
  183. elif len(sys.argv) == 4 and sys.argv[1] == 'savePlotsTogether':
  184. densityDir = sys.argv[2]
  185. outDir = sys.argv[3]
  186. savePlotsTogether(densityDir, outDir)
  187. else:
  188. raise(ValueError("Improper Usage! Please use as:\n"
  189. "python {fle} saveData <RegMaxSParFile> <outFile> "
  190. "<spatial discretization size> <Gaussian smoothing sigma>\n"
  191. "python {fle} savePlotsSingle <compressed Data file> <label> <output directory>\n"
  192. "python {fle} savePlotsTogether <density directory> <output directory>".format(fle=sys.argv[0])))