generate_poisson_correlation_matrix.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. from networkunit import models, tests, scores
  2. import numpy as np
  3. from quantities import ms, Hz
  4. import argparse
  5. class activity_model(models.stochastic_activity):
  6. params = {**models.stochastic_activity.default_params}
  7. class corr_test(tests.correlation_matrix_test):
  8. score_type = scores.effect_size
  9. params = {'cluster_matrix':False,
  10. 'remove_autocorr':False}
  11. def generate_matrix(N, t_start, t_stop, binsize, rate, assembly_sizes,
  12. correlations, bkgr_correlation, corr_method):
  13. params = {'size': N,
  14. 't_start': t_start * ms,
  15. 't_stop': t_stop * ms,
  16. 'rate': rate * Hz,
  17. 'statistic': 'poisson',
  18. 'correlation_method': corr_method,
  19. 'expected_bin_size': binsize * ms,
  20. 'correlations': correlations,
  21. 'assembly_sizes': assembly_sizes,
  22. 'bkgr_correlation': bkgr_correlation,
  23. 'max_pattern_length':100 * ms,
  24. 'shuffle': False,
  25. 'shuffle_seed': None}
  26. activity_model_inst = activity_model(**params)
  27. test = corr_test()
  28. return test.generate_prediction(activity_model_inst)
  29. if __name__ == '__main__':
  30. CLI = argparse.ArgumentParser()
  31. CLI.add_argument("--N", nargs='?', type=int)
  32. CLI.add_argument("--t_start", nargs='?', type=float)
  33. CLI.add_argument("--t_stop", nargs='?', type=float)
  34. CLI.add_argument("--binsize", nargs='?', type=float)
  35. CLI.add_argument("--rate", nargs='?', type=float)
  36. CLI.add_argument("--assembly_sizes", nargs='?', type=lambda s: [int(i) for i in s.split(',')])
  37. CLI.add_argument("--correlations", nargs='?', type=lambda s: [float(i) for i in s.split(',')])
  38. CLI.add_argument("--bkgr_correlation", nargs='?', type=float)
  39. CLI.add_argument("--corr_method", nargs='?', type=str)
  40. CLI.add_argument("--output", nargs='?', type=str)
  41. args, unknown = CLI.parse_known_args()
  42. M = generate_matrix(N=args.N,
  43. t_start=args.t_start,
  44. t_stop=args.t_stop,
  45. binsize=args.binsize,
  46. rate=args.rate,
  47. assembly_sizes=args.assembly_sizes,
  48. correlations=args.correlations,
  49. bkgr_correlation=args.bkgr_correlation,
  50. corr_method=args.corr_method)
  51. np.save(args.output, M)