correlation_matrix_from_spikes.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from networkunit import models, tests, scores
  2. import neo
  3. import quantities as pq
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. import argparse
  7. from scipy.linalg import eigh
  8. class polychrony_data(models.loaded_spiketrains):
  9. default_params = {'align_to_0': True,
  10. 'filter_inh': False,
  11. 'N': 1000,
  12. 't_start':0*pq.ms,
  13. 't_stop':60000*pq.ms,
  14. 'file_path': ''
  15. }
  16. def load(self):
  17. f = open(self.params['file_path'], 'r')
  18. lines = f.readlines()
  19. # Read Spike Times
  20. spike_times = [[]] * self.params['N']
  21. for line in lines:
  22. sec, msec, n = line.split(' ')[:3]
  23. t = float(sec)*1000. + float(msec)
  24. n = int(n)
  25. if t > self.params['t_stop']:
  26. break
  27. spike_times[n] = spike_times[n] + [t]
  28. # Fill Spike Trains
  29. nbr_neurons = self.params['N']
  30. if self.params['filter_inh']:
  31. nbr_neurons = 800
  32. spiketrains = [[]] * nbr_neurons
  33. for n, st in enumerate(spike_times):
  34. if n < 800:
  35. n_type = 'exc'
  36. else:
  37. n_type = 'inh'
  38. if not self.params['filter_inh'] or n_type == 'exc':
  39. spiketrains[n] = neo.SpikeTrain(np.sort(st), units='ms',
  40. t_start=self.params['t_start'],
  41. t_stop=self.params['t_stop'],
  42. n_type=n_type, unitID=n)
  43. return spiketrains
  44. class correlation_matrix_test(tests.correlation_matrix_test):
  45. score_type = scores.effect_size
  46. params = {'cluster_matrix':False,
  47. 'remove_autocorr':False,
  48. 'nan_to_num':True}
  49. if __name__ == '__main__':
  50. CLI = argparse.ArgumentParser()
  51. CLI.add_argument("--spikes", nargs='?', type=str)
  52. CLI.add_argument("--output", nargs='?', type=str)
  53. CLI.add_argument("--t_start", nargs='?', type=float, default=0)
  54. CLI.add_argument("--t_stop", "--T", nargs='?', type=float, default=60000)
  55. args, unknown = CLI.parse_known_args()
  56. spikes_model = polychrony_data(file_path=args.spikes,
  57. t_start=args.t_start*pq.ms,
  58. t_stop=args.t_stop*pq.ms,
  59. align_to_0=True)
  60. corr_test = correlation_matrix_test()
  61. cc_matrix = corr_test.generate_prediction(spikes_model)
  62. np.save(args.output, cc_matrix)