Scheduled service maintenance on November 22


On Friday, November 22, 2024, between 06:00 CET and 18:00 CET, GIN services will undergo planned maintenance. Extended service interruptions should be expected. We will try to keep downtimes to a minimum, but recommend that users avoid critical tasks, large data uploads, or DOI requests during this time.

We apologize for any inconvenience.

classifier.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029
  1. import pickle
  2. import matplotlib.pyplot as plt
  3. import munch
  4. import numpy as np
  5. import scipy.signal as sgn
  6. from scipy import io
  7. from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
  8. from sklearn.externals import joblib
  9. from sklearn.model_selection import (StratifiedKFold, StratifiedShuffleSplit,
  10. cross_val_predict, cross_val_score,
  11. cross_validate)
  12. from sklearn.feature_selection import RFE
  13. from aux import log
  14. import aux
  15. class Classifier:
  16. def __init__(self, params, block_phase=[]):
  17. self.block_phase = block_phase
  18. self.init_buffer = 1 # this is to avoid decoding if buffer has not enough samples for full template
  19. try:
  20. res = self.set_params(params)
  21. except Exception as e:
  22. log.error(e)
  23. return None
  24. def set_params(self, params):
  25. params.classifier.template = np.array(params.classifier.template)
  26. params.classifier.n_features = params.classifier.template.size * params.daq.n_channels
  27. self.params = params
  28. if params.classifier.online:
  29. try:
  30. # self.clf = joblib.load('/kiap/data/classifier/monkey_model.joblib') # scikit LDA
  31. # self.clf2 = joblib.load('/kiap/data/classifier/monkey_model2.joblib') # explicit equations LDA
  32. self.clf1 = joblib.load(self.params.classifier.path_model1) # scikit LDA
  33. self.clf2 = joblib.load(self.params.classifier.path_model2) # explicit equations LDA
  34. except Exception as e:
  35. log.error('classifier model(s) not found. Send stop signal')
  36. self.params.classifier.online = False
  37. raise e
  38. # online classifier
  39. self.online_n_classes = self.params.classifier.n_classes
  40. self.online_history = []
  41. self.online_template = self.params.classifier.template
  42. self.online_template = self.online_template-self.online_template.min()+1 # shift template since it's a sliding window
  43. self.online_template = - self.online_template
  44. # TODO: Check that code below is valid (params.daq.exclude_channels is 1-based!)
  45. self.channel_mask = [ii for ii in range(self.params.daq.n_channels_max) if ii + 1 not in self.params.daq.exclude_channels]
  46. self.channel_mask = [ii for ii in range(len(self.channel_mask)) if ii not in self.params.classifier.exclude_channels]
  47. self.online_n_ch = len(self.channel_mask)
  48. self.online_features = np.zeros((1, len(self.online_template) * self.online_n_ch))
  49. if params.classifier.online:
  50. if self.online_features.shape[1] != self.clf1.coef_.shape[1]:
  51. log.error(f'# of features:{self.online_features.shape[1]}, # of coef:{self.clf1.coef_.shape[1]}')
  52. log.error('Mismatch in # of features. No online decoding possible')
  53. self.params.classifier.online = False
  54. self.online_decision = aux.decision.error2.value
  55. self.online_sig = [0] * self.online_n_classes
  56. return None
  57. if params.classifier.online:
  58. try:
  59. self.compare_params()
  60. except Exception as e:
  61. log.error(e)
  62. log.error('Parameter mismatch. Deactivating online classifier !')
  63. self.params.classifier.online = False
  64. raise(e)
  65. return None
  66. def compare_params(self):
  67. p1 = self.params.classifier.copy()
  68. p2 = self.clf1.params.copy()
  69. p3 = self.clf2.params.copy()
  70. for par in p1.items():
  71. if par[0] not in ['online', 'path_model1', 'path_model2']:
  72. if np.any(p1[par[0]] != p2[par[0]]) or np.any(p1[par[0]] != p3[par[0]]):
  73. if par[0] == 'n_classes':
  74. log.warning(f'Parameter "{par[0]}" mismatch. Current {par[1]}, models: {p2[par[0]]}, {p3[par[0]]} ')
  75. raise ValueError('Parameter mismatch!')
  76. return None
  77. def train_LDA(self, data_train=[], baseline=0, features_fname='', model_fname1='', model_fname2='', solver='lsqr', save_model=0):
  78. '''train an LDA with sklearn and using explicit formulations, save the sklearn model, the means etc
  79. Parameters
  80. ----------
  81. data_train: 1d-list of arrays, shape: (n_classes, 1)
  82. arrays in list have shape: (n_features, n_trials_train)
  83. solver: lsqr, svd (lsqr is faster, but svd gives better results)
  84. '''
  85. log.debug('Training LDA model...')
  86. if data_train == []: # load data from file, data_train shape: (features x trials)
  87. log.info('load features from file')
  88. with open(features_fname, 'rb') as f:
  89. data_train = pickle.load(f)
  90. else:
  91. log.debug('get features from classifier instance')
  92. n_features_data = data_train[0, 0].shape[0]
  93. n_features = self.params.classifier.n_features
  94. if n_features_data != n_features:
  95. n_features = n_features_data
  96. log.debug(f'Used the number of features from data: {n_features}')
  97. # n_trials_tot = self.params.classifier.n_trials_tot
  98. n_trials_tot = sum([data_train[ii, 0].shape[1] for ii in range(len(data_train))])
  99. n_classes = self.triggers_shape[1]
  100. # if baseline == 0: # add one for one additional class for the baseline
  101. # n_classes -= 1
  102. reg_fact = self.params.classifier.model_training.reg_fact
  103. X_tot = np.empty((0, n_features))
  104. y_tot = np.empty((0, 1))
  105. for ii in range(n_classes):
  106. X_tot = np.vstack((X_tot, data_train[ii, 0].T))
  107. y_tot = np.vstack((y_tot, data_train[ii, 0][0][:, None] * 00 + ii))
  108. y_tot = y_tot.flatten()
  109. means = np.zeros((n_features, n_classes))
  110. tmp_cov = np.zeros((n_features, n_trials_tot))
  111. # train model using sklearn LDA
  112. shrinkage = self.params.classifier.reg_fact
  113. if solver == 'svd':
  114. shrinkage = None
  115. if self.params.classifier.model_training.shrinkage == 'auto':
  116. shrinkage = 'auto'
  117. clf1 = LDA(solver=solver, shrinkage=shrinkage)
  118. # CROSS_VALIDATION PARAMETERS
  119. test_size = self.params.classifier.model_training.test_size
  120. n_splits = self.params.classifier.model_training.n_splits
  121. clf1.scores = []
  122. if self.params.classifier.model_training.cross_validation:
  123. # train model using scikit LDA
  124. # use stratified shuffle to get trials from all classes
  125. kf = StratifiedShuffleSplit(n_splits=n_splits, test_size=test_size, random_state=None)
  126. for train_index, test_index in kf.split(X_tot, y_tot):
  127. X_train, X_test = X_tot[train_index], X_tot[test_index]
  128. y_train, y_test = y_tot[train_index], y_tot[test_index]
  129. # X_train = X_tot
  130. # y_train = y_tot
  131. for class_id in range(n_classes):
  132. log.debug(f'class {class_id} split samples: {np.sum(y_train == class_id)}')
  133. try:
  134. clf1.fit(X_train, y_train)
  135. clf1.scores.append(clf1.score(X_test,y_test))
  136. except Exception as e:
  137. log.error('need more trials!')
  138. log.error(e)
  139. raise
  140. else:
  141. X_train = X_tot
  142. y_train = y_tot
  143. X_test = X_tot
  144. y_test = y_tot
  145. clf1.fit(X_train, y_train)
  146. clf1.scores.append(clf1.score(X_test,y_test))
  147. if self.params.classifier.model_training.fsel:
  148. selector = RFE(clf1, 20, step=1) # get best features
  149. selector = selector.fit(X_train, y_train)
  150. # train model using explicit LDA equations
  151. counter = 0
  152. for class_id in range(n_classes):
  153. means[:, class_id] = np.mean(X_train[y_train == class_id, :], axis=0)
  154. class_len = X_train[y_train == class_id, :].shape[0]
  155. tmp_cov[:, counter:counter + class_len] = X_train[y_train == class_id, :].T - np.repeat(means[:, class_id][:, None], class_len, axis=1)
  156. counter = counter + class_len
  157. cov_mat = np.cov(tmp_cov)
  158. # Cholesky factorization
  159. if (np.linalg.matrix_rank(cov_mat) < cov_mat.shape[0]) & (reg_fact == 0):
  160. reg_fact = 1e-6
  161. log.warning(f'Covariance matrix is singular! Using regularization {reg_fact}')
  162. cov_diag_mean = np.mean(np.diag(cov_mat))
  163. cov_mat_reg = (1 - reg_fact) * cov_mat + reg_fact * cov_diag_mean * np.eye(n_features)
  164. cov_mat_inv = np.linalg.pinv(cov_mat_reg)
  165. chol_mat = np.linalg.cholesky(cov_mat_inv).T
  166. clf2 = {'means': means, 'chol_mat': chol_mat}
  167. self.clf1 = clf1
  168. if self.params.classifier.model_training.fsel:
  169. self.clf1.ranking = selector.ranking_
  170. self.clf2 = munch.munchify(clf2)
  171. self.clf1.params = self.params.classifier.copy() # save all information for later comparison in online decoding
  172. self.clf2.params = self.params.classifier.copy() # save all information for later comparison in online decoding
  173. if save_model:
  174. log.info(f'Saving model1 in {model_fname1}')
  175. log.info(f'Saving model2 in {model_fname2}')
  176. joblib.dump(self.clf1, model_fname1) # scikit LDA
  177. joblib.dump(self.clf2, model_fname2) # explicit LDA
  178. return X_test, y_test, X_train, y_train
  179. def test_LDA(self, test_trials, session_id, compare_results=0):
  180. '''test model using test_trials, compare probabilities with the ones from matlab code
  181. Parameters
  182. ----------
  183. test_trials: (n_samples, n_features)
  184. '''
  185. n_features = self.params.classifier.n_features
  186. # n_trials_test = self.params.classifier.n_trials_test
  187. n_classes = self.params.classifier.n_triggers + 1
  188. for sess in range(session_id + 1, session_id + 2):
  189. if compare_results:
  190. classProb = io.loadmat('/kiap/data/tom/model/Q19_20131122/classProb{}.mat'.format(sess))['classProb']
  191. triggers0 = io.loadmat('/kiap/data/tom/model/Q19_20131122/triggers.mat'.format(sess))['hand_triggers']
  192. # calculate probabilities using lda from sklearn
  193. prob = self.clf1.predict_proba(test_trials)
  194. # calculate probabilities explicitly
  195. cumSesLen = test_trials.shape[0]
  196. log_prob = np.zeros((cumSesLen, n_classes))
  197. rel_prob = np.zeros((cumSesLen, n_classes))
  198. for class_id in range(n_classes):
  199. vect = test_trials - np.repeat(self.clf2['means'][:, class_id][:, None], cumSesLen, axis=1).T
  200. tmp = np.matmul(self.clf2['chol_mat'], vect.T)
  201. log_prob[:, class_id] = -np.sum(tmp**2, axis=0) / 2.
  202. for class_id in range(n_classes):
  203. log_prob_cl = log_prob - np.repeat(log_prob[:, class_id][:, None], n_classes, axis=1)
  204. rel_prob[:, class_id] = 1. / np.sum(np.exp(log_prob_cl), axis=1)
  205. if compare_results:
  206. self.triggers0 = triggers0
  207. self.classProb = classProb # imported from Matlab code
  208. self.rel_prob = rel_prob # computed here explicitly
  209. self.prob = prob # computed via sklearn
  210. # for class_id in range(n_classes-1):
  211. if compare_results:
  212. self.plot_results(sess, fig_number=1)
  213. return None
  214. def test_LDA_kiap(self, session_id, compare_results=0):
  215. '''test model using test_trials, compare probabilities with the ones from matlab code
  216. Parameters
  217. ----------
  218. test_trials: (n_samples, n_features)
  219. '''
  220. test_trials = self.test_trials
  221. n_features = self.params.classifier.n_features
  222. # n_trials_test = self.params.classifier.n_trials_test
  223. # n_classes = self.params.classifier.n_triggers + 1
  224. n_classes = self.params.classifier.n_classes
  225. for sess in range(session_id + 1, session_id + 2):
  226. # if compare_results:
  227. # classProb = io.loadmat('/kiap/data/tom/model/Q19_20131122/classProb{}.mat'.format(sess))['classProb']
  228. # triggers0 = io.loadmat('/kiap/data/tom/model/Q19_20131122/triggers.mat'.format(sess))['hand_triggers']
  229. # calculate probabilities using lda from sklearn
  230. prob = self.clf1.predict_proba(test_trials)
  231. # calculate probabilities explicitly
  232. cumSesLen = test_trials.shape[0]
  233. log_prob = np.zeros((cumSesLen, n_classes))
  234. rel_prob = np.zeros((cumSesLen, n_classes))
  235. for class_id in range(n_classes):
  236. vect = test_trials - np.repeat(self.clf2['means'][:, class_id][:, None], cumSesLen, axis=1).T
  237. tmp = np.matmul(self.clf2['chol_mat'], vect.T)
  238. log_prob[:, class_id] = -np.sum(tmp**2, axis=0) / 2.
  239. for class_id in range(n_classes):
  240. log_prob_cl = log_prob - np.repeat(log_prob[:, class_id][:, None], n_classes, axis=1)
  241. rel_prob[:, class_id] = 1. / np.sum(np.exp(log_prob_cl), axis=1)
  242. # rel_prob[:, class_id] = log_prob_cl # 20.02.19
  243. # if compare_results:
  244. # self.triggers0 = triggers0
  245. # self.classProb = classProb # imported from Matlab code
  246. self.rel_prob = rel_prob # computed here explicitly
  247. self.prob = prob # computed via sklearn
  248. # self.predict = self.clf1
  249. # for class_id in range(n_classes-1):
  250. if compare_results:
  251. self.plot_results_kiap(sess, fig_number=1)
  252. return None
  253. def plot_results(self, sess, compare_results=0, fig_number=1):
  254. col = plt.rcParams['axes.prop_cycle'].by_key()['color']
  255. self.triggers0 = io.loadmat('/kiap/data/tom/model/Q19_20131122/triggers.mat'.format(sess))['hand_triggers']
  256. # plt.figure(fig_number)
  257. # plt.clf()
  258. for class_id in range(self.params.classifier.n_triggers):
  259. triggers = np.hstack(self.triggers0[sess-1, class_id])
  260. ax1 = plt.subplot(311) # matlab code
  261. plt.plot(self.classProb[:, class_id], lw=1, alpha=0.5, color=col[class_id])
  262. plt.vlines(triggers - 601, 0, 2, color=col[class_id], lw=2)
  263. plt.xlim(0, 17500)
  264. plt.ylim(0, 1.2)
  265. plt.title('Session {}'.format(sess))
  266. ax2 = plt.subplot(312, sharex=ax1) # python
  267. plt.plot(self.rel_prob[:, class_id], lw=1, alpha=0.5, color=col[class_id])
  268. # plt.plot(prob[:, class_id], lw=1, alpha=0.5, color=col[class_id])
  269. plt.vlines(triggers - 601, 0, 2, color=col[class_id], lw=2)
  270. plt.xlim(0, 17500)
  271. plt.ylim(0, 1.2)
  272. ax3 = plt.subplot(313, sharex=ax1) # python sklearn
  273. plt.plot(self.prob[:, class_id], lw=1, alpha=0.5, color=col[class_id])
  274. plt.vlines(triggers - 601, 0, 2, color=col[class_id], lw=2)
  275. plt.xlim(0, 17500)
  276. plt.ylim(0, 1.2)
  277. plt.draw()
  278. plt.show()
  279. input('press enter for next session')
  280. return None
  281. def plot_results_kiap(self, sess, compare_results=0, fig_number=1):
  282. col = plt.rcParams['axes.prop_cycle'].by_key()['color']
  283. # self.triggers0 = io.loadmat('/kiap/data/tom/model/Q19_20131122/triggers.mat'.format(sess))['hand_triggers']
  284. # plt.figure(fig_number)
  285. # plt.clf()
  286. for class_id in range(self.params.classifier.n_triggers):
  287. # triggers = np.hstack(self.triggers0[sess-1, class_id])
  288. triggers = np.hstack(self.triggers[sess-1, class_id])
  289. ax1 = plt.subplot(311) # matlab code
  290. # plt.plot(self.classProb[:, class_id], lw=1, alpha=0.5, color=col[class_id])
  291. # plt.vlines(triggers - 601, 0, 2, color=col[class_id], lw=2)
  292. # plt.xlim(0, 17500)
  293. # plt.ylim(0, 1.2)
  294. # plt.title('Session {}'.format(sess))
  295. ax2 = plt.subplot(312, sharex=ax1) # python
  296. plt.plot(self.rel_prob[:, class_id], lw=1, alpha=0.5, color=col[class_id])
  297. # plt.plot(prob[:, class_id], lw=1, alpha=0.5, color=col[class_id])
  298. plt.vlines(triggers - 601, 0, 2, color=col[class_id], lw=2)
  299. plt.xlim(0, 500)
  300. plt.ylim(0, 1.2)
  301. ax3 = plt.subplot(313, sharex=ax1) # python sklearn
  302. plt.plot(self.prob[:, class_id], lw=1, alpha=0.5, color=col[class_id])
  303. plt.vlines(triggers - 601, 0, 2, color=col[class_id], lw=2)
  304. plt.xlim(0, 500)
  305. plt.ylim(0, 1.2)
  306. plt.draw()
  307. plt.show()
  308. input('press enter for next session')
  309. return None
  310. def get_trials_kiap(self, train_data, triggers, features_fname, save_features=0):
  311. log.debug('Computing feature matrix from training data ...')
  312. n_train_sess = train_data.size
  313. n_ch = train_data[0][0].shape[1]
  314. # n_triggers = self.params.classifier.n_triggers
  315. n_triggers = triggers.shape[1]
  316. n_classes = n_triggers
  317. deadtime = self.params.classifier.deadtime # number of samplepoints to consider around each trigger
  318. # n_neg_train = self.params.classifier.n_neg_train # number of sample-points to consider for baseline
  319. template = self.params.classifier.template
  320. log.debug(f'# of channels: {n_ch}')
  321. # GET CUED TRIALS
  322. n_pos_train = np.zeros((n_triggers,), dtype=int) # total number of triggers per trigger type
  323. for trg_ind in range(n_triggers):
  324. n_pos_train[trg_ind] = np.concatenate(triggers[:, trg_ind], axis=1).size
  325. n_train_tot = 0
  326. for sess in range(n_train_sess):
  327. n_train_tot += train_data[sess, 0].shape[0] # total number of samples
  328. # EXTRACTING TRAINING TRIALS
  329. # --------------------------
  330. train_trials = [0] * n_classes
  331. psth = [0] * n_classes
  332. cut_win = self.params.classifier.psth.cut[1] - self.params.classifier.psth.cut[0]
  333. for trg_id in range(n_triggers):
  334. train_trials[trg_id] = np.zeros((n_pos_train[trg_id], len(template) * n_ch))
  335. # train_trials_f = np.zeros((n_neg_train, len(template) * n_ch)) # optimize code by avoiding inner loop over channels
  336. trialCounter = [1] * n_classes
  337. for sess in range(n_train_sess):
  338. for trg_id in range(n_triggers):
  339. psth[trg_id] = np.zeros((triggers[sess, trg_id].size, cut_win, n_ch))
  340. print(psth[trg_id].shape)
  341. # CONSTRUCT FEATURE MATRICES FOR EACH CLASS
  342. for trg_id in range(n_triggers):
  343. for ii in range(triggers[sess, trg_id].size): # skip ii if sample points are < 0 or > max size of available data
  344. if (triggers[sess, trg_id][0, ii] + min(template) <= 0) or (triggers[sess, trg_id][0, ii] + max(template) > train_data[sess, 0].shape[0]):
  345. log.debug('exluded {} {} {}'.format(sess, trg_id, ii))
  346. continue
  347. for ch in range(0, n_ch):
  348. idx = list(range(ch * len(template), (ch + 1) * len(template))) # CAUTION: why -1 below? is trigger absolute index?
  349. train_trials[trg_id][trialCounter[trg_id] - 1, idx] = train_data[sess, 0][triggers[sess, trg_id][0, ii] + template - 1, ch]
  350. try:
  351. psth[trg_id][ii, :, ch] = train_data[sess, 0][triggers[sess, trg_id][0, ii]+self.params.classifier.psth.cut[0]:triggers[sess, trg_id][0, ii]+self.params.classifier.psth.cut[1], ch]
  352. except:
  353. log.debug(f'Trial {ii} out of limits')
  354. trialCounter[trg_id] = trialCounter[trg_id] + 1
  355. log.debug('feature for classes extracted')
  356. log.debug(f'Session {sess}, trialCounter {trialCounter}')
  357. if trialCounter == [1] * n_classes:
  358. raise ValueError('No feature trial(s) extracted, because triggers are too close to the edges!')
  359. # bring data into correct format to use afterwards
  360. train_trials2 = np.zeros((n_classes, 1), dtype='O')
  361. for cl in range(n_classes):
  362. train_trials2[cl, 0] = train_trials[cl][:trialCounter[cl] - 1, :].T # remove trials if they were outside margins. to be confirmed !
  363. if features_fname>'':
  364. log.debug('Saving feature array...')
  365. with open(features_fname, 'wb') as f:
  366. pickle.dump(train_trials2, f)
  367. self.train_trials = train_trials2
  368. self.psth = psth
  369. self.psth_xx = range(self.params.classifier.psth.cut[0], self.params.classifier.psth.cut[1])
  370. log.debug(f'feature-array session 0 shape: {train_trials2[0,0].shape}')
  371. return None
  372. def get_trials_kiap_bl(self, train_data, triggers, save_features=0):
  373. log.debug('Computing feature matrix from training data ...')
  374. n_train_sess = train_data.size
  375. n_ch = train_data[0][0].shape[1]
  376. # n_triggers = self.params.classifier.n_triggers
  377. n_triggers = triggers.shape[1]
  378. n_classes = n_triggers
  379. deadtime = self.params.classifier.deadtime # number of samplepoints to consider around each trigger
  380. # n_neg_train = self.params.classifier.n_neg_train # number of sample-points to consider for baseline
  381. template = self.params.classifier.template
  382. log.debug(f'# of channels: {n_ch}')
  383. # GET CUED TRIALS
  384. n_pos_train = np.zeros((n_triggers,), dtype=int) # total number of triggers per trigger type
  385. for trg_ind in range(n_triggers):
  386. n_pos_train[trg_ind] = np.concatenate(triggers[:, trg_ind], axis=1).size
  387. n_train_tot = 0
  388. for sess in range(n_train_sess):
  389. n_train_tot += train_data[sess, 0].shape[0] # total number of samples
  390. # EXTRACTING TRAINING TRIALS
  391. # --------------------------
  392. train_trials = [0] * n_classes
  393. psth = [0] * n_classes
  394. cut_win = self.params.classifier.psth.cut[1] - self.params.classifier.psth.cut[0]
  395. for trg_id in range(n_triggers):
  396. train_trials[trg_id] = np.zeros((n_pos_train[trg_id], len(template) * n_ch))
  397. # train_trials_f = np.zeros((n_neg_train, len(template) * n_ch)) # optimize code by avoiding inner loop over channels
  398. trialCounter = [1] * n_classes
  399. for sess in range(n_train_sess):
  400. for trg_id in range(n_triggers):
  401. psth[trg_id] = np.zeros((triggers[sess, trg_id].size, cut_win, n_ch))
  402. print(psth[trg_id].shape)
  403. # CONSTRUCT FEATURE MATRICES FOR EACH CLASS
  404. for trg_id in range(n_triggers):
  405. for ii in range(triggers[sess, trg_id].size): # skip ii if sample points are < 0 or > max size of available data
  406. if (triggers[sess, trg_id][0, ii] + min(template) <= 0) or (triggers[sess, trg_id][0, ii] + max(template) > train_data[sess, 0].shape[0]):
  407. log.debug('exluded {} {} {}'.format(sess, trg_id, ii))
  408. continue
  409. for ch in range(0, n_ch):
  410. idx = list(range(ch * len(template), (ch + 1) * len(template))) # CAUTION: why -1 below? is trigger absolute index?
  411. train_trials[trg_id][trialCounter[trg_id] - 1, idx] = train_data[sess, 0][triggers[sess, trg_id][0, ii] + template - 1, ch]
  412. try:
  413. psth[trg_id][ii, :, ch] = train_data[sess, 0][triggers[sess, trg_id][0, ii]+self.params.classifier.psth.cut[0]:triggers[sess, trg_id][0, ii]+self.params.classifier.psth.cut[1], ch]
  414. except:
  415. log.debug(f'Trial {ii} out of limits')
  416. trialCounter[trg_id] = trialCounter[trg_id] + 1
  417. log.debug('feature for classes extracted')
  418. log.debug(f'Session {sess}, trialCounter {trialCounter}')
  419. if trialCounter == [1] * n_classes:
  420. raise ValueError('No feature trial(s) extracted, because triggers are too close to the edges!')
  421. # bring data into correct format to use afterwards
  422. train_trials2 = np.zeros((n_classes, 1), dtype='O')
  423. for cl in range(n_classes):
  424. train_trials2[cl, 0] = train_trials[cl][:trialCounter[cl] - 1, :].T # remove trials if they were outside margins. to be confirmed !
  425. self.train_trials = train_trials2
  426. self.psth = psth
  427. self.psth_xx = range(self.params.classifier.psth.cut[0], self.params.classifier.psth.cut[1])
  428. log.debug(f'feature-array session 0 shape: {train_trials2[0,0].shape}')
  429. return None
  430. def get_trials(self, train_data, triggers, features_fname='', train_trials_comparison=[], assert_results=0):
  431. '''import spike rates, extract the trials and construct features.
  432. it yields identical results for the monkey data to the matlab code provided by Tom
  433. Parameters
  434. ----------
  435. train_data: (n_train_sess,1)
  436. train_data[ii,0]: (n_samples, n_ch) per element
  437. triggers: (n_train_sess, n_triggers)
  438. Saves as object variable
  439. ------------------------
  440. self.train_trials: (n_classes, 1) with (n_feat, n_trials) per class
  441. train_trials_comparison this file is imported only to assert that the resulting train_trials is equivalent to Matlab code
  442. Note: set assert_results=1 to compare with results from Matlab code provided by Tom
  443. '''
  444. log.info('Computing feature matrix from training data ...')
  445. n_train_sess = train_data.size
  446. n_ch = train_data[0][0].shape[1]
  447. # n_triggers = self.params.classifier.n_triggers
  448. n_triggers = triggers.shape[1]
  449. n_classes = n_triggers + 1
  450. deadtime = self.params.classifier.deadtime # number of samplepoints to consider around each trigger
  451. n_neg_train = self.params.classifier.n_neg_train # number of sample-points to consider for baseline
  452. # template = [0, -150, -300, -450, -600]
  453. template = self.params.classifier.template
  454. # GET CUED TRIALS
  455. n_pos_train = np.zeros((n_triggers,), dtype=int) # total number of triggers per trigger type
  456. for trg_ind in range(n_triggers):
  457. n_pos_train[trg_ind] = np.concatenate(triggers[:, trg_ind], axis=1).size
  458. n_train_tot = 0
  459. for sess in range(n_train_sess):
  460. n_train_tot += train_data[sess, 0].shape[0] # total number of samples
  461. # GET BASELINE TRIALS
  462. neg_train = np.zeros((2, n_train_tot), dtype=int) # row 0: session number, row 1: baseline indices for this session
  463. counter = 0
  464. for sess in range(n_train_sess): # build index with sample-points to exclude from baseline
  465. idx1 = list(range(deadtime + 1)) # exclude first and last samples of each session of length deadtime
  466. idx1.extend(list(range(train_data[sess, 0].shape[0] - deadtime, train_data[sess, 0].shape[0] + 1)))
  467. for trg_ind in range(n_triggers): # exclude also +/- deadtime samples around each trigger
  468. for tr in range(triggers[sess, trg_ind].size):
  469. idx1.extend(range(triggers[sess, trg_ind][0, tr] - deadtime, triggers[sess, trg_ind][0, tr] + deadtime + 1))
  470. tmpNegStarts = list(set(range(train_data[sess, 0].shape[0])) - set(idx1))
  471. neg_train[0, counter:counter + len(tmpNegStarts)] = sess
  472. neg_train[1, counter:counter + len(tmpNegStarts)] = tmpNegStarts
  473. counter = counter + len(tmpNegStarts)
  474. neg_train = neg_train[:, :counter] # remove zero elements from the end
  475. # note that 2nd row has to be modified for neg_train, it is being taken care of below
  476. # print(negTrain)
  477. # assert(np.array_equiv(negTrain, neg_train))
  478. for ii in range(8):
  479. log.debug(f'session {ii}, # baseline indices: {sum(neg_train[0, :] == ii)}')
  480. # limit the number of sample points for baseline
  481. if neg_train.shape[1] > n_neg_train:
  482. perm_idx = np.random.permutation(range(neg_train.shape[1]))
  483. neg_train = neg_train[:, perm_idx[:n_neg_train]]
  484. else:
  485. n_neg_train = neg_train.shape[1]
  486. # EXTRACTING TRAINING TRIALS
  487. # --------------------------
  488. train_trials = [0] * n_classes
  489. for trg_id in range(n_triggers):
  490. train_trials[trg_id] = np.zeros((n_pos_train[trg_id], len(template) * n_ch))
  491. train_trials[n_classes - 1] = np.zeros((n_neg_train, len(template) * n_ch))
  492. train_trials_f = np.zeros((n_neg_train, len(template) * n_ch)) # optimize code by avoiding inner loop over channels
  493. trialCounter = [1] * n_classes
  494. for sess in range(n_train_sess):
  495. # CONSTRUCT FEATURE MATRICES FOR EACH CLASS
  496. for trg_id in range(n_triggers):
  497. for ii in range(triggers[sess, trg_id].size): # skip ii if sample points are < 0 or > max size of available data
  498. if (triggers[sess, trg_id][0, ii] + min(template) <= 0) or (triggers[sess, trg_id][0, ii] + max(template) > train_data[sess, 0].shape[0]):
  499. log.debug('exluded {} {} {}'.format(sess, trg_id, ii))
  500. continue
  501. for ch in range(0, n_ch):
  502. idx = list(range(ch * len(template), (ch + 1) * len(template))) # CAUTION: why -1 below? is trigger absolute index?
  503. train_trials[trg_id][trialCounter[trg_id] - 1, idx] = train_data[sess, 0][triggers[sess, trg_id][0, ii] + template - 1, ch]
  504. if assert_results:
  505. assert(np.array_equiv(train_trials[trg_id][trialCounter[trg_id] - 1, idx], train_trials_comparison[trg_id, 0].T[trialCounter[trg_id] - 1, idx]))
  506. trialCounter[trg_id] = trialCounter[trg_id] + 1
  507. log.debug('feature for classes extracted')
  508. # CONSTRUCT FEATURE MATRICES FOR BASELINE
  509. neg_ind = np.where(neg_train[0, :] == sess)[0]
  510. for ii in range(0, len(neg_ind)): # throw away indices if they fall outside margins
  511. if (neg_train[1, neg_ind[ii]] + min(template) <= 0) or (neg_train[1, neg_ind[ii]] + max(template) > train_data[sess, 0].shape[0]):
  512. continue
  513. # print(ii)
  514. # for ch in range(0, n_ch):
  515. # idx = list(range(ch * len(template), (ch + 1) * len(template)))
  516. # train_trials[n_classes - 1][trialCounter[n_classes - 1]-1, idx] = train_data[sess, 0][neg_train[1, neg_ind[ii]] + template-1, ch]
  517. # # print(sess, ii, ch, 'neg trials')
  518. # if assert_results:
  519. # assert(np.array_equiv(train_trials[n_classes - 1][trialCounter[n_classes - 1] -1, idx], train_trials_comparison[n_classes - 1, 0].T[trialCounter[n_classes - 1] - 1, idx]))
  520. # avoid inner loop to speed up calculations
  521. fff = train_data[sess, 0][neg_train[1, neg_ind[ii]] + template-1, :].flatten('F')
  522. # fff = fff.flatten('F')
  523. train_trials_f[trialCounter[n_classes - 1] - 1] = fff
  524. trialCounter[n_classes - 1] = trialCounter[n_classes - 1] + 1
  525. train_trials[n_classes-1] = train_trials_f
  526. # print(np.array_equiv(train_trials_f, train_trials[n_classes - 1]))
  527. # assert(np.array_equiv(train_trials_f, train_trials[n_classes - 1]))
  528. log.debug('feature for baseline extracted')
  529. log.debug(f'Session {sess}, trialCounter {trialCounter}')
  530. # bring data into correct format to use afterwards
  531. train_trials2 = np.zeros((n_classes, 1), dtype='O')
  532. for cl in range(n_classes):
  533. train_trials2[cl, 0] = train_trials[cl][:trialCounter[cl]-1, :].T # remove trials if they were outside margins. to be confirmed !
  534. if features_fname>'':
  535. log.info('Saving feature array...')
  536. with open(features_fname, 'wb') as f:
  537. pickle.dump(train_trials2, f)
  538. self.train_trials = train_trials2
  539. return None
  540. def get_test_trials(self, test_data):
  541. '''this function imports spike rates, extracts the trials for TESTING and constructs features.'''
  542. # data_tot = io.loadmat('/kiap/data/tom/model/trainData_rates2.mat')['trainData']
  543. # test_data = data_tot[session_id:session_id + 1]
  544. # # this file is imported only to assert that test_trials is equivalent to dt
  545. # dt = io.loadmat('/home/vlachos/devel/iv_misc/decoding/tom/scripts/model/Q19_20131122/testTrials.mat')['testTrials']
  546. n_test_sess = test_data.size # number of test sessions
  547. n_ch = test_data[0][0].shape[1]
  548. n_triggers = self.params.classifier.n_triggers
  549. n_classes = n_triggers + 1
  550. deadtime = self.params.classifier.deadtime # number of samplepoints to consider around each trigger
  551. n_neg_train = self.params.classifier.n_neg_train # number of sample-points to consider for baseline
  552. template = self.params.classifier.template
  553. n_test = 0
  554. for sess in range(n_test_sess): # total number of sample points
  555. n_test = n_test + test_data[sess, 0].shape[0]
  556. # % Extracting testing trials
  557. test_trials = np.zeros((n_test, len(template) * n_ch))
  558. test_ind = np.zeros((n_test, 1), dtype=int)
  559. test_cut = np.zeros((n_test, 1))
  560. cum_ses_len = 0
  561. for sess in range(n_test_sess):
  562. start_sess = 1 - min(np.append(template, 0)) # get start index
  563. # start_sess = 0 - min(np.append(template, 0)) # get start index
  564. end_sess = test_data[sess, 0].shape[0] - max(np.append(template, 1)) + 1 # get end index
  565. sess_size = end_sess - start_sess + 1
  566. if sess_size <= 0:
  567. continue
  568. test_ind[cum_ses_len:cum_ses_len + sess_size, 0] = np.arange(start_sess, end_sess + 1) + max(np.append(template, 1)) - 1
  569. test_cut[cum_ses_len:cum_ses_len + sess_size, 0] = sess
  570. for ch in range(n_ch):
  571. for ii in range(len(template)):
  572. tmp = test_data[sess, 0][np.arange(start_sess - 1, end_sess) + template[ii], ch]
  573. test_trials[cum_ses_len:cum_ses_len + sess_size, ch * len(template) + ii] = tmp
  574. # print(tmp[-10:-1])
  575. # inputData{sess}((startSes:endSess) + template_detection(ii),ch);
  576. cum_ses_len = cum_ses_len + sess_size
  577. test_trials = test_trials[: cum_ses_len, :]
  578. test_ind = test_ind[:cum_ses_len]
  579. test_cut = test_cut[:cum_ses_len]
  580. return test_trials
  581. def get_test_trials_kiap(self, test_data):
  582. '''this function imports spike rates, extracts the trials for TESTING and constructs features.'''
  583. # data_tot = io.loadmat('/kiap/data/tom/model/trainData_rates2.mat')['trainData']
  584. # test_data = data_tot[session_id:session_id + 1]
  585. # # this file is imported only to assert that test_trials is equivalent to dt
  586. # dt = io.loadmat('/home/vlachos/devel/iv_misc/decoding/tom/scripts/model/Q19_20131122/testTrials.mat')['testTrials']
  587. n_test_sess = test_data.size # number of test sessions
  588. n_ch = test_data[0][0].shape[1]
  589. # n_triggers = self.params.classifier.n_triggers
  590. # n_classes = n_triggers + 1
  591. # deadtime = self.params.classifier.deadtime # number of samplepoints to consider around each trigger
  592. # n_neg_train = self.params.classifier.n_neg_train # number of sample-points to consider for baseline
  593. template = self.params.classifier.template
  594. template = template-template.min()+1
  595. n_test = 0
  596. for sess in range(n_test_sess): # total number of sample points
  597. n_test = n_test + test_data[sess, 0].shape[0]
  598. # % Extracting testing trials
  599. test_trials = np.zeros((n_test, len(template) * n_ch))
  600. test_ind = np.zeros((n_test, 1), dtype=int)
  601. test_cut = np.zeros((n_test, 1))
  602. cum_ses_len = 0
  603. for sess in range(n_test_sess):
  604. start_sess = 1 - min(np.append(template, 0)) # get start index
  605. # start_sess = 0 - min(np.append(template, 0)) # get start index
  606. end_sess = test_data[sess, 0].shape[0] - max(np.append(template, 1)) + 1 # get end index
  607. sess_size = end_sess - start_sess + 1
  608. if sess_size <= 0:
  609. continue
  610. test_ind[cum_ses_len:cum_ses_len + sess_size, 0] = np.arange(start_sess, end_sess + 1) + max(np.append(template, 1)) - 1
  611. test_cut[cum_ses_len:cum_ses_len + sess_size, 0] = sess
  612. for ch in range(n_ch):
  613. for ii in range(len(template)):
  614. tmp = test_data[sess, 0][np.arange(start_sess - 1, end_sess) - template[ii]-1, ch] # negative template !!!
  615. test_trials[cum_ses_len:cum_ses_len + sess_size, ch * len(template) + ii] = tmp
  616. # print(tmp[-10:-1])
  617. # inputData{sess}((startSes:endSess) + template_detection(ii),ch);
  618. cum_ses_len = cum_ses_len + sess_size
  619. test_trials = test_trials[: cum_ses_len, :]
  620. test_ind = test_ind[:cum_ses_len]
  621. test_cut = test_cut[:cum_ses_len]
  622. self.test_trials = test_trials
  623. return None
  624. def online_decoder(self, cur_data, session_id=0):
  625. plt.cla()
  626. # plt.ion()
  627. plt.ylim(0, 1)
  628. plt.xlim(0, 10000)
  629. # plt.ion()
  630. # td = cur_data[sess, 0]
  631. for jj in range(600, 2000, 20):
  632. self.get_class(cur_data[jj-600:jj], jj)
  633. return None
  634. def get_class(self, cur_data, tt):
  635. '''used for Tom's data'''
  636. # log.debug(cur_data.shape)
  637. n_ch = cur_data.shape[1]
  638. n_triggers = 4
  639. n_classes = n_triggers + 1
  640. template = [0, -150, -300, -450, -600]
  641. n_test=1
  642. test_trials = np.zeros((n_test, len(template) * n_ch))
  643. cum_ses_len = 0
  644. sess_size = 1
  645. for ch in range(n_ch):
  646. for ii in range(len(template)):
  647. tmp_data = cur_data[template[ii], ch]
  648. test_trials[cum_ses_len:cum_ses_len + sess_size, ch * len(template) + ii] = tmp_data
  649. col = plt.rcParams['axes.prop_cycle'].by_key()['color']
  650. cumSesLen = 1
  651. log_prob = np.zeros((cumSesLen, n_classes))
  652. rel_prob = np.zeros((cumSesLen, n_classes))
  653. for class_id in range(n_classes):
  654. vect = test_trials - np.repeat(self.clf2['means'][:, class_id][:, None], cumSesLen, axis=1).T
  655. tmp = np.matmul(self.clf2['chol_mat'], vect.T)
  656. log_prob[:, class_id] = -np.sum(tmp**2, axis=0) / 2.
  657. for class_id in range(n_classes):
  658. log_prob_cl = log_prob - np.repeat(log_prob[:, class_id][:, None], n_classes, axis=1)
  659. rel_prob[:, class_id] = 1. / np.sum(np.exp(log_prob_cl), axis=1)
  660. if self.params.system.plot:
  661. for class_id in range(self.params.classifier.n_classes-1):
  662. plt.plot(tt-600, rel_prob[:, class_id], lw=1, alpha=0.5, color=col[class_id], marker='.')
  663. plt.pause(.01)
  664. # print(rel_prob.shape)
  665. # print(tt)
  666. # plt.draw()
  667. # plt.show()
  668. # input('press enter')
  669. # fig = plt.gcf()
  670. # fig.canvas.mpl_connect('close_event', handle_close)
  671. my_decision = int(np.any(rel_prob[0, 0:4] > 0.7))
  672. return my_decision, rel_prob
  673. def get_class2(self, cur_data, tt, decoder_decision):
  674. '''used for data from NSP'''
  675. cur_data = cur_data[:, self.channel_mask] # get only channels used during training
  676. if self.init_buffer: #self.params.classifier.thr_window:
  677. self.online_decision = aux.decision.error1.value # not enough data at the beginning
  678. self.online_sig = [0] * self.online_n_classes
  679. # log.warning(f'not enough data, online_sig: {self.online_sig}')
  680. return None
  681. # log.warning(f'cur_data: {cur_data.shape}, {cur_data[:1,:3]}')
  682. for ch in range(self.online_n_ch):
  683. for ii in range(len(self.online_template)-1,-1,-1):
  684. # log.warning(f'----> {ii} {self.online_template[ii]}')
  685. tmp_data = cur_data[self.online_template[ii], ch]
  686. # log.warning(f'self.online_template: {self.online_template}, {ii}, {self.online_template[ii]}')
  687. self.online_features[0:1, ch * len(self.online_template) + ii] = tmp_data
  688. col = plt.rcParams['axes.prop_cycle'].by_key()['color']
  689. cumSesLen = 1
  690. log_prob = np.zeros((cumSesLen, self.online_n_classes))
  691. rel_prob = np.zeros((cumSesLen, self.online_n_classes))
  692. for class_id in range(self.online_n_classes):
  693. vect = self.online_features - np.repeat(self.clf2['means'][:, class_id][:, None], cumSesLen, axis=1).T
  694. tmp = np.matmul(self.clf2['chol_mat'], vect.T)
  695. log_prob[:, class_id] = -np.sum(tmp**2, axis=0) / 2.
  696. for class_id in range(self.online_n_classes):
  697. log_prob_cl = log_prob - np.repeat(log_prob[:, class_id][:, None], self.online_n_classes, axis=1)
  698. rel_prob[:, class_id] = 1. / np.sum(np.exp(log_prob_cl), axis=1)
  699. with np.errstate(divide='ignore',invalid='ignore'): # supress sklearn warning if probs are zero
  700. prob = self.clf1.predict_proba(self.online_features)
  701. # use either probabilites or class predictions
  702. sig = np.zeros((1,self.online_n_classes))
  703. if self.params.classifier.model_training.model == 'scikit' and self.params.classifier.peaks.sig == 'prob':
  704. sig = np.copy(prob)
  705. elif self.params.classifier.model_training.model == 'scikit' and self.params.classifier.peaks.sig == 'pred':
  706. pred_cl = int(self.clf1.predict(self.online_features))
  707. sig[0,pred_cl] = 1
  708. elif self.params.classifier.model_training.model == 'explicit':
  709. sig = np.copy(rel_prob)
  710. if np.any(sig[0]>self.params.classifier.thr_prob) and (self.block_phase.value == 2): # only if above thr and if in response state
  711. if self.params.classifier.peaks.sig == 'pred':
  712. self.online_history.append(int(np.argmax(np.array(sig[0])))) # 0: yes, 1:no, 2:baseline
  713. # log.warning(f'{sig}, {self.online_history[-1]}')
  714. else:
  715. self.online_history.append((np.array(sig[0]))) # 0: yes, 1:no, 2:baseline
  716. log.debug(sig)
  717. decoder_decision.value = -1
  718. # log.warning(f'sig:{sig}, prob: {prob}, {self.online_history}')
  719. # append only if prob cross threshold, and only if neural response has started
  720. # set decision if consecutive number of prob samples cross threshold
  721. if self.online_decision < 0 and len(self.online_history) >= self.params.classifier.thr_window:
  722. # set decision if same class appears in all samples within last thr_window
  723. # first if conditions: check if only decision in thr_window samples;
  724. if np.unique(self.online_history[-self.params.classifier.thr_window:]).size == 1 and (self.online_history[-1] < self.online_n_classes):
  725. if (self.online_n_classes ==3) and (self.online_history[-1] < 2): # DO NOT SEND DECISION FOR BASELINE
  726. decoder_decision.value = self.online_history[-1]
  727. self.online_history = []
  728. elif (self.online_n_classes ==2) and (self.online_history[-1] < 1):
  729. decoder_decision.value = self.online_history[-1]
  730. self.online_history = []
  731. if self.block_phase.value!=2:
  732. self.online_history = []
  733. # log.error(f'{self.online_decision}, {self.online_history}, {prob}')
  734. # log.debug(f'{prob}')
  735. self.online_sig = sig
  736. # self.online_sig = rel_prob
  737. return None #prob