import matplotlib.pyplot as plt import numpy as np import os from scipy import signal class fr_analytics: def __init__(self, data, params, file_name): self.data = data self.params = params self.file_name = file_name self.base_name = os.path.basename(file_name).split('.')[0] self.results_path = self.params.file_handling.results + 'yes_no/' os.makedirs(self.results_path, exist_ok=True) return None def plot_pdf(self, f_id=0, ch_ids=range(32), xmax=50, ymax=1000): data = self.data[f_id, 0] plt.figure() plt.clf() for ch_id in range(len(ch_ids)): ax = plt.subplot(8, 8, ch_id + 1) hist, bin_edges = np.histogram(data[:, ch_id], range(-1, 50)) plt.bar(bin_edges[:-1], hist, width=1, label=f'{ch_ids[ch_id]}') plt.xlim(0, xmax) plt.ylim(0, ymax) plt.legend(loc=1, prop={'size': 6}) if ch_id < 56: ax.set_xticks([]) else: ax.set_xlabel('sp/sec') if np.mod(ch_id, 8) > 0: ax.set_yticks([]) return None def plot_spectra(self, f_id=0, recompute=True, fs=20, v_thr=5, ymin=0, ymax=7, arr_id=1, save_fig=False): data = self.data[f_id, 0] if arr_id == 1: ch_ids = range(32, 96) else: ch_ids = list(range(32)) + list(range(96, 128)) if not hasattr(self, 'Y') or recompute: # Y1, ff1 = analytics1.calc_fft(data, fs=fs, i_stop=-1, axis=0) ff, Y = signal.welch(data, fs=fs, axis=0, nperseg=1000) # fidx1 = (ff1 > 0.1) & (ff1 < 5) self.Y = Y self.ff = ff fidx = (self.ff > 0.1) & (self.ff < 5) v_mean = data.mean(axis=0) v_ids = np.argwhere(v_mean > v_thr) # get ids of channels with mean firing rate higher than v_thr fh = plt.figure(1, figsize=[19.2, 9.55]) plt.clf() for ii, ch_id in enumerate(ch_ids): ax = plt.subplot(8, 8, ii + 1) # plt.plot(ff1[fidx1], Y1[fidx1, ch_id], label=f'{ch_id}') if ch_id in v_ids: plt.plot(self.ff[fidx], self.Y[fidx, ch_id], 'C0', lw=2, label=f'{ch_id}') else: plt.plot(self.ff[fidx], self.Y[fidx, ch_id], 'k', lw=1, alpha=0.5, label=f'{ch_id}') plt.legend(loc=1, prop={'size': 6}) plt.ylim(ymin, ymax) if ch_id < 56: ax.set_xticks([]) else: ax.set_xlabel('Hz') if np.mod(ch_id, 8) > 0: ax.set_yticks([]) plt.tight_layout() if save_fig: fig_name = f'{self.results_path}{self.base_name.split("-")[0]}_arr_{arr_id}.png' plt.savefig(fig_name) print(f'saving figure {fh.number} in {fig_name}\n') return None