plotPairwiseDistanceNN.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import os
  2. import json
  3. import numpy as np
  4. import seaborn as sns
  5. import matplotlib.pyplot as plt
  6. import pandas as pd
  7. from multiprocessing import cpu_count
  8. from scipy.spatial import cKDTree
  9. homeFolder = os.path.expanduser('~')
  10. mplPars = {
  11. 'text.usetex' : True,
  12. 'axes.labelsize' : 'large',
  13. 'font.family' : 'sans-serif',
  14. 'font.sans-serif' : 'Computer Modern Sans serif',
  15. 'font.size' : 48,
  16. 'font.weight' : 'black',
  17. 'xtick.labelsize' : 36,
  18. 'ytick.labelsize' : 36,
  19. }
  20. def plotPairwiseDistancesNN(parFile):
  21. plt.ion()
  22. sns.set(rc=mplPars)
  23. with open(parFile) as fle:
  24. parsList = json.load(fle)
  25. transErrs = pd.DataFrame(None, columns=['Exp. Name', 'Pairwise Distance in $\mu$m'])
  26. for par in parsList:
  27. refSWC = par['refSWC']
  28. resFile = par['resFile']
  29. testName = resFile[:-4]
  30. thresh = par['gridSizes'][-1]
  31. print(('Doing ' + repr((refSWC, resFile))))
  32. refPts = np.loadtxt(refSWC)[:, 2:5]
  33. testPts = np.loadtxt(resFile)[:, 2:5]
  34. refKDTree = cKDTree(refPts, compact_nodes=True, leafsize=100)
  35. minDists = refKDTree.query(testPts, n_jobs=cpu_count() - 1)[0]
  36. minDists[minDists == np.inf] = 1000
  37. transErrs = transErrs.append(pd.DataFrame({'Pairwise Distance in $\mu$m': minDists,
  38. 'Exp. Name': testName}),
  39. ignore_index=True)
  40. transErrsGr = transErrs.groupby(by='Exp. Name')
  41. regErrs = transErrsGr['Pairwise Distance in $\mu$m'].agg({'\% of points closer than\n lowest grid size':
  42. lambda x: 100 * ((x <= thresh).sum()) / float(len(x))})
  43. fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(14, 11.2))
  44. fig1, ax1 = plt.subplots(nrows=1, ncols=1, figsize=(14, 11.2))
  45. with sns.axes_style('darkgrid'):
  46. sns.boxplot(x='Exp. Name', y='Pairwise Distance in $\mu$m',
  47. ax=ax, data=transErrs, color=sns.color_palette()[0], whis='range')
  48. ax1.plot(list(range(regErrs.size)), regErrs['\% of points closer than\n lowest grid size'],
  49. color=sns.color_palette()[0], marker='o', linestyle='-', ms=10)
  50. ax.set_xlim(-1, len(regErrs))
  51. ax.set_ylim(0, 40)
  52. ax.set_xticklabels(['job {}'.format(x) for x in range(len(parsList))], rotation=90)
  53. ax.set_xlabel('')
  54. ax1.set_xlim(-1, len(regErrs))
  55. ax1.set_ylim(-10, 110)
  56. ax1.set_xticks(list(range(regErrs.size)))
  57. ax1.set_xticklabels(['job {}'.format(x) for x in range(len(parsList))], rotation=90)
  58. ax1.set_ylabel('\% of points closer than\n lowest grid size')
  59. for ind, f in enumerate([fig, fig1]):
  60. # f.canvas.draw()
  61. f.tight_layout()
  62. return fig, fig1
  63. # ----------------------------------------------------------------------------------------------------------------------
  64. if __name__ == '__main__':
  65. import sys
  66. assert len(sys.argv) == 2, 'Improper usage! Please use as \'python plotPairwiseDistance.py parFile\''
  67. parFile = sys.argv[1]
  68. figs = plotPairwiseDistancesNN(parFile)
  69. input('Press any key to close figures and quit:')