regErrorVsAnisotropicscaling.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import os
  2. import numpy as np
  3. from matplotlib import pyplot as plt
  4. import seaborn as sns
  5. import json
  6. from regmaxsn.core.RegMaxSPars import RegMaxSParNames
  7. from regmaxsn.core.misc import parFileCheck
  8. from regmaxsn.core.matplotlibRCParams import mplPars
  9. homeFolder = os.path.expanduser('~')
  10. plt.ion()
  11. # Example colFunc, takes refSWC and testSWC and returns an object that can be passed to matplotlib plotting argument
  12. # color. For example return objects could be 'b' (blue), 'r' (red) and [0, 0, 1] (green).
  13. # def colFunc(refSWC, testSWC):
  14. #
  15. # testInd = int(os.path.split(testSWC)[1][25:-4])
  16. #
  17. # if testInd in range(2, 5):
  18. # return 'r'
  19. # else:
  20. # return 'b'
  21. colFunc = None
  22. def regErrorVsAIScaling(parFile, colFunc=None):
  23. # Axis 1: neuron pairs; Axis 2: (reg accuracy, anisotropic scaling)
  24. parsList = parFileCheck(parFile, RegMaxSParNames)
  25. translErrStats = np.empty((len(parsList), 2))
  26. for parInd, par in enumerate(parsList):
  27. refSWC = par['refSWC']
  28. testSWC = par['testSWC']
  29. resFile = par['resFile']
  30. thresh = par['gridSizes'][-1]
  31. print('Doing ' + repr((refSWC, resFile)))
  32. origJSON = testSWC[:-4] + '.json'
  33. if os.path.isfile(origJSON):
  34. with open(origJSON, 'r') as fle:
  35. pars = json.load(fle)
  36. scales = np.array(pars['scale'])
  37. else:
  38. raise(IOError('File not found: {}'.format(origJSON)))
  39. scalesOrdered = np.sort(scales)
  40. scalesRelative = np.mean([scalesOrdered[0] / scalesOrdered[1],
  41. scalesOrdered[0] / scalesOrdered[2],
  42. scalesOrdered[1] / scalesOrdered[2]])
  43. refPts = np.loadtxt(refSWC)[:, 2:5]
  44. testPts = np.loadtxt(resFile)[:, 2:5]
  45. if refPts.shape[0] != testPts.shape[0]:
  46. print('Number of points do not match for ' + refSWC + 'and' + testSWC)
  47. continue
  48. ptDiff = np.linalg.norm(refPts - testPts, axis=1)
  49. translErrStats[parInd, 0] = 100 * sum(ptDiff <= thresh) / float(ptDiff.shape[0])
  50. translErrStats[parInd, 1] = scalesRelative
  51. sns.set(rc=mplPars)
  52. with sns.axes_style('whitegrid'):
  53. fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 8))
  54. for parInd, vals in enumerate(translErrStats):
  55. try:
  56. if colFunc:
  57. col = colFunc(parsList[parInd]['refSWC'], parsList[parInd]['testSWC'])
  58. else:
  59. col = 'b'
  60. ax.plot(vals[1], vals[0], color=col, marker='o', ls='None', ms=10)
  61. except Exception as e:
  62. raise(Exception('Problem with plotting. There could be a problem with argument colFunc'))
  63. ax.set_xlabel('measure of anisotropic scaling')
  64. ax.set_ylabel('\% points closer than \nthe lowest grid size')
  65. ax.set_ylim(-10, 110)
  66. fig.tight_layout()
  67. return fig
  68. if __name__ == '__main__':
  69. import sys
  70. assert len(sys.argv) == 2, 'Improper usage! Please use as \'python refErrorVsAnisotropicscaling.py parFile\''
  71. parFile = sys.argv[1]
  72. fig = regErrorVsAIScaling(parFile, colFunc)
  73. raw_input('Press any key to close figures and quit:')