firing_rate_correlation.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import argparse
  2. import pandas as pd
  3. import numpy as np
  4. from elephant.statistics import mean_firing_rate
  5. import neo
  6. from pathlib import Path
  7. import os
  8. import sys
  9. sys.path.append(str(Path.cwd().parents[0] / 'scripts'))
  10. from utils import load_df, compact_column
  11. def load_rate_vectors(path):
  12. try:
  13. spiketrains = load_spiketrains(path)
  14. spiketrains_exc = filter_spiketrains(spiketrains, neuron_type='excitatory')
  15. spiketrains_inh = filter_spiketrains(spiketrains, neuron_type='inhibitory')
  16. return [calc_rate_vector(st) for st in [spiketrains,
  17. spiketrains_exc, spiketrains_inh]]
  18. except FileNotFoundError:
  19. return [np.array([]), np.array([]), np.array([])]
  20. def filter_spiketrains(spiketrains, **kwargs):
  21. for key, value in kwargs.items():
  22. spiketrains = [st for st in spiketrains if st.annotations[key] == value]
  23. return spiketrains
  24. def load_spiketrains(path):
  25. io = neo.io.get_io(path)
  26. return io.read_block().segments[0].spiketrains
  27. def calc_rate_vector(spiketrains):
  28. return np.array([mean_firing_rate(st) for st in spiketrains])
  29. def calc_vector_correlation(vector_a, vector_b):
  30. if len(vector_a) and len(vector_b):
  31. return np.corrcoef(vector_a, vector_b)[0,1]
  32. else:
  33. return np.nan
  34. if __name__ == '__main__':
  35. CLI = argparse.ArgumentParser()
  36. CLI.add_argument("--input", nargs='?', type=lambda s: s.split(' '))
  37. CLI.add_argument("--output", nargs='?', type=Path)
  38. args, unknown = CLI.parse_known_args()
  39. if len(args.input) != 2:
  40. raise ValueError(f'Expected two input files, got {len(args.input)}!')
  41. rates_a, rates_exc_a, rates_inh_a = load_rate_vectors(args.input[0])
  42. rates_b, rates_exc_b, rates_inh_b = load_rate_vectors(args.input[1])
  43. ratecorr = calc_vector_correlation(rates_a, rates_b)
  44. ratecorr_exc = calc_vector_correlation(rates_exc_a, rates_exc_b)
  45. ratecorr_inh = calc_vector_correlation(rates_inh_a, rates_inh_b)
  46. params = dict([(k.strip('-'),v) for k,v in zip(unknown[:-1:2],unknown[1::2])])
  47. params.update(rate_correlation=ratecorr,
  48. rate_correlation_exc=ratecorr_exc,
  49. rate_correlation_inh=ratecorr_inh)
  50. df = pd.Series(params).to_frame().T
  51. df.to_csv(args.output)