1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- from networkunit import models, tests, scores
- import neo
- import quantities as pq
- import numpy as np
- import matplotlib.pyplot as plt
- import argparse
- from scipy.linalg import eigh
- class polychrony_data(models.loaded_spiketrains):
- default_params = {'align_to_0': True,
- 'filter_inh': False,
- 'N': 1000,
- 't_start':0*pq.ms,
- 't_stop':60000*pq.ms,
- 'file_path': ''
- }
- def load(self):
- f = open(self.params['file_path'], 'r')
- lines = f.readlines()
- # Read Spike Times
- spike_times = [[]] * self.params['N']
- for line in lines:
- sec, msec, n = line.split(' ')[:3]
- t = float(sec)*1000. + float(msec)
- n = int(n)
- if t > self.params['t_stop']:
- break
- spike_times[n] = spike_times[n] + [t]
- # Fill Spike Trains
- nbr_neurons = self.params['N']
- if self.params['filter_inh']:
- nbr_neurons = 800
- spiketrains = [[]] * nbr_neurons
- for n, st in enumerate(spike_times):
- if n < 800:
- n_type = 'exc'
- else:
- n_type = 'inh'
- if not self.params['filter_inh'] or n_type == 'exc':
- spiketrains[n] = neo.SpikeTrain(np.sort(st), units='ms',
- t_start=self.params['t_start'],
- t_stop=self.params['t_stop'],
- n_type=n_type, unitID=n)
- return spiketrains
- class correlation_matrix_test(tests.correlation_matrix_test):
- score_type = scores.effect_size
- params = {'cluster_matrix':False,
- 'remove_autocorr':False,
- 'nan_to_num':True}
- if __name__ == '__main__':
- CLI = argparse.ArgumentParser()
- CLI.add_argument("--spikes", nargs='?', type=str)
- CLI.add_argument("--output", nargs='?', type=str)
- CLI.add_argument("--t_start", nargs='?', type=float, default=0)
- CLI.add_argument("--t_stop", "--T", nargs='?', type=float, default=60000)
- args, unknown = CLI.parse_known_args()
- spikes_model = polychrony_data(file_path=args.spikes,
- t_start=args.t_start*pq.ms,
- t_stop=args.t_stop*pq.ms,
- align_to_0=True)
- corr_test = correlation_matrix_test()
- cc_matrix = corr_test.generate_prediction(spikes_model)
- np.save(args.output, cc_matrix)
|