123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- import matplotlib.pyplot as plt
- import numpy as np
- import os
- from scipy import signal
- class fr_analytics:
- def __init__(self, data, params, file_name):
- self.data = data
- self.params = params
- self.file_name = file_name
- self.base_name = os.path.basename(file_name).split('.')[0]
- self.results_path = self.params.file_handling.results + 'yes_no/'
- os.makedirs(self.results_path, exist_ok=True)
- return None
- def plot_pdf(self, f_id=0, ch_ids=range(32), xmax=50, ymax=1000):
- data = self.data[f_id, 0]
- plt.figure()
- plt.clf()
- for ch_id in range(len(ch_ids)):
- ax = plt.subplot(8, 8, ch_id + 1)
- hist, bin_edges = np.histogram(data[:, ch_id], range(-1, 50))
- plt.bar(bin_edges[:-1], hist, width=1, label=f'{ch_ids[ch_id]}')
- plt.xlim(0, xmax)
- plt.ylim(0, ymax)
- plt.legend(loc=1, prop={'size': 6})
- if ch_id < 56:
- ax.set_xticks([])
- else:
- ax.set_xlabel('sp/sec')
- if np.mod(ch_id, 8) > 0:
- ax.set_yticks([])
- return None
- def plot_spectra(self, f_id=0, recompute=True, fs=20, v_thr=5, ymin=0, ymax=7, arr_id=1, save_fig=False):
- data = self.data[f_id, 0]
- if arr_id == 1:
- ch_ids = range(32, 96)
- else:
- ch_ids = list(range(32)) + list(range(96, 128))
- if not hasattr(self, 'Y') or recompute:
- # Y1, ff1 = analytics1.calc_fft(data, fs=fs, i_stop=-1, axis=0)
- ff, Y = signal.welch(data, fs=fs, axis=0, nperseg=1000)
- # fidx1 = (ff1 > 0.1) & (ff1 < 5)
- self.Y = Y
- self.ff = ff
- fidx = (self.ff > 0.1) & (self.ff < 5)
- v_mean = data.mean(axis=0)
- v_ids = np.argwhere(v_mean > v_thr) # get ids of channels with mean firing rate higher than v_thr
-
- fh = plt.figure(1, figsize=[19.2, 9.55])
- plt.clf()
- for ii, ch_id in enumerate(ch_ids):
- ax = plt.subplot(8, 8, ii + 1)
- # plt.plot(ff1[fidx1], Y1[fidx1, ch_id], label=f'{ch_id}')
- if ch_id in v_ids:
- plt.plot(self.ff[fidx], self.Y[fidx, ch_id], 'C0', lw=2, label=f'{ch_id}')
- else:
- plt.plot(self.ff[fidx], self.Y[fidx, ch_id], 'k', lw=1, alpha=0.5, label=f'{ch_id}')
- plt.legend(loc=1, prop={'size': 6})
- plt.ylim(ymin, ymax)
- if ch_id < 56:
- ax.set_xticks([])
- else:
- ax.set_xlabel('Hz')
- if np.mod(ch_id, 8) > 0:
- ax.set_yticks([])
- plt.tight_layout()
- if save_fig:
- fig_name = f'{self.results_path}{self.base_name.split("-")[0]}_arr_{arr_id}.png'
- plt.savefig(fig_name)
- print(f'saving figure {fh.number} in {fig_name}\n')
- return None
|