correctReg-MaxS-N_finalChoice.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import sys
  2. from regmaxsn.scripts.algorithms.RegMaxSN import normalizeFinally
  3. import os
  4. from regmaxsn.core.misc import parFileCheck
  5. from regmaxsn.core.RegMaxSPars import RegMaxSNParNames
  6. from regmaxsn.core.occupancyBasedMeasure import occupancyEMD
  7. import numpy as np
  8. import json
  9. def getRegMaxSNIterVsMeasure(resDir, swcList, voxelSize):
  10. itersAll = sorted([int(fle[3:-4]) for fle in os.listdir(resDir) if fle.find('ref') == 0])
  11. itersAll = [x for x in itersAll if x >= 0]
  12. occupancyMeasures = []
  13. for iterInd in itersAll:
  14. iterSWCs = []
  15. for swc in swcList:
  16. swcStem = os.path.split(swc)[1][:-4]
  17. iterSWC = os.path.join(resDir, "{}{}.swc".format(swcStem, iterInd))
  18. iterSWCs.append(iterSWC)
  19. occupancyMeasures.append(occupancyEMD(iterSWCs, voxelSize))
  20. return itersAll, occupancyMeasures
  21. def correctRegMaxSNChoice(parFile, parNames):
  22. assert os.path.isfile(parFile), "{} not found".format(parFile)
  23. ch = raw_input('Using parameter File {}.\n Continue?(y/n)'.format(parFile))
  24. if ch != 'y':
  25. print('User Abort!')
  26. sys.exit()
  27. parsList = parFileCheck(parFile, parNames)
  28. for pars in parsList:
  29. refSWC = pars['initRefSWC']
  30. swcList = pars['swcList']
  31. fnwrt = pars['finallyNormalizeWRT']
  32. assert os.path.isfile(refSWC), 'Could not find {}'.format(refSWC)
  33. for swc in swcList:
  34. assert os.path.isfile(swc), 'Could not find {}'.format(swc)
  35. assert swc.endswith('.swc'), 'Elements of swcList must be of SWC format with extension \'.swc\''
  36. assert fnwrt in swcList, 'The parameter finallyNormalizeWRT must be an element of the parameter swcList'
  37. for pars in parsList:
  38. refSWC = pars['initRefSWC']
  39. swcList = pars['swcList']
  40. fnwrt = pars['finallyNormalizeWRT']
  41. voxelSizes = pars['gridSizes']
  42. resDir = pars["resDir"]
  43. if os.path.isdir(resDir):
  44. iters, measures = getRegMaxSNIterVsMeasure(resDir=resDir, swcList=swcList, voxelSize=voxelSizes[-1])
  45. bestIterInd = iters[int(np.argmin(measures))]
  46. bestMeasure = min(measures)
  47. ipFiles = []
  48. opFiles = []
  49. thrash, fnwrtName = os.path.split(fnwrt[:-4])
  50. for swc in swcList:
  51. dirPath, expName = os.path.split(swc[:-4])
  52. ipFiles.append(os.path.join(resDir, '{}{}.swc'.format(expName, bestIterInd)))
  53. opFiles.append(os.path.join(resDir, '{}.swc'.format(expName)))
  54. normalizeFinally(ipFiles, resDir, opFiles, fnwrtName, bestIterInd)
  55. finalSolFile = os.path.join(resDir, "bestIterInd.json")
  56. with open(finalSolFile, 'w') as fle:
  57. json.dump({'finalVal': bestMeasure,
  58. 'bestIteration': bestIterInd}, fle)
  59. if __name__ == "__main__":
  60. assert len(sys.argv) == 2, "Improper usage! Please use as: \n " \
  61. "python {} <Reg-MaxS-N Par File>".format(sys.argv[1])
  62. correctRegMaxSNChoice(sys.argv[1], RegMaxSNParNames)