123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- import importlib
- import numpy as np
- import matplotlib.pyplot as plt
- from sklearn.decomposition.pca import PCA
- import aux
- from helpers import data_management as dm
- from analytics import analytics1
- importlib.reload(analytics1)
- params_plot = {'legend.fontsize': 'small',
- 'figure.figsize': (10, 5),
- 'axes.labelsize': 'small',
- 'axes.titlesize':'small',
- 'xtick.labelsize':'small',
- 'ytick.labelsize':'small'}
- plt.rcParams.update(params_plot)
- params = aux.load_config()
- data_tot, tt, triggers, ch_rec_list = dm.get_raw(n_triggers=params.classifier.n_classes)
- if params.classifier.max_active_ch_nr != []:
- params.classifier.include_channels = np.argsort(np.max(data_tot[0,0],axis=0))[-params.classifier.max_active_ch_nr:] # get channels with high firing rates
- if params.classifier.exclude_channels == []:
- exclude_channels = list(set(range(128))-set(params.classifier.include_channels))
- data_tot = analytics1.exclude_channels(data_tot, exclude_channels) # exclude channels from analysis
- if params.daq.spike_rates.correct_bl_model:
- data_tot = analytics1.correct_init_baseline(data_tot, idx1, params.daq.spike_rates.bl_offset) #correct baseline as well
- for block_id in range(data_tot.shape[0]):
- # get channels with highest firing rate
- data = data_tot[block_id,0]
- ch_ids = np.argsort(np.max(data,axis=0))[-20:]
- print(f'session: {block_id}, ch_ids with highest max firing rates: {ch_ids}')
-
- pca1 = PCA(n_components=2)
- pca1.fit(data)
- res1 = pca1.transform(data)
- cc1 = np.corrcoef(np.max(data, axis=0), np.abs(pca1.components_[0]))
- cc2 = np.corrcoef(np.max(data, axis=0), np.abs(pca1.components_[1]))
- print(f'block: {block_id}, corr coefs: {cc1[0,1]}, {cc2[0,1]}')
- # plt.plot(np.max(data,axis=0))
- # plt.plot(pca1.components_)
- # show PC1, PC2
- plt.figure(1)
- plt.clf()
- plt.plot(res1[:,0])
- plt.plot(res1[:,1])
- plt.draw()
- # show all channels
- plt.figure(2)
- plt.clf()
- plt.imshow(data.T, aspect='auto')
- plt.draw()
- for block_id in range(data_tot.shape[0]):
- plt.figure(figsize=(6,6))
- plt.clf()
- data = data_tot[block_id,0]
- for ii in range(data.shape[1]):
- plt.subplot(8,8,ii+1)
- plt.hist(data[:,ii],bins=range(25))
- plt.xlim(-1,30)
- # plt.ylim(0,2000)
- if ii==56:
- plt.xlabel('Sp/sec', fontsize=8)
- if ii !=56:
- plt.xticks([])
- plt.yticks([])
- plt.suptitle(f'Firing rate distributions, array: M1, block: {block_id}')
- plt.draw()
- plt.show()
- fs = 20.
- fc = [1, 0]
- order = 5
- data1 = data[:,55:56]
- data2,_,_ = analytics1.bw_filter(data1.T, fc, fs, order, plot=True)
- xx = np.arange(data1.shape[0]) / fs
- Y, ff = analytics1.calc_fft(data2, fs)
- plt.figure(3)
- plt.clf()
- plt.subplot(211)
- plt.plot(xx, data1[:, 0], 'C1', label='original')
- plt.plot(xx, data2[:, 0], 'C0', label='filtered')
- plt.xlabel('Time (sec)')
- plt.ylabel('Sp/sec')
- plt.legend()
- plt.subplot(212)
- plt.plot(ff, Y[0, :], 'C0', label='power\nspectrum')
- plt.xlim(0, 1)
- plt.xlabel('Frequency (Hz)')
- plt.legend()
- plt.show()
|