123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029 |
- import pickle
- import matplotlib.pyplot as plt
- import munch
- import numpy as np
- import scipy.signal as sgn
- from scipy import io
- from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
- from sklearn.externals import joblib
- from sklearn.model_selection import (StratifiedKFold, StratifiedShuffleSplit,
- cross_val_predict, cross_val_score,
- cross_validate)
- from sklearn.feature_selection import RFE
- from aux import log
- import aux
- class Classifier:
- def __init__(self, params, block_phase=[]):
-
- self.block_phase = block_phase
- self.init_buffer = 1 # this is to avoid decoding if buffer has not enough samples for full template
- try:
- res = self.set_params(params)
- except Exception as e:
- log.error(e)
- return None
- def set_params(self, params):
- params.classifier.template = np.array(params.classifier.template)
- params.classifier.n_features = params.classifier.template.size * params.daq.n_channels
- self.params = params
- if params.classifier.online:
- try:
- # self.clf = joblib.load('/kiap/data/classifier/monkey_model.joblib') # scikit LDA
- # self.clf2 = joblib.load('/kiap/data/classifier/monkey_model2.joblib') # explicit equations LDA
- self.clf1 = joblib.load(self.params.classifier.path_model1) # scikit LDA
- self.clf2 = joblib.load(self.params.classifier.path_model2) # explicit equations LDA
- except Exception as e:
- log.error('classifier model(s) not found. Send stop signal')
- self.params.classifier.online = False
- raise e
- # online classifier
- self.online_n_classes = self.params.classifier.n_classes
- self.online_history = []
- self.online_template = self.params.classifier.template
- self.online_template = self.online_template-self.online_template.min()+1 # shift template since it's a sliding window
- self.online_template = - self.online_template
- # TODO: Check that code below is valid (params.daq.exclude_channels is 1-based!)
- self.channel_mask = [ii for ii in range(self.params.daq.n_channels_max) if ii + 1 not in self.params.daq.exclude_channels]
- self.channel_mask = [ii for ii in range(len(self.channel_mask)) if ii not in self.params.classifier.exclude_channels]
- self.online_n_ch = len(self.channel_mask)
- self.online_features = np.zeros((1, len(self.online_template) * self.online_n_ch))
- if params.classifier.online:
- if self.online_features.shape[1] != self.clf1.coef_.shape[1]:
- log.error(f'# of features:{self.online_features.shape[1]}, # of coef:{self.clf1.coef_.shape[1]}')
- log.error('Mismatch in # of features. No online decoding possible')
- self.params.classifier.online = False
- self.online_decision = aux.decision.error2.value
- self.online_sig = [0] * self.online_n_classes
-
- return None
- if params.classifier.online:
- try:
- self.compare_params()
- except Exception as e:
- log.error(e)
- log.error('Parameter mismatch. Deactivating online classifier !')
- self.params.classifier.online = False
- raise(e)
- return None
- def compare_params(self):
- p1 = self.params.classifier.copy()
- p2 = self.clf1.params.copy()
- p3 = self.clf2.params.copy()
- for par in p1.items():
- if par[0] not in ['online', 'path_model1', 'path_model2']:
- if np.any(p1[par[0]] != p2[par[0]]) or np.any(p1[par[0]] != p3[par[0]]):
- if par[0] == 'n_classes':
- log.warning(f'Parameter "{par[0]}" mismatch. Current {par[1]}, models: {p2[par[0]]}, {p3[par[0]]} ')
- raise ValueError('Parameter mismatch!')
- return None
- def train_LDA(self, data_train=[], baseline=0, features_fname='', model_fname1='', model_fname2='', solver='lsqr', save_model=0):
- '''train an LDA with sklearn and using explicit formulations, save the sklearn model, the means etc
- Parameters
- ----------
- data_train: 1d-list of arrays, shape: (n_classes, 1)
- arrays in list have shape: (n_features, n_trials_train)
- solver: lsqr, svd (lsqr is faster, but svd gives better results)
- '''
- log.debug('Training LDA model...')
- if data_train == []: # load data from file, data_train shape: (features x trials)
- log.info('load features from file')
- with open(features_fname, 'rb') as f:
- data_train = pickle.load(f)
- else:
- log.debug('get features from classifier instance')
- n_features_data = data_train[0, 0].shape[0]
- n_features = self.params.classifier.n_features
- if n_features_data != n_features:
- n_features = n_features_data
- log.debug(f'Used the number of features from data: {n_features}')
-
- # n_trials_tot = self.params.classifier.n_trials_tot
- n_trials_tot = sum([data_train[ii, 0].shape[1] for ii in range(len(data_train))])
- n_classes = self.triggers_shape[1]
- # if baseline == 0: # add one for one additional class for the baseline
- # n_classes -= 1
- reg_fact = self.params.classifier.model_training.reg_fact
- X_tot = np.empty((0, n_features))
- y_tot = np.empty((0, 1))
- for ii in range(n_classes):
- X_tot = np.vstack((X_tot, data_train[ii, 0].T))
- y_tot = np.vstack((y_tot, data_train[ii, 0][0][:, None] * 00 + ii))
- y_tot = y_tot.flatten()
- means = np.zeros((n_features, n_classes))
- tmp_cov = np.zeros((n_features, n_trials_tot))
- # train model using sklearn LDA
- shrinkage = self.params.classifier.reg_fact
- if solver == 'svd':
- shrinkage = None
- if self.params.classifier.model_training.shrinkage == 'auto':
- shrinkage = 'auto'
- clf1 = LDA(solver=solver, shrinkage=shrinkage)
-
- # CROSS_VALIDATION PARAMETERS
- test_size = self.params.classifier.model_training.test_size
- n_splits = self.params.classifier.model_training.n_splits
- clf1.scores = []
- if self.params.classifier.model_training.cross_validation:
- # train model using scikit LDA
- # use stratified shuffle to get trials from all classes
- kf = StratifiedShuffleSplit(n_splits=n_splits, test_size=test_size, random_state=None)
- for train_index, test_index in kf.split(X_tot, y_tot):
- X_train, X_test = X_tot[train_index], X_tot[test_index]
- y_train, y_test = y_tot[train_index], y_tot[test_index]
- # X_train = X_tot
- # y_train = y_tot
- for class_id in range(n_classes):
- log.debug(f'class {class_id} split samples: {np.sum(y_train == class_id)}')
- try:
- clf1.fit(X_train, y_train)
- clf1.scores.append(clf1.score(X_test,y_test))
- except Exception as e:
- log.error('need more trials!')
- log.error(e)
- raise
- else:
- X_train = X_tot
- y_train = y_tot
- X_test = X_tot
- y_test = y_tot
- clf1.fit(X_train, y_train)
- clf1.scores.append(clf1.score(X_test,y_test))
-
- if self.params.classifier.model_training.fsel:
- selector = RFE(clf1, 20, step=1) # get best features
- selector = selector.fit(X_train, y_train)
- # train model using explicit LDA equations
- counter = 0
- for class_id in range(n_classes):
- means[:, class_id] = np.mean(X_train[y_train == class_id, :], axis=0)
- class_len = X_train[y_train == class_id, :].shape[0]
- tmp_cov[:, counter:counter + class_len] = X_train[y_train == class_id, :].T - np.repeat(means[:, class_id][:, None], class_len, axis=1)
- counter = counter + class_len
- cov_mat = np.cov(tmp_cov)
- # Cholesky factorization
- if (np.linalg.matrix_rank(cov_mat) < cov_mat.shape[0]) & (reg_fact == 0):
- reg_fact = 1e-6
- log.warning(f'Covariance matrix is singular! Using regularization {reg_fact}')
- cov_diag_mean = np.mean(np.diag(cov_mat))
- cov_mat_reg = (1 - reg_fact) * cov_mat + reg_fact * cov_diag_mean * np.eye(n_features)
- cov_mat_inv = np.linalg.pinv(cov_mat_reg)
- chol_mat = np.linalg.cholesky(cov_mat_inv).T
- clf2 = {'means': means, 'chol_mat': chol_mat}
- self.clf1 = clf1
- if self.params.classifier.model_training.fsel:
- self.clf1.ranking = selector.ranking_
- self.clf2 = munch.munchify(clf2)
- self.clf1.params = self.params.classifier.copy() # save all information for later comparison in online decoding
- self.clf2.params = self.params.classifier.copy() # save all information for later comparison in online decoding
- if save_model:
- log.info(f'Saving model1 in {model_fname1}')
- log.info(f'Saving model2 in {model_fname2}')
- joblib.dump(self.clf1, model_fname1) # scikit LDA
- joblib.dump(self.clf2, model_fname2) # explicit LDA
- return X_test, y_test, X_train, y_train
- def test_LDA(self, test_trials, session_id, compare_results=0):
- '''test model using test_trials, compare probabilities with the ones from matlab code
- Parameters
- ----------
- test_trials: (n_samples, n_features)
- '''
- n_features = self.params.classifier.n_features
- # n_trials_test = self.params.classifier.n_trials_test
- n_classes = self.params.classifier.n_triggers + 1
- for sess in range(session_id + 1, session_id + 2):
- if compare_results:
- classProb = io.loadmat('/kiap/data/tom/model/Q19_20131122/classProb{}.mat'.format(sess))['classProb']
- triggers0 = io.loadmat('/kiap/data/tom/model/Q19_20131122/triggers.mat'.format(sess))['hand_triggers']
- # calculate probabilities using lda from sklearn
- prob = self.clf1.predict_proba(test_trials)
- # calculate probabilities explicitly
- cumSesLen = test_trials.shape[0]
- log_prob = np.zeros((cumSesLen, n_classes))
- rel_prob = np.zeros((cumSesLen, n_classes))
- for class_id in range(n_classes):
- vect = test_trials - np.repeat(self.clf2['means'][:, class_id][:, None], cumSesLen, axis=1).T
- tmp = np.matmul(self.clf2['chol_mat'], vect.T)
- log_prob[:, class_id] = -np.sum(tmp**2, axis=0) / 2.
- for class_id in range(n_classes):
- log_prob_cl = log_prob - np.repeat(log_prob[:, class_id][:, None], n_classes, axis=1)
- rel_prob[:, class_id] = 1. / np.sum(np.exp(log_prob_cl), axis=1)
- if compare_results:
- self.triggers0 = triggers0
- self.classProb = classProb # imported from Matlab code
- self.rel_prob = rel_prob # computed here explicitly
- self.prob = prob # computed via sklearn
- # for class_id in range(n_classes-1):
- if compare_results:
- self.plot_results(sess, fig_number=1)
- return None
- def test_LDA_kiap(self, session_id, compare_results=0):
- '''test model using test_trials, compare probabilities with the ones from matlab code
- Parameters
- ----------
- test_trials: (n_samples, n_features)
- '''
- test_trials = self.test_trials
- n_features = self.params.classifier.n_features
- # n_trials_test = self.params.classifier.n_trials_test
- # n_classes = self.params.classifier.n_triggers + 1
- n_classes = self.params.classifier.n_classes
- for sess in range(session_id + 1, session_id + 2):
- # if compare_results:
- # classProb = io.loadmat('/kiap/data/tom/model/Q19_20131122/classProb{}.mat'.format(sess))['classProb']
- # triggers0 = io.loadmat('/kiap/data/tom/model/Q19_20131122/triggers.mat'.format(sess))['hand_triggers']
- # calculate probabilities using lda from sklearn
- prob = self.clf1.predict_proba(test_trials)
- # calculate probabilities explicitly
- cumSesLen = test_trials.shape[0]
- log_prob = np.zeros((cumSesLen, n_classes))
- rel_prob = np.zeros((cumSesLen, n_classes))
- for class_id in range(n_classes):
- vect = test_trials - np.repeat(self.clf2['means'][:, class_id][:, None], cumSesLen, axis=1).T
- tmp = np.matmul(self.clf2['chol_mat'], vect.T)
- log_prob[:, class_id] = -np.sum(tmp**2, axis=0) / 2.
- for class_id in range(n_classes):
- log_prob_cl = log_prob - np.repeat(log_prob[:, class_id][:, None], n_classes, axis=1)
- rel_prob[:, class_id] = 1. / np.sum(np.exp(log_prob_cl), axis=1)
- # rel_prob[:, class_id] = log_prob_cl # 20.02.19
-
- # if compare_results:
- # self.triggers0 = triggers0
- # self.classProb = classProb # imported from Matlab code
- self.rel_prob = rel_prob # computed here explicitly
- self.prob = prob # computed via sklearn
- # self.predict = self.clf1
- # for class_id in range(n_classes-1):
- if compare_results:
- self.plot_results_kiap(sess, fig_number=1)
- return None
- def plot_results(self, sess, compare_results=0, fig_number=1):
- col = plt.rcParams['axes.prop_cycle'].by_key()['color']
- self.triggers0 = io.loadmat('/kiap/data/tom/model/Q19_20131122/triggers.mat'.format(sess))['hand_triggers']
- # plt.figure(fig_number)
- # plt.clf()
- for class_id in range(self.params.classifier.n_triggers):
- triggers = np.hstack(self.triggers0[sess-1, class_id])
- ax1 = plt.subplot(311) # matlab code
- plt.plot(self.classProb[:, class_id], lw=1, alpha=0.5, color=col[class_id])
- plt.vlines(triggers - 601, 0, 2, color=col[class_id], lw=2)
- plt.xlim(0, 17500)
- plt.ylim(0, 1.2)
- plt.title('Session {}'.format(sess))
- ax2 = plt.subplot(312, sharex=ax1) # python
- plt.plot(self.rel_prob[:, class_id], lw=1, alpha=0.5, color=col[class_id])
- # plt.plot(prob[:, class_id], lw=1, alpha=0.5, color=col[class_id])
- plt.vlines(triggers - 601, 0, 2, color=col[class_id], lw=2)
- plt.xlim(0, 17500)
- plt.ylim(0, 1.2)
- ax3 = plt.subplot(313, sharex=ax1) # python sklearn
- plt.plot(self.prob[:, class_id], lw=1, alpha=0.5, color=col[class_id])
- plt.vlines(triggers - 601, 0, 2, color=col[class_id], lw=2)
- plt.xlim(0, 17500)
- plt.ylim(0, 1.2)
- plt.draw()
- plt.show()
- input('press enter for next session')
- return None
- def plot_results_kiap(self, sess, compare_results=0, fig_number=1):
- col = plt.rcParams['axes.prop_cycle'].by_key()['color']
- # self.triggers0 = io.loadmat('/kiap/data/tom/model/Q19_20131122/triggers.mat'.format(sess))['hand_triggers']
- # plt.figure(fig_number)
- # plt.clf()
- for class_id in range(self.params.classifier.n_triggers):
- # triggers = np.hstack(self.triggers0[sess-1, class_id])
- triggers = np.hstack(self.triggers[sess-1, class_id])
- ax1 = plt.subplot(311) # matlab code
- # plt.plot(self.classProb[:, class_id], lw=1, alpha=0.5, color=col[class_id])
- # plt.vlines(triggers - 601, 0, 2, color=col[class_id], lw=2)
- # plt.xlim(0, 17500)
- # plt.ylim(0, 1.2)
- # plt.title('Session {}'.format(sess))
- ax2 = plt.subplot(312, sharex=ax1) # python
- plt.plot(self.rel_prob[:, class_id], lw=1, alpha=0.5, color=col[class_id])
- # plt.plot(prob[:, class_id], lw=1, alpha=0.5, color=col[class_id])
- plt.vlines(triggers - 601, 0, 2, color=col[class_id], lw=2)
- plt.xlim(0, 500)
- plt.ylim(0, 1.2)
- ax3 = plt.subplot(313, sharex=ax1) # python sklearn
- plt.plot(self.prob[:, class_id], lw=1, alpha=0.5, color=col[class_id])
- plt.vlines(triggers - 601, 0, 2, color=col[class_id], lw=2)
- plt.xlim(0, 500)
- plt.ylim(0, 1.2)
- plt.draw()
- plt.show()
- input('press enter for next session')
- return None
- def get_trials_kiap(self, train_data, triggers, features_fname, save_features=0):
- log.debug('Computing feature matrix from training data ...')
-
- n_train_sess = train_data.size
- n_ch = train_data[0][0].shape[1]
- # n_triggers = self.params.classifier.n_triggers
- n_triggers = triggers.shape[1]
- n_classes = n_triggers
- deadtime = self.params.classifier.deadtime # number of samplepoints to consider around each trigger
- # n_neg_train = self.params.classifier.n_neg_train # number of sample-points to consider for baseline
- template = self.params.classifier.template
- log.debug(f'# of channels: {n_ch}')
- # GET CUED TRIALS
- n_pos_train = np.zeros((n_triggers,), dtype=int) # total number of triggers per trigger type
- for trg_ind in range(n_triggers):
- n_pos_train[trg_ind] = np.concatenate(triggers[:, trg_ind], axis=1).size
- n_train_tot = 0
- for sess in range(n_train_sess):
- n_train_tot += train_data[sess, 0].shape[0] # total number of samples
- # EXTRACTING TRAINING TRIALS
- # --------------------------
- train_trials = [0] * n_classes
- psth = [0] * n_classes
- cut_win = self.params.classifier.psth.cut[1] - self.params.classifier.psth.cut[0]
- for trg_id in range(n_triggers):
- train_trials[trg_id] = np.zeros((n_pos_train[trg_id], len(template) * n_ch))
- # train_trials_f = np.zeros((n_neg_train, len(template) * n_ch)) # optimize code by avoiding inner loop over channels
- trialCounter = [1] * n_classes
- for sess in range(n_train_sess):
-
- for trg_id in range(n_triggers):
- psth[trg_id] = np.zeros((triggers[sess, trg_id].size, cut_win, n_ch))
- print(psth[trg_id].shape)
- # CONSTRUCT FEATURE MATRICES FOR EACH CLASS
- for trg_id in range(n_triggers):
- for ii in range(triggers[sess, trg_id].size): # skip ii if sample points are < 0 or > max size of available data
- if (triggers[sess, trg_id][0, ii] + min(template) <= 0) or (triggers[sess, trg_id][0, ii] + max(template) > train_data[sess, 0].shape[0]):
- log.debug('exluded {} {} {}'.format(sess, trg_id, ii))
- continue
- for ch in range(0, n_ch):
- idx = list(range(ch * len(template), (ch + 1) * len(template))) # CAUTION: why -1 below? is trigger absolute index?
- train_trials[trg_id][trialCounter[trg_id] - 1, idx] = train_data[sess, 0][triggers[sess, trg_id][0, ii] + template - 1, ch]
- try:
- psth[trg_id][ii, :, ch] = train_data[sess, 0][triggers[sess, trg_id][0, ii]+self.params.classifier.psth.cut[0]:triggers[sess, trg_id][0, ii]+self.params.classifier.psth.cut[1], ch]
- except:
- log.debug(f'Trial {ii} out of limits')
- trialCounter[trg_id] = trialCounter[trg_id] + 1
- log.debug('feature for classes extracted')
- log.debug(f'Session {sess}, trialCounter {trialCounter}')
- if trialCounter == [1] * n_classes:
- raise ValueError('No feature trial(s) extracted, because triggers are too close to the edges!')
- # bring data into correct format to use afterwards
- train_trials2 = np.zeros((n_classes, 1), dtype='O')
- for cl in range(n_classes):
- train_trials2[cl, 0] = train_trials[cl][:trialCounter[cl] - 1, :].T # remove trials if they were outside margins. to be confirmed !
- if features_fname>'':
- log.debug('Saving feature array...')
- with open(features_fname, 'wb') as f:
- pickle.dump(train_trials2, f)
- self.train_trials = train_trials2
- self.psth = psth
- self.psth_xx = range(self.params.classifier.psth.cut[0], self.params.classifier.psth.cut[1])
- log.debug(f'feature-array session 0 shape: {train_trials2[0,0].shape}')
- return None
- def get_trials_kiap_bl(self, train_data, triggers, save_features=0):
- log.debug('Computing feature matrix from training data ...')
-
- n_train_sess = train_data.size
- n_ch = train_data[0][0].shape[1]
- # n_triggers = self.params.classifier.n_triggers
- n_triggers = triggers.shape[1]
- n_classes = n_triggers
- deadtime = self.params.classifier.deadtime # number of samplepoints to consider around each trigger
- # n_neg_train = self.params.classifier.n_neg_train # number of sample-points to consider for baseline
- template = self.params.classifier.template
- log.debug(f'# of channels: {n_ch}')
- # GET CUED TRIALS
- n_pos_train = np.zeros((n_triggers,), dtype=int) # total number of triggers per trigger type
- for trg_ind in range(n_triggers):
- n_pos_train[trg_ind] = np.concatenate(triggers[:, trg_ind], axis=1).size
- n_train_tot = 0
- for sess in range(n_train_sess):
- n_train_tot += train_data[sess, 0].shape[0] # total number of samples
- # EXTRACTING TRAINING TRIALS
- # --------------------------
- train_trials = [0] * n_classes
- psth = [0] * n_classes
- cut_win = self.params.classifier.psth.cut[1] - self.params.classifier.psth.cut[0]
- for trg_id in range(n_triggers):
- train_trials[trg_id] = np.zeros((n_pos_train[trg_id], len(template) * n_ch))
- # train_trials_f = np.zeros((n_neg_train, len(template) * n_ch)) # optimize code by avoiding inner loop over channels
- trialCounter = [1] * n_classes
- for sess in range(n_train_sess):
-
- for trg_id in range(n_triggers):
- psth[trg_id] = np.zeros((triggers[sess, trg_id].size, cut_win, n_ch))
- print(psth[trg_id].shape)
- # CONSTRUCT FEATURE MATRICES FOR EACH CLASS
- for trg_id in range(n_triggers):
- for ii in range(triggers[sess, trg_id].size): # skip ii if sample points are < 0 or > max size of available data
- if (triggers[sess, trg_id][0, ii] + min(template) <= 0) or (triggers[sess, trg_id][0, ii] + max(template) > train_data[sess, 0].shape[0]):
- log.debug('exluded {} {} {}'.format(sess, trg_id, ii))
- continue
- for ch in range(0, n_ch):
- idx = list(range(ch * len(template), (ch + 1) * len(template))) # CAUTION: why -1 below? is trigger absolute index?
- train_trials[trg_id][trialCounter[trg_id] - 1, idx] = train_data[sess, 0][triggers[sess, trg_id][0, ii] + template - 1, ch]
- try:
- psth[trg_id][ii, :, ch] = train_data[sess, 0][triggers[sess, trg_id][0, ii]+self.params.classifier.psth.cut[0]:triggers[sess, trg_id][0, ii]+self.params.classifier.psth.cut[1], ch]
- except:
- log.debug(f'Trial {ii} out of limits')
- trialCounter[trg_id] = trialCounter[trg_id] + 1
- log.debug('feature for classes extracted')
- log.debug(f'Session {sess}, trialCounter {trialCounter}')
- if trialCounter == [1] * n_classes:
- raise ValueError('No feature trial(s) extracted, because triggers are too close to the edges!')
- # bring data into correct format to use afterwards
- train_trials2 = np.zeros((n_classes, 1), dtype='O')
- for cl in range(n_classes):
- train_trials2[cl, 0] = train_trials[cl][:trialCounter[cl] - 1, :].T # remove trials if they were outside margins. to be confirmed !
- self.train_trials = train_trials2
- self.psth = psth
- self.psth_xx = range(self.params.classifier.psth.cut[0], self.params.classifier.psth.cut[1])
- log.debug(f'feature-array session 0 shape: {train_trials2[0,0].shape}')
- return None
- def get_trials(self, train_data, triggers, features_fname='', train_trials_comparison=[], assert_results=0):
- '''import spike rates, extract the trials and construct features.
- it yields identical results for the monkey data to the matlab code provided by Tom
-
- Parameters
- ----------
- train_data: (n_train_sess,1)
- train_data[ii,0]: (n_samples, n_ch) per element
- triggers: (n_train_sess, n_triggers)
- Saves as object variable
- ------------------------
- self.train_trials: (n_classes, 1) with (n_feat, n_trials) per class
- train_trials_comparison this file is imported only to assert that the resulting train_trials is equivalent to Matlab code
- Note: set assert_results=1 to compare with results from Matlab code provided by Tom
- '''
-
-
- log.info('Computing feature matrix from training data ...')
- n_train_sess = train_data.size
- n_ch = train_data[0][0].shape[1]
- # n_triggers = self.params.classifier.n_triggers
- n_triggers = triggers.shape[1]
- n_classes = n_triggers + 1
- deadtime = self.params.classifier.deadtime # number of samplepoints to consider around each trigger
- n_neg_train = self.params.classifier.n_neg_train # number of sample-points to consider for baseline
- # template = [0, -150, -300, -450, -600]
- template = self.params.classifier.template
- # GET CUED TRIALS
- n_pos_train = np.zeros((n_triggers,), dtype=int) # total number of triggers per trigger type
- for trg_ind in range(n_triggers):
- n_pos_train[trg_ind] = np.concatenate(triggers[:, trg_ind], axis=1).size
- n_train_tot = 0
- for sess in range(n_train_sess):
- n_train_tot += train_data[sess, 0].shape[0] # total number of samples
- # GET BASELINE TRIALS
- neg_train = np.zeros((2, n_train_tot), dtype=int) # row 0: session number, row 1: baseline indices for this session
- counter = 0
- for sess in range(n_train_sess): # build index with sample-points to exclude from baseline
- idx1 = list(range(deadtime + 1)) # exclude first and last samples of each session of length deadtime
- idx1.extend(list(range(train_data[sess, 0].shape[0] - deadtime, train_data[sess, 0].shape[0] + 1)))
- for trg_ind in range(n_triggers): # exclude also +/- deadtime samples around each trigger
- for tr in range(triggers[sess, trg_ind].size):
- idx1.extend(range(triggers[sess, trg_ind][0, tr] - deadtime, triggers[sess, trg_ind][0, tr] + deadtime + 1))
- tmpNegStarts = list(set(range(train_data[sess, 0].shape[0])) - set(idx1))
-
- neg_train[0, counter:counter + len(tmpNegStarts)] = sess
- neg_train[1, counter:counter + len(tmpNegStarts)] = tmpNegStarts
- counter = counter + len(tmpNegStarts)
-
- neg_train = neg_train[:, :counter] # remove zero elements from the end
- # note that 2nd row has to be modified for neg_train, it is being taken care of below
- # print(negTrain)
- # assert(np.array_equiv(negTrain, neg_train))
- for ii in range(8):
- log.debug(f'session {ii}, # baseline indices: {sum(neg_train[0, :] == ii)}')
- # limit the number of sample points for baseline
- if neg_train.shape[1] > n_neg_train:
- perm_idx = np.random.permutation(range(neg_train.shape[1]))
- neg_train = neg_train[:, perm_idx[:n_neg_train]]
- else:
- n_neg_train = neg_train.shape[1]
- # EXTRACTING TRAINING TRIALS
- # --------------------------
- train_trials = [0] * n_classes
- for trg_id in range(n_triggers):
- train_trials[trg_id] = np.zeros((n_pos_train[trg_id], len(template) * n_ch))
- train_trials[n_classes - 1] = np.zeros((n_neg_train, len(template) * n_ch))
- train_trials_f = np.zeros((n_neg_train, len(template) * n_ch)) # optimize code by avoiding inner loop over channels
- trialCounter = [1] * n_classes
- for sess in range(n_train_sess):
- # CONSTRUCT FEATURE MATRICES FOR EACH CLASS
- for trg_id in range(n_triggers):
- for ii in range(triggers[sess, trg_id].size): # skip ii if sample points are < 0 or > max size of available data
- if (triggers[sess, trg_id][0, ii] + min(template) <= 0) or (triggers[sess, trg_id][0, ii] + max(template) > train_data[sess, 0].shape[0]):
- log.debug('exluded {} {} {}'.format(sess, trg_id, ii))
- continue
- for ch in range(0, n_ch):
- idx = list(range(ch * len(template), (ch + 1) * len(template))) # CAUTION: why -1 below? is trigger absolute index?
- train_trials[trg_id][trialCounter[trg_id] - 1, idx] = train_data[sess, 0][triggers[sess, trg_id][0, ii] + template - 1, ch]
- if assert_results:
- assert(np.array_equiv(train_trials[trg_id][trialCounter[trg_id] - 1, idx], train_trials_comparison[trg_id, 0].T[trialCounter[trg_id] - 1, idx]))
- trialCounter[trg_id] = trialCounter[trg_id] + 1
- log.debug('feature for classes extracted')
- # CONSTRUCT FEATURE MATRICES FOR BASELINE
- neg_ind = np.where(neg_train[0, :] == sess)[0]
- for ii in range(0, len(neg_ind)): # throw away indices if they fall outside margins
- if (neg_train[1, neg_ind[ii]] + min(template) <= 0) or (neg_train[1, neg_ind[ii]] + max(template) > train_data[sess, 0].shape[0]):
- continue
- # print(ii)
- # for ch in range(0, n_ch):
- # idx = list(range(ch * len(template), (ch + 1) * len(template)))
- # train_trials[n_classes - 1][trialCounter[n_classes - 1]-1, idx] = train_data[sess, 0][neg_train[1, neg_ind[ii]] + template-1, ch]
- # # print(sess, ii, ch, 'neg trials')
- # if assert_results:
- # assert(np.array_equiv(train_trials[n_classes - 1][trialCounter[n_classes - 1] -1, idx], train_trials_comparison[n_classes - 1, 0].T[trialCounter[n_classes - 1] - 1, idx]))
- # avoid inner loop to speed up calculations
- fff = train_data[sess, 0][neg_train[1, neg_ind[ii]] + template-1, :].flatten('F')
- # fff = fff.flatten('F')
- train_trials_f[trialCounter[n_classes - 1] - 1] = fff
- trialCounter[n_classes - 1] = trialCounter[n_classes - 1] + 1
- train_trials[n_classes-1] = train_trials_f
- # print(np.array_equiv(train_trials_f, train_trials[n_classes - 1]))
- # assert(np.array_equiv(train_trials_f, train_trials[n_classes - 1]))
- log.debug('feature for baseline extracted')
- log.debug(f'Session {sess}, trialCounter {trialCounter}')
- # bring data into correct format to use afterwards
- train_trials2 = np.zeros((n_classes, 1), dtype='O')
- for cl in range(n_classes):
- train_trials2[cl, 0] = train_trials[cl][:trialCounter[cl]-1, :].T # remove trials if they were outside margins. to be confirmed !
- if features_fname>'':
- log.info('Saving feature array...')
- with open(features_fname, 'wb') as f:
- pickle.dump(train_trials2, f)
- self.train_trials = train_trials2
- return None
- def get_test_trials(self, test_data):
- '''this function imports spike rates, extracts the trials for TESTING and constructs features.'''
-
- # data_tot = io.loadmat('/kiap/data/tom/model/trainData_rates2.mat')['trainData']
- # test_data = data_tot[session_id:session_id + 1]
- # # this file is imported only to assert that test_trials is equivalent to dt
- # dt = io.loadmat('/home/vlachos/devel/iv_misc/decoding/tom/scripts/model/Q19_20131122/testTrials.mat')['testTrials']
- n_test_sess = test_data.size # number of test sessions
- n_ch = test_data[0][0].shape[1]
- n_triggers = self.params.classifier.n_triggers
- n_classes = n_triggers + 1
- deadtime = self.params.classifier.deadtime # number of samplepoints to consider around each trigger
- n_neg_train = self.params.classifier.n_neg_train # number of sample-points to consider for baseline
- template = self.params.classifier.template
- n_test = 0
- for sess in range(n_test_sess): # total number of sample points
- n_test = n_test + test_data[sess, 0].shape[0]
- # % Extracting testing trials
- test_trials = np.zeros((n_test, len(template) * n_ch))
- test_ind = np.zeros((n_test, 1), dtype=int)
- test_cut = np.zeros((n_test, 1))
- cum_ses_len = 0
- for sess in range(n_test_sess):
- start_sess = 1 - min(np.append(template, 0)) # get start index
- # start_sess = 0 - min(np.append(template, 0)) # get start index
- end_sess = test_data[sess, 0].shape[0] - max(np.append(template, 1)) + 1 # get end index
- sess_size = end_sess - start_sess + 1
- if sess_size <= 0:
- continue
- test_ind[cum_ses_len:cum_ses_len + sess_size, 0] = np.arange(start_sess, end_sess + 1) + max(np.append(template, 1)) - 1
- test_cut[cum_ses_len:cum_ses_len + sess_size, 0] = sess
- for ch in range(n_ch):
- for ii in range(len(template)):
- tmp = test_data[sess, 0][np.arange(start_sess - 1, end_sess) + template[ii], ch]
- test_trials[cum_ses_len:cum_ses_len + sess_size, ch * len(template) + ii] = tmp
- # print(tmp[-10:-1])
- # inputData{sess}((startSes:endSess) + template_detection(ii),ch);
- cum_ses_len = cum_ses_len + sess_size
- test_trials = test_trials[: cum_ses_len, :]
- test_ind = test_ind[:cum_ses_len]
- test_cut = test_cut[:cum_ses_len]
- return test_trials
- def get_test_trials_kiap(self, test_data):
- '''this function imports spike rates, extracts the trials for TESTING and constructs features.'''
-
- # data_tot = io.loadmat('/kiap/data/tom/model/trainData_rates2.mat')['trainData']
- # test_data = data_tot[session_id:session_id + 1]
- # # this file is imported only to assert that test_trials is equivalent to dt
- # dt = io.loadmat('/home/vlachos/devel/iv_misc/decoding/tom/scripts/model/Q19_20131122/testTrials.mat')['testTrials']
- n_test_sess = test_data.size # number of test sessions
- n_ch = test_data[0][0].shape[1]
- # n_triggers = self.params.classifier.n_triggers
- # n_classes = n_triggers + 1
- # deadtime = self.params.classifier.deadtime # number of samplepoints to consider around each trigger
- # n_neg_train = self.params.classifier.n_neg_train # number of sample-points to consider for baseline
- template = self.params.classifier.template
- template = template-template.min()+1
- n_test = 0
- for sess in range(n_test_sess): # total number of sample points
- n_test = n_test + test_data[sess, 0].shape[0]
- # % Extracting testing trials
- test_trials = np.zeros((n_test, len(template) * n_ch))
- test_ind = np.zeros((n_test, 1), dtype=int)
- test_cut = np.zeros((n_test, 1))
- cum_ses_len = 0
- for sess in range(n_test_sess):
- start_sess = 1 - min(np.append(template, 0)) # get start index
- # start_sess = 0 - min(np.append(template, 0)) # get start index
- end_sess = test_data[sess, 0].shape[0] - max(np.append(template, 1)) + 1 # get end index
- sess_size = end_sess - start_sess + 1
- if sess_size <= 0:
- continue
- test_ind[cum_ses_len:cum_ses_len + sess_size, 0] = np.arange(start_sess, end_sess + 1) + max(np.append(template, 1)) - 1
- test_cut[cum_ses_len:cum_ses_len + sess_size, 0] = sess
- for ch in range(n_ch):
- for ii in range(len(template)):
- tmp = test_data[sess, 0][np.arange(start_sess - 1, end_sess) - template[ii]-1, ch] # negative template !!!
- test_trials[cum_ses_len:cum_ses_len + sess_size, ch * len(template) + ii] = tmp
- # print(tmp[-10:-1])
- # inputData{sess}((startSes:endSess) + template_detection(ii),ch);
- cum_ses_len = cum_ses_len + sess_size
- test_trials = test_trials[: cum_ses_len, :]
- test_ind = test_ind[:cum_ses_len]
- test_cut = test_cut[:cum_ses_len]
- self.test_trials = test_trials
- return None
- def online_decoder(self, cur_data, session_id=0):
- plt.cla()
- # plt.ion()
- plt.ylim(0, 1)
- plt.xlim(0, 10000)
- # plt.ion()
- # td = cur_data[sess, 0]
- for jj in range(600, 2000, 20):
- self.get_class(cur_data[jj-600:jj], jj)
- return None
- def get_class(self, cur_data, tt):
- '''used for Tom's data'''
- # log.debug(cur_data.shape)
-
- n_ch = cur_data.shape[1]
- n_triggers = 4
- n_classes = n_triggers + 1
- template = [0, -150, -300, -450, -600]
- n_test=1
- test_trials = np.zeros((n_test, len(template) * n_ch))
- cum_ses_len = 0
- sess_size = 1
- for ch in range(n_ch):
- for ii in range(len(template)):
- tmp_data = cur_data[template[ii], ch]
- test_trials[cum_ses_len:cum_ses_len + sess_size, ch * len(template) + ii] = tmp_data
- col = plt.rcParams['axes.prop_cycle'].by_key()['color']
- cumSesLen = 1
- log_prob = np.zeros((cumSesLen, n_classes))
- rel_prob = np.zeros((cumSesLen, n_classes))
- for class_id in range(n_classes):
- vect = test_trials - np.repeat(self.clf2['means'][:, class_id][:, None], cumSesLen, axis=1).T
- tmp = np.matmul(self.clf2['chol_mat'], vect.T)
- log_prob[:, class_id] = -np.sum(tmp**2, axis=0) / 2.
- for class_id in range(n_classes):
- log_prob_cl = log_prob - np.repeat(log_prob[:, class_id][:, None], n_classes, axis=1)
- rel_prob[:, class_id] = 1. / np.sum(np.exp(log_prob_cl), axis=1)
- if self.params.system.plot:
- for class_id in range(self.params.classifier.n_classes-1):
- plt.plot(tt-600, rel_prob[:, class_id], lw=1, alpha=0.5, color=col[class_id], marker='.')
- plt.pause(.01)
- # print(rel_prob.shape)
- # print(tt)
- # plt.draw()
- # plt.show()
- # input('press enter')
- # fig = plt.gcf()
- # fig.canvas.mpl_connect('close_event', handle_close)
- my_decision = int(np.any(rel_prob[0, 0:4] > 0.7))
-
- return my_decision, rel_prob
- def get_class2(self, cur_data, tt, decoder_decision):
- '''used for data from NSP'''
-
- cur_data = cur_data[:, self.channel_mask] # get only channels used during training
-
-
- if self.init_buffer: #self.params.classifier.thr_window:
- self.online_decision = aux.decision.error1.value # not enough data at the beginning
- self.online_sig = [0] * self.online_n_classes
- # log.warning(f'not enough data, online_sig: {self.online_sig}')
- return None
- # log.warning(f'cur_data: {cur_data.shape}, {cur_data[:1,:3]}')
- for ch in range(self.online_n_ch):
- for ii in range(len(self.online_template)-1,-1,-1):
- # log.warning(f'----> {ii} {self.online_template[ii]}')
- tmp_data = cur_data[self.online_template[ii], ch]
- # log.warning(f'self.online_template: {self.online_template}, {ii}, {self.online_template[ii]}')
- self.online_features[0:1, ch * len(self.online_template) + ii] = tmp_data
- col = plt.rcParams['axes.prop_cycle'].by_key()['color']
- cumSesLen = 1
- log_prob = np.zeros((cumSesLen, self.online_n_classes))
- rel_prob = np.zeros((cumSesLen, self.online_n_classes))
- for class_id in range(self.online_n_classes):
- vect = self.online_features - np.repeat(self.clf2['means'][:, class_id][:, None], cumSesLen, axis=1).T
- tmp = np.matmul(self.clf2['chol_mat'], vect.T)
- log_prob[:, class_id] = -np.sum(tmp**2, axis=0) / 2.
- for class_id in range(self.online_n_classes):
- log_prob_cl = log_prob - np.repeat(log_prob[:, class_id][:, None], self.online_n_classes, axis=1)
- rel_prob[:, class_id] = 1. / np.sum(np.exp(log_prob_cl), axis=1)
-
- with np.errstate(divide='ignore',invalid='ignore'): # supress sklearn warning if probs are zero
- prob = self.clf1.predict_proba(self.online_features)
-
- # use either probabilites or class predictions
- sig = np.zeros((1,self.online_n_classes))
- if self.params.classifier.model_training.model == 'scikit' and self.params.classifier.peaks.sig == 'prob':
- sig = np.copy(prob)
- elif self.params.classifier.model_training.model == 'scikit' and self.params.classifier.peaks.sig == 'pred':
- pred_cl = int(self.clf1.predict(self.online_features))
- sig[0,pred_cl] = 1
- elif self.params.classifier.model_training.model == 'explicit':
- sig = np.copy(rel_prob)
- if np.any(sig[0]>self.params.classifier.thr_prob) and (self.block_phase.value == 2): # only if above thr and if in response state
- if self.params.classifier.peaks.sig == 'pred':
- self.online_history.append(int(np.argmax(np.array(sig[0])))) # 0: yes, 1:no, 2:baseline
- # log.warning(f'{sig}, {self.online_history[-1]}')
- else:
- self.online_history.append((np.array(sig[0]))) # 0: yes, 1:no, 2:baseline
- log.debug(sig)
- decoder_decision.value = -1
- # log.warning(f'sig:{sig}, prob: {prob}, {self.online_history}')
- # append only if prob cross threshold, and only if neural response has started
- # set decision if consecutive number of prob samples cross threshold
- if self.online_decision < 0 and len(self.online_history) >= self.params.classifier.thr_window:
- # set decision if same class appears in all samples within last thr_window
- # first if conditions: check if only decision in thr_window samples;
- if np.unique(self.online_history[-self.params.classifier.thr_window:]).size == 1 and (self.online_history[-1] < self.online_n_classes):
- if (self.online_n_classes ==3) and (self.online_history[-1] < 2): # DO NOT SEND DECISION FOR BASELINE
- decoder_decision.value = self.online_history[-1]
- self.online_history = []
- elif (self.online_n_classes ==2) and (self.online_history[-1] < 1):
- decoder_decision.value = self.online_history[-1]
- self.online_history = []
- if self.block_phase.value!=2:
- self.online_history = []
- # log.error(f'{self.online_decision}, {self.online_history}, {prob}')
- # log.debug(f'{prob}')
- self.online_sig = sig
- # self.online_sig = rel_prob
-
- return None #prob
|