fr_analytics.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import os
  4. from scipy import signal
  5. class fr_analytics:
  6. def __init__(self, data, params, file_name):
  7. self.data = data
  8. self.params = params
  9. self.file_name = file_name
  10. self.base_name = os.path.basename(file_name).split('.')[0]
  11. self.results_path = self.params.file_handling.results + 'yes_no/'
  12. os.makedirs(self.results_path, exist_ok=True)
  13. return None
  14. def plot_pdf(self, f_id=0, ch_ids=range(32), xmax=50, ymax=1000):
  15. data = self.data[f_id, 0]
  16. plt.figure()
  17. plt.clf()
  18. for ch_id in range(len(ch_ids)):
  19. ax = plt.subplot(8, 8, ch_id + 1)
  20. hist, bin_edges = np.histogram(data[:, ch_id], range(-1, 50))
  21. plt.bar(bin_edges[:-1], hist, width=1, label=f'{ch_ids[ch_id]}')
  22. plt.xlim(0, xmax)
  23. plt.ylim(0, ymax)
  24. plt.legend(loc=1, prop={'size': 6})
  25. if ch_id < 56:
  26. ax.set_xticks([])
  27. else:
  28. ax.set_xlabel('sp/sec')
  29. if np.mod(ch_id, 8) > 0:
  30. ax.set_yticks([])
  31. return None
  32. def plot_spectra(self, f_id=0, recompute=True, fs=20, v_thr=5, ymin=0, ymax=7, arr_id=1, save_fig=False):
  33. data = self.data[f_id, 0]
  34. if arr_id == 1:
  35. ch_ids = range(32, 96)
  36. else:
  37. ch_ids = list(range(32)) + list(range(96, 128))
  38. if not hasattr(self, 'Y') or recompute:
  39. # Y1, ff1 = analytics1.calc_fft(data, fs=fs, i_stop=-1, axis=0)
  40. ff, Y = signal.welch(data, fs=fs, axis=0, nperseg=1000)
  41. # fidx1 = (ff1 > 0.1) & (ff1 < 5)
  42. self.Y = Y
  43. self.ff = ff
  44. fidx = (self.ff > 0.1) & (self.ff < 5)
  45. v_mean = data.mean(axis=0)
  46. v_ids = np.argwhere(v_mean > v_thr) # get ids of channels with mean firing rate higher than v_thr
  47. fh = plt.figure(1, figsize=[19.2, 9.55])
  48. plt.clf()
  49. for ii, ch_id in enumerate(ch_ids):
  50. ax = plt.subplot(8, 8, ii + 1)
  51. # plt.plot(ff1[fidx1], Y1[fidx1, ch_id], label=f'{ch_id}')
  52. if ch_id in v_ids:
  53. plt.plot(self.ff[fidx], self.Y[fidx, ch_id], 'C0', lw=2, label=f'{ch_id}')
  54. else:
  55. plt.plot(self.ff[fidx], self.Y[fidx, ch_id], 'k', lw=1, alpha=0.5, label=f'{ch_id}')
  56. plt.legend(loc=1, prop={'size': 6})
  57. plt.ylim(ymin, ymax)
  58. if ch_id < 56:
  59. ax.set_xticks([])
  60. else:
  61. ax.set_xlabel('Hz')
  62. if np.mod(ch_id, 8) > 0:
  63. ax.set_yticks([])
  64. plt.tight_layout()
  65. if save_fig:
  66. fig_name = f'{self.results_path}{self.base_name.split("-")[0]}_arr_{arr_id}.png'
  67. plt.savefig(fig_name)
  68. print(f'saving figure {fh.number} in {fig_name}\n')
  69. return None