calcD2Metrics.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # Based on D2 measure of the Peng lab, which first resamples all the swcs, then for each point of the reference swc
  2. # finds the closest point in each test swc. The distances between the reference point and the closest test points are
  3. # used to quantify the how well a method has worked.
  4. import os
  5. from core.swcFuncs import resampleSWC
  6. from scipy.spatial import cKDTree
  7. import numpy as np
  8. from multiprocessing import cpu_count
  9. from matplotlib import pyplot as plt
  10. import pandas as pd
  11. import seaborn as sns
  12. mplPars = {'text.usetex': True,
  13. 'axes.labelsize': 'large',
  14. 'font.family': 'sans-serif',
  15. 'font.sans-serif': 'computer modern roman',
  16. 'font.size': 42,
  17. 'font.weight': 'black',
  18. 'xtick.labelsize': 36,
  19. 'ytick.labelsize': 36,
  20. 'legend.fontsize': 36,
  21. }
  22. homeFolder = os.path.expanduser('~')
  23. plt.ion()
  24. sns.set(rc=mplPars)
  25. class DataSet(object):
  26. def __init__(self, label, refSWC, testSWCSets):
  27. self.label = label
  28. self.refSWC = refSWC
  29. self.testSWCSets = testSWCSets
  30. def calcMinDists(self, minLen):
  31. allMinDists = {}
  32. thrash, resamRefPts = resampleSWC(self.refSWC, minLen)
  33. for testSWCSetName, testSWCSet in self.testSWCSets.iteritems():
  34. minDists = np.empty((resamRefPts.shape[0], len(testSWCSet)))
  35. for testInd, testSWC in enumerate(testSWCSet):
  36. thrash, resampTestPts = resampleSWC(testSWC, minLen)
  37. testKDTree = cKDTree(resampTestPts, compact_nodes=True, leafsize=100)
  38. minDists[:, testInd] = testKDTree.query(resamRefPts, n_jobs=cpu_count() - 1)[0]
  39. allMinDists[testSWCSetName] = minDists
  40. return allMinDists
  41. dataSets = {}
  42. # ----------------------------------------------------------------------------------------------------------------------
  43. refPath = os.path.join(homeFolder, 'DataAndResults', 'morphology', 'Registered', 'chiangOMB')
  44. refExpName = 'VGlut-F-500085.CNG'
  45. refSWC = os.path.join(refPath, refExpName + '.swc')
  46. testSWCSets = {}
  47. dataSetName = 'ALPN'
  48. testExpNames = [
  49. # 'VGlut-F-500085_registered',
  50. 'VGlut-F-700500.CNG',
  51. 'VGlut-F-700567.CNG',
  52. 'VGlut-F-500471.CNG',
  53. 'Cha-F-000353.CNG',
  54. 'VGlut-F-600253.CNG',
  55. 'VGlut-F-400434.CNG',
  56. 'VGlut-F-600379.CNG',
  57. 'VGlut-F-700558.CNG',
  58. 'VGlut-F-500183.CNG',
  59. 'VGlut-F-300628.CNG',
  60. 'VGlut-F-500085.CNG',
  61. 'VGlut-F-500031.CNG',
  62. 'VGlut-F-500852.CNG',
  63. 'VGlut-F-600366.CNG'
  64. ]
  65. # -------------------------------------------------------------------------------------------------
  66. testPath = os.path.join(homeFolder, 'DataAndResults', 'morphology', 'directPixelBased', 'chiangOMB')
  67. label = 'Reg-MaxS-N'
  68. testSWCs = [os.path.join(testPath, x + '_norm.swc') for x in testExpNames]
  69. testSWCSets[label] = testSWCs
  70. # -------------------------------------------------------------------------------------------------
  71. testPath = os.path.join(homeFolder, 'DataAndResults', 'morphology', 'RefPCA', 'chiangOMB')
  72. label = 'PCA-Based'
  73. testSWCs = [os.path.join(testPath, x + '.swc') for x in testExpNames]
  74. testSWCSets[label] = testSWCs
  75. # -------------------------------------------------------------------------------------------------
  76. dataSets[dataSetName] = DataSet(dataSetName, refSWC, testSWCSets)
  77. # ----------------------------------------------------------------------------------------------------------------------
  78. refPath = os.path.join(homeFolder, 'DataAndResults', 'morphology', 'Registered', 'chiangLLC')
  79. refExpName = 'Gad1-F-000062.CNG'
  80. refSWC = os.path.join(refPath, refExpName + '.swc')
  81. testSWCSets = {}
  82. dataSetName = 'LCInt'
  83. testExpNames = [
  84. 'Gad1-F-000062.CNG',
  85. 'Cha-F-000012.CNG',
  86. 'Cha-F-300331.CNG',
  87. 'Gad1-F-600000.CNG',
  88. 'Cha-F-000018.CNG',
  89. 'Cha-F-300051.CNG',
  90. 'Cha-F-400051.CNG',
  91. 'Cha-F-200000.CNG'
  92. ]
  93. # -------------------------------------------------------------------------------------------------
  94. testPath = os.path.join(homeFolder, 'DataAndResults', 'morphology', 'directPixelBased', 'chiangLLC')
  95. label = 'Reg-MaxS-N'
  96. testSWCs = [os.path.join(testPath, x + '_norm.swc') for x in testExpNames]
  97. testSWCSets[label] = testSWCs
  98. # -------------------------------------------------------------------------------------------------
  99. testPath = os.path.join(homeFolder, 'DataAndResults', 'morphology', 'RefPCA', 'chiangLLC')
  100. label = 'PCA-Based'
  101. testSWCs = [os.path.join(testPath, x + '.swc') for x in testExpNames]
  102. testSWCSets[label] = testSWCs
  103. # -------------------------------------------------------------------------------------------------
  104. dataSets[dataSetName] = DataSet(dataSetName, refSWC, testSWCSets)
  105. # ----------------------------------------------------------------------------------------------------------------------
  106. refPath = os.path.join(homeFolder, 'DataAndResults', 'morphology', 'Registered', 'chiangOPSInt')
  107. refExpName = 'Trh-F-000047.CNG'
  108. testSWCSets = {}
  109. testExpNames = [
  110. 'Trh-F-000047.CNG',
  111. 'Trh-M-000143.CNG',
  112. 'Trh-F-000092.CNG',
  113. 'Trh-F-700009.CNG',
  114. 'Trh-M-000013.CNG',
  115. 'Trh-M-000146.CNG',
  116. # 'Trh-M-100009.CNG',
  117. 'Trh-F-000019.CNG',
  118. 'Trh-M-000081.CNG',
  119. 'Trh-M-900003.CNG',
  120. 'Trh-F-200035.CNG',
  121. 'Trh-F-200015.CNG',
  122. 'Trh-M-000040.CNG',
  123. 'Trh-M-600023.CNG',
  124. 'Trh-M-100048.CNG',
  125. 'Trh-M-700019.CNG',
  126. 'Trh-F-100009.CNG',
  127. 'Trh-M-400000.CNG',
  128. 'Trh-M-000067.CNG',
  129. 'Trh-M-000114.CNG',
  130. 'Trh-M-100018.CNG',
  131. 'Trh-M-000141.CNG',
  132. 'Trh-M-900019.CNG',
  133. 'Trh-M-800002.CNG'
  134. ]
  135. dataSetName = 'OPInt'
  136. refSWC = os.path.join(refPath, refExpName + '.swc')
  137. # -------------------------------------------------------------------------------------------------
  138. testPath = os.path.join(homeFolder, 'DataAndResults', 'morphology', 'directPixelBased', 'chiangOPSInt')
  139. label = 'Reg-MaxS-N'
  140. testSWCs = [os.path.join(testPath, x + '_norm.swc') for x in testExpNames]
  141. testSWCSets[label] = testSWCs
  142. # -------------------------------------------------------------------------------------------------
  143. testPath = os.path.join(homeFolder, 'DataAndResults', 'morphology', 'RefPCA', 'chiangOPSInt')
  144. label = 'PCA-Based'
  145. testSWCs = [os.path.join(testPath, x + '.swc') for x in testExpNames]
  146. testSWCSets[label] = testSWCs
  147. # -------------------------------------------------------------------------------------------------
  148. dataSets[dataSetName] = DataSet(dataSetName, refSWC, testSWCSets)
  149. # ----------------------------------------------------------------------------------------------------------------------
  150. minDists = {}
  151. allMinDistStats = {}
  152. fig1, ax1 = plt.subplots(figsize=(14, 11.2))
  153. fig2, ax2 = plt.subplots(figsize=(14, 11.2))
  154. for dataSetName, dataSet in dataSets.iteritems():
  155. temp = dataSet.calcMinDists(0.1)
  156. minDists[dataSetName] = temp
  157. dataSetMinDists = []
  158. for methodName, testSetMinDists in temp.iteritems():
  159. methodMinDistStats = pd.DataFrame(data=None,
  160. columns=['Mean of minimum distances(um)',
  161. 'Standard Deviation of \nminimum distances(um)',
  162. 'Method',
  163. 'Group Name'])
  164. methodMinDistStats.loc[:, 'Mean of minimum distances(um)'] = testSetMinDists.mean(axis=1)
  165. methodMinDistStats.loc[:, 'Standard Deviation of \nminimum distances(um)'] = testSetMinDists.std(axis=1)
  166. methodMinDistStats.loc[:, 'Method'] = methodName
  167. methodMinDistStats.loc[:, 'Group Name'] = dataSetName
  168. dataSetMinDists.append(methodMinDistStats)
  169. allMinDistStats[dataSetName] = pd.concat(dataSetMinDists)
  170. minDistsStats1DF = pd.concat(allMinDistStats.itervalues())
  171. sns.boxplot(x='Group Name', y='Mean of minimum distances(um)', hue='Method', data=minDistsStats1DF,
  172. ax=ax1, whis='range')
  173. sns.boxplot(x='Group Name', y='Standard Deviation of \nminimum distances(um)', hue='Method', data=minDistsStats1DF,
  174. ax=ax2, whis='range')
  175. for fig in [fig1, fig2]:
  176. fig.tight_layout()
  177. fig.canvas.draw()
  178. # print('Reasampling Ref')
  179. # refRPts = resampleSWC(refSWC, 0.1)[1][:, :3]
  180. # print('Resampling tests')
  181. # testRPts = [resampleSWC(x, 0.1)[1][:, :3] for x in testSWCs]
  182. # print('Creating testKDtrees')
  183. # testKDTrees = [cKDTree(x, compact_nodes=True, leafsize=100) for x in testRPts]
  184. #
  185. # minDists = np.empty((refRPts.shape[0], len(testRPts)))
  186. #
  187. # for testInd, testKDTree in enumerate(testKDTrees):
  188. # print('Calculating minDists for ' + testExpNames[testInd])
  189. # minDists[:, testInd] = testKDTree.query(refRPts, n_jobs=cpu_count() - 1)[0]
  190. #
  191. # # minDists[minDists == np.inf]
  192. #
  193. # meanMinDists = minDists.mean()
  194. # meanStdMinDists = minDists.std(axis=1).mean()
  195. #
  196. #
  197. # print(dataSetName + part, label + ' Results:')
  198. # print('Mean of all minDists: ' + str(meanMinDists))
  199. # print('Mean Std of minDists: ' + str(meanStdMinDists))