plotOccupancyMeasureVsIterations.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from regmaxsn.core.occupancyBasedMeasure import occupancyEMD, calcOccupancyDistribution
  2. from regmaxsn.core.matplotlibRCParams import mplPars
  3. from matplotlib import pyplot as plt
  4. from regmaxsn.core.misc import parFileCheck
  5. import pandas as pd
  6. import os
  7. import seaborn as sns
  8. import sys
  9. import numpy as np
  10. def plotMaxDistEMDVsIteration(parFile, parNames):
  11. sns.set(rc=mplPars)
  12. plt.ion()
  13. metricsDF = pd.DataFrame()
  14. parsList = parFileCheck(parFile, parNames)
  15. for parInd, pars in enumerate(parsList):
  16. resDir = pars['resDir']
  17. swcList = pars['swcList']
  18. gridSizes = pars["gridSizes"]
  19. expNames = [os.path.split(swc)[1][:-4] for swc in swcList]
  20. if os.path.isdir(resDir):
  21. print('Current refSWC={}'.format(pars["initRefSWC"]))
  22. iters = sorted([int(fle[3:-4]) for fle in os.listdir(resDir) if fle.find('ref') == 0])
  23. nIter = max(iters)
  24. # distributionsDF = pd.DataFrame()
  25. for iterInd in range(nIter + 1):
  26. print('Doing {}/{}'.format(iterInd + 1, nIter + 1))
  27. iterSWCs = [os.path.join(resDir, '{}{}.swc'.format(expName, iterInd)) for expName in expNames]
  28. for gridSize in gridSizes:
  29. metric = occupancyEMD(iterSWCs, gridSize)
  30. occupancyDist = calcOccupancyDistribution(iterSWCs, gridSize)
  31. tempDict = {}
  32. for k, v in occupancyDist.iteritems():
  33. tempDict["Occupancy"] = k
  34. tempDict["Occupancy PMF"] = v
  35. tempDict["gridSize"] = gridSize
  36. tempDict["Iteration"] = iterInd + 1
  37. # distributionsDF = distributionsDF.append(tempDict, ignore_index=True)
  38. tempDict = {"Job Number": parInd + 1, "Iteration Number": iterInd + 1,
  39. "Occupancy based Metric": metric, "gridSize": gridSize}
  40. metricsDF = metricsDF.append(tempDict, ignore_index=True)
  41. # distributionsDFGBGS = distributionsDF.groupby("gridSize")
  42. # fig, axs = plt.subplots(figsize=(14, 11.2), nrows=len(distributionsDFGBGS.groups))
  43. # for ax, (gridSize, distributionsDFGS) in zip(axs, distributionsDFGBGS):
  44. # sns.pointplot(data=distributionsDFGS,
  45. # x="Occupancy", y="Occupancy PMF",
  46. # hue="Iteration",
  47. # palette=sns.diverging_palette(255, 133, l=60, n=7, center="dark"),
  48. # ci=0,
  49. # linestyles="-", markers="o", ax=ax)
  50. # ax.set_ylabel("Occupancy PMF")
  51. # ax.set_title("gridSize={}".format(gridSize))
  52. # fig.suptitle("Job number {}".format(parInd + 1))
  53. # fig.tight_layout()
  54. metricGBgridSize = metricsDF.groupby("gridSize")
  55. fig, axs = plt.subplots(figsize=(14, 11.2), nrows=len(metricGBgridSize.groups))
  56. for ax, (gridSize, metricsDFGS) in zip(axs, metricGBgridSize):
  57. sns.pointplot(data=metricsDFGS,
  58. x="Iteration Number", y="Occupancy based Metric", hue="Job Number",
  59. ci=None, linestyles="-", markers="o", ax=ax)
  60. ax.set_ylabel("Occupancy\n based\n Metric")
  61. ax.set_title("gridSize={}".format(gridSize))
  62. fig.tight_layout()
  63. return fig
  64. if __name__ == '__main__':
  65. from regmaxsn.core.RegMaxSPars import RegMaxSNParNames
  66. assert len(sys.argv) == 2, 'Improper usage! Please use as \'python {} parFile\''.format(sys.argv[1])
  67. parFile = sys.argv[1]
  68. figs = plotMaxDistEMDVsIteration(parFile, RegMaxSNParNames)
  69. raw_input('Press any key to close figures and quit:')