Scheduled service maintenance on November 22

On Friday, November 22, 2024, between 06:00 CET and 18:00 CET, GIN services will undergo planned maintenance. Extended service interruptions should be expected. We will try to keep downtimes to a minimum, but recommend that users avoid critical tasks, large data uploads, or DOI requests during this time.

We apologize for any inconvenience. 44 KB

  1. import pickle
  2. import matplotlib.pyplot as plt
  3. import munch
  4. import numpy as np
  5. import scipy.signal as sgn
  6. from scipy import io
  7. from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
  8. from sklearn.externals import joblib
  9. from sklearn.model_selection import (StratifiedKFold, StratifiedShuffleSplit,
  10. cross_val_predict, cross_val_score,
  11. cross_validate)
  12. from sklearn.feature_selection import RFE
  13. from aux import log
  14. import aux
  15. class Classifier:
  16. def __init__(self, params, block_phase=[]):
  17. self.block_phase = block_phase
  18. self.init_buffer = 1 # this is to avoid decoding if buffer has not enough samples for full template
  19. try:
  20. res = self.set_params(params)
  21. except Exception as e:
  22. log.error(e)
  23. return None
  24. def set_params(self, params):
  25. params.classifier.template = np.array(params.classifier.template)
  26. params.classifier.n_features = params.classifier.template.size * params.daq.n_channels
  27. self.params = params
  28. if
  29. try:
  30. # self.clf = joblib.load('/kiap/data/classifier/monkey_model.joblib') # scikit LDA
  31. # self.clf2 = joblib.load('/kiap/data/classifier/monkey_model2.joblib') # explicit equations LDA
  32. self.clf1 = joblib.load(self.params.classifier.path_model1) # scikit LDA
  33. self.clf2 = joblib.load(self.params.classifier.path_model2) # explicit equations LDA
  34. except Exception as e:
  35. log.error('classifier model(s) not found. Send stop signal')
  36. = False
  37. raise e
  38. # online classifier
  39. self.online_n_classes = self.params.classifier.n_classes
  40. self.online_history = []
  41. self.online_template = self.params.classifier.template
  42. self.online_template = self.online_template-self.online_template.min()+1 # shift template since it's a sliding window
  43. self.online_template = - self.online_template
  44. # TODO: Check that code below is valid (params.daq.exclude_channels is 1-based!)
  45. self.channel_mask = [ii for ii in range(self.params.daq.n_channels_max) if ii + 1 not in self.params.daq.exclude_channels]
  46. self.channel_mask = [ii for ii in range(len(self.channel_mask)) if ii not in self.params.classifier.exclude_channels]
  47. self.online_n_ch = len(self.channel_mask)
  48. self.online_features = np.zeros((1, len(self.online_template) * self.online_n_ch))
  49. if
  50. if self.online_features.shape[1] != self.clf1.coef_.shape[1]:
  51. log.error(f'# of features:{self.online_features.shape[1]}, # of coef:{self.clf1.coef_.shape[1]}')
  52. log.error('Mismatch in # of features. No online decoding possible')
  53. = False
  54. self.online_decision = aux.decision.error2.value
  55. self.online_sig = [0] * self.online_n_classes
  56. return None
  57. if
  58. try:
  59. self.compare_params()
  60. except Exception as e:
  61. log.error(e)
  62. log.error('Parameter mismatch. Deactivating online classifier !')
  63. = False
  64. raise(e)
  65. return None
  66. def compare_params(self):
  67. p1 = self.params.classifier.copy()
  68. p2 = self.clf1.params.copy()
  69. p3 = self.clf2.params.copy()
  70. for par in p1.items():
  71. if par[0] not in ['online', 'path_model1', 'path_model2']:
  72. if np.any(p1[par[0]] != p2[par[0]]) or np.any(p1[par[0]] != p3[par[0]]):
  73. if par[0] == 'n_classes':
  74. log.warning(f'Parameter "{par[0]}" mismatch. Current {par[1]}, models: {p2[par[0]]}, {p3[par[0]]} ')
  75. raise ValueError('Parameter mismatch!')
  76. return None
  77. def train_LDA(self, data_train=[], baseline=0, features_fname='', model_fname1='', model_fname2='', solver='lsqr', save_model=0):
  78. '''train an LDA with sklearn and using explicit formulations, save the sklearn model, the means etc
  79. Parameters
  80. ----------
  81. data_train: 1d-list of arrays, shape: (n_classes, 1)
  82. arrays in list have shape: (n_features, n_trials_train)
  83. solver: lsqr, svd (lsqr is faster, but svd gives better results)
  84. '''
  85. log.debug('Training LDA model...')
  86. if data_train == []: # load data from file, data_train shape: (features x trials)
  87.'load features from file')
  88. with open(features_fname, 'rb') as f:
  89. data_train = pickle.load(f)
  90. else:
  91. log.debug('get features from classifier instance')
  92. n_features_data = data_train[0, 0].shape[0]
  93. n_features = self.params.classifier.n_features
  94. if n_features_data != n_features:
  95. n_features = n_features_data
  96. log.debug(f'Used the number of features from data: {n_features}')
  97. # n_trials_tot = self.params.classifier.n_trials_tot
  98. n_trials_tot = sum([data_train[ii, 0].shape[1] for ii in range(len(data_train))])
  99. n_classes = self.triggers_shape[1]
  100. # if baseline == 0: # add one for one additional class for the baseline
  101. # n_classes -= 1
  102. reg_fact = self.params.classifier.model_training.reg_fact
  103. X_tot = np.empty((0, n_features))
  104. y_tot = np.empty((0, 1))
  105. for ii in range(n_classes):
  106. X_tot = np.vstack((X_tot, data_train[ii, 0].T))
  107. y_tot = np.vstack((y_tot, data_train[ii, 0][0][:, None] * 00 + ii))
  108. y_tot = y_tot.flatten()
  109. means = np.zeros((n_features, n_classes))
  110. tmp_cov = np.zeros((n_features, n_trials_tot))
  111. # train model using sklearn LDA
  112. shrinkage = self.params.classifier.reg_fact
  113. if solver == 'svd':
  114. shrinkage = None
  115. if self.params.classifier.model_training.shrinkage == 'auto':
  116. shrinkage = 'auto'
  117. clf1 = LDA(solver=solver, shrinkage=shrinkage)
  119. test_size = self.params.classifier.model_training.test_size
  120. n_splits = self.params.classifier.model_training.n_splits
  121. clf1.scores = []
  122. if self.params.classifier.model_training.cross_validation:
  123. # train model using scikit LDA
  124. # use stratified shuffle to get trials from all classes
  125. kf = StratifiedShuffleSplit(n_splits=n_splits, test_size=test_size, random_state=None)
  126. for train_index, test_index in kf.split(X_tot, y_tot):
  127. X_train, X_test = X_tot[train_index], X_tot[test_index]
  128. y_train, y_test = y_tot[train_index], y_tot[test_index]
  129. # X_train = X_tot
  130. # y_train = y_tot
  131. for class_id in range(n_classes):
  132. log.debug(f'class {class_id} split samples: {np.sum(y_train == class_id)}')
  133. try:
  134., y_train)
  135. clf1.scores.append(clf1.score(X_test,y_test))
  136. except Exception as e:
  137. log.error('need more trials!')
  138. log.error(e)
  139. raise
  140. else:
  141. X_train = X_tot
  142. y_train = y_tot
  143. X_test = X_tot
  144. y_test = y_tot
  145., y_train)
  146. clf1.scores.append(clf1.score(X_test,y_test))
  147. if self.params.classifier.model_training.fsel:
  148. selector = RFE(clf1, 20, step=1) # get best features
  149. selector =, y_train)
  150. # train model using explicit LDA equations
  151. counter = 0
  152. for class_id in range(n_classes):
  153. means[:, class_id] = np.mean(X_train[y_train == class_id, :], axis=0)
  154. class_len = X_train[y_train == class_id, :].shape[0]
  155. tmp_cov[:, counter:counter + class_len] = X_train[y_train == class_id, :].T - np.repeat(means[:, class_id][:, None], class_len, axis=1)
  156. counter = counter + class_len
  157. cov_mat = np.cov(tmp_cov)
  158. # Cholesky factorization
  159. if (np.linalg.matrix_rank(cov_mat) < cov_mat.shape[0]) & (reg_fact == 0):
  160. reg_fact = 1e-6
  161. log.warning(f'Covariance matrix is singular! Using regularization {reg_fact}')
  162. cov_diag_mean = np.mean(np.diag(cov_mat))
  163. cov_mat_reg = (1 - reg_fact) * cov_mat + reg_fact * cov_diag_mean * np.eye(n_features)
  164. cov_mat_inv = np.linalg.pinv(cov_mat_reg)
  165. chol_mat = np.linalg.cholesky(cov_mat_inv).T
  166. clf2 = {'means': means, 'chol_mat': chol_mat}
  167. self.clf1 = clf1
  168. if self.params.classifier.model_training.fsel:
  169. self.clf1.ranking = selector.ranking_
  170. self.clf2 = munch.munchify(clf2)
  171. self.clf1.params = self.params.classifier.copy() # save all information for later comparison in online decoding
  172. self.clf2.params = self.params.classifier.copy() # save all information for later comparison in online decoding
  173. if save_model:
  174.'Saving model1 in {model_fname1}')
  175.'Saving model2 in {model_fname2}')
  176. joblib.dump(self.clf1, model_fname1) # scikit LDA
  177. joblib.dump(self.clf2, model_fname2) # explicit LDA
  178. return X_test, y_test, X_train, y_train
  179. def test_LDA(self, test_trials, session_id, compare_results=0):
  180. '''test model using test_trials, compare probabilities with the ones from matlab code
  181. Parameters
  182. ----------
  183. test_trials: (n_samples, n_features)
  184. '''
  185. n_features = self.params.classifier.n_features
  186. # n_trials_test = self.params.classifier.n_trials_test
  187. n_classes = self.params.classifier.n_triggers + 1
  188. for sess in range(session_id + 1, session_id + 2):
  189. if compare_results:
  190. classProb = io.loadmat('/kiap/data/tom/model/Q19_20131122/classProb{}.mat'.format(sess))['classProb']
  191. triggers0 = io.loadmat('/kiap/data/tom/model/Q19_20131122/triggers.mat'.format(sess))['hand_triggers']
  192. # calculate probabilities using lda from sklearn
  193. prob = self.clf1.predict_proba(test_trials)
  194. # calculate probabilities explicitly
  195. cumSesLen = test_trials.shape[0]
  196. log_prob = np.zeros((cumSesLen, n_classes))
  197. rel_prob = np.zeros((cumSesLen, n_classes))
  198. for class_id in range(n_classes):
  199. vect = test_trials - np.repeat(self.clf2['means'][:, class_id][:, None], cumSesLen, axis=1).T
  200. tmp = np.matmul(self.clf2['chol_mat'], vect.T)
  201. log_prob[:, class_id] = -np.sum(tmp**2, axis=0) / 2.
  202. for class_id in range(n_classes):
  203. log_prob_cl = log_prob - np.repeat(log_prob[:, class_id][:, None], n_classes, axis=1)
  204. rel_prob[:, class_id] = 1. / np.sum(np.exp(log_prob_cl), axis=1)
  205. if compare_results:
  206. self.triggers0 = triggers0
  207. self.classProb = classProb # imported from Matlab code
  208. self.rel_prob = rel_prob # computed here explicitly
  209. self.prob = prob # computed via sklearn
  210. # for class_id in range(n_classes-1):
  211. if compare_results:
  212. self.plot_results(sess, fig_number=1)
  213. return None
  214. def test_LDA_kiap(self, session_id, compare_results=0):
  215. '''test model using test_trials, compare probabilities with the ones from matlab code
  216. Parameters
  217. ----------
  218. test_trials: (n_samples, n_features)
  219. '''
  220. test_trials = self.test_trials
  221. n_features = self.params.classifier.n_features
  222. # n_trials_test = self.params.classifier.n_trials_test
  223. # n_classes = self.params.classifier.n_triggers + 1
  224. n_classes = self.params.classifier.n_classes
  225. for sess in range(session_id + 1, session_id + 2):
  226. # if compare_results:
  227. # classProb = io.loadmat('/kiap/data/tom/model/Q19_20131122/classProb{}.mat'.format(sess))['classProb']
  228. # triggers0 = io.loadmat('/kiap/data/tom/model/Q19_20131122/triggers.mat'.format(sess))['hand_triggers']
  229. # calculate probabilities using lda from sklearn
  230. prob = self.clf1.predict_proba(test_trials)
  231. # calculate probabilities explicitly
  232. cumSesLen = test_trials.shape[0]
  233. log_prob = np.zeros((cumSesLen, n_classes))
  234. rel_prob = np.zeros((cumSesLen, n_classes))
  235. for class_id in range(n_classes):
  236. vect = test_trials - np.repeat(self.clf2['means'][:, class_id][:, None], cumSesLen, axis=1).T
  237. tmp = np.matmul(self.clf2['chol_mat'], vect.T)
  238. log_prob[:, class_id] = -np.sum(tmp**2, axis=0) / 2.
  239. for class_id in range(n_classes):
  240. log_prob_cl = log_prob - np.repeat(log_prob[:, class_id][:, None], n_classes, axis=1)
  241. rel_prob[:, class_id] = 1. / np.sum(np.exp(log_prob_cl), axis=1)
  242. # rel_prob[:, class_id] = log_prob_cl # 20.02.19
  243. # if compare_results:
  244. # self.triggers0 = triggers0
  245. # self.classProb = classProb # imported from Matlab code
  246. self.rel_prob = rel_prob # computed here explicitly
  247. self.prob = prob # computed via sklearn
  248. # self.predict = self.clf1
  249. # for class_id in range(n_classes-1):
  250. if compare_results:
  251. self.plot_results_kiap(sess, fig_number=1)
  252. return None
  253. def plot_results(self, sess, compare_results=0, fig_number=1):
  254. col = plt.rcParams['axes.prop_cycle'].by_key()['color']
  255. self.triggers0 = io.loadmat('/kiap/data/tom/model/Q19_20131122/triggers.mat'.format(sess))['hand_triggers']
  256. # plt.figure(fig_number)
  257. # plt.clf()
  258. for class_id in range(self.params.classifier.n_triggers):
  259. triggers = np.hstack(self.triggers0[sess-1, class_id])
  260. ax1 = plt.subplot(311) # matlab code
  261. plt.plot(self.classProb[:, class_id], lw=1, alpha=0.5, color=col[class_id])
  262. plt.vlines(triggers - 601, 0, 2, color=col[class_id], lw=2)
  263. plt.xlim(0, 17500)
  264. plt.ylim(0, 1.2)
  265. plt.title('Session {}'.format(sess))
  266. ax2 = plt.subplot(312, sharex=ax1) # python
  267. plt.plot(self.rel_prob[:, class_id], lw=1, alpha=0.5, color=col[class_id])
  268. # plt.plot(prob[:, class_id], lw=1, alpha=0.5, color=col[class_id])
  269. plt.vlines(triggers - 601, 0, 2, color=col[class_id], lw=2)
  270. plt.xlim(0, 17500)
  271. plt.ylim(0, 1.2)
  272. ax3 = plt.subplot(313, sharex=ax1) # python sklearn
  273. plt.plot(self.prob[:, class_id], lw=1, alpha=0.5, color=col[class_id])
  274. plt.vlines(triggers - 601, 0, 2, color=col[class_id], lw=2)
  275. plt.xlim(0, 17500)
  276. plt.ylim(0, 1.2)
  277. plt.draw()
  279. input('press enter for next session')
  280. return None
  281. def plot_results_kiap(self, sess, compare_results=0, fig_number=1):
  282. col = plt.rcParams['axes.prop_cycle'].by_key()['color']
  283. # self.triggers0 = io.loadmat('/kiap/data/tom/model/Q19_20131122/triggers.mat'.format(sess))['hand_triggers']
  284. # plt.figure(fig_number)
  285. # plt.clf()
  286. for class_id in range(self.params.classifier.n_triggers):
  287. # triggers = np.hstack(self.triggers0[sess-1, class_id])
  288. triggers = np.hstack(self.triggers[sess-1, class_id])
  289. ax1 = plt.subplot(311) # matlab code
  290. # plt.plot(self.classProb[:, class_id], lw=1, alpha=0.5, color=col[class_id])
  291. # plt.vlines(triggers - 601, 0, 2, color=col[class_id], lw=2)
  292. # plt.xlim(0, 17500)
  293. # plt.ylim(0, 1.2)
  294. # plt.title('Session {}'.format(sess))
  295. ax2 = plt.subplot(312, sharex=ax1) # python
  296. plt.plot(self.rel_prob[:, class_id], lw=1, alpha=0.5, color=col[class_id])
  297. # plt.plot(prob[:, class_id], lw=1, alpha=0.5, color=col[class_id])
  298. plt.vlines(triggers - 601, 0, 2, color=col[class_id], lw=2)
  299. plt.xlim(0, 500)
  300. plt.ylim(0, 1.2)
  301. ax3 = plt.subplot(313, sharex=ax1) # python sklearn
  302. plt.plot(self.prob[:, class_id], lw=1, alpha=0.5, color=col[class_id])
  303. plt.vlines(triggers - 601, 0, 2, color=col[class_id], lw=2)
  304. plt.xlim(0, 500)
  305. plt.ylim(0, 1.2)
  306. plt.draw()
  308. input('press enter for next session')
  309. return None
  310. def get_trials_kiap(self, train_data, triggers, features_fname, save_features=0):
  311. log.debug('Computing feature matrix from training data ...')
  312. n_train_sess = train_data.size
  313. n_ch = train_data[0][0].shape[1]
  314. # n_triggers = self.params.classifier.n_triggers
  315. n_triggers = triggers.shape[1]
  316. n_classes = n_triggers
  317. deadtime = self.params.classifier.deadtime # number of samplepoints to consider around each trigger
  318. # n_neg_train = self.params.classifier.n_neg_train # number of sample-points to consider for baseline
  319. template = self.params.classifier.template
  320. log.debug(f'# of channels: {n_ch}')
  322. n_pos_train = np.zeros((n_triggers,), dtype=int) # total number of triggers per trigger type
  323. for trg_ind in range(n_triggers):
  324. n_pos_train[trg_ind] = np.concatenate(triggers[:, trg_ind], axis=1).size
  325. n_train_tot = 0
  326. for sess in range(n_train_sess):
  327. n_train_tot += train_data[sess, 0].shape[0] # total number of samples
  329. # --------------------------
  330. train_trials = [0] * n_classes
  331. psth = [0] * n_classes
  332. cut_win = self.params.classifier.psth.cut[1] - self.params.classifier.psth.cut[0]
  333. for trg_id in range(n_triggers):
  334. train_trials[trg_id] = np.zeros((n_pos_train[trg_id], len(template) * n_ch))
  335. # train_trials_f = np.zeros((n_neg_train, len(template) * n_ch)) # optimize code by avoiding inner loop over channels
  336. trialCounter = [1] * n_classes
  337. for sess in range(n_train_sess):
  338. for trg_id in range(n_triggers):
  339. psth[trg_id] = np.zeros((triggers[sess, trg_id].size, cut_win, n_ch))
  340. print(psth[trg_id].shape)
  342. for trg_id in range(n_triggers):
  343. for ii in range(triggers[sess, trg_id].size): # skip ii if sample points are < 0 or > max size of available data
  344. if (triggers[sess, trg_id][0, ii] + min(template) <= 0) or (triggers[sess, trg_id][0, ii] + max(template) > train_data[sess, 0].shape[0]):
  345. log.debug('exluded {} {} {}'.format(sess, trg_id, ii))
  346. continue
  347. for ch in range(0, n_ch):
  348. idx = list(range(ch * len(template), (ch + 1) * len(template))) # CAUTION: why -1 below? is trigger absolute index?
  349. train_trials[trg_id][trialCounter[trg_id] - 1, idx] = train_data[sess, 0][triggers[sess, trg_id][0, ii] + template - 1, ch]
  350. try:
  351. psth[trg_id][ii, :, ch] = train_data[sess, 0][triggers[sess, trg_id][0, ii]+self.params.classifier.psth.cut[0]:triggers[sess, trg_id][0, ii]+self.params.classifier.psth.cut[1], ch]
  352. except:
  353. log.debug(f'Trial {ii} out of limits')
  354. trialCounter[trg_id] = trialCounter[trg_id] + 1
  355. log.debug('feature for classes extracted')
  356. log.debug(f'Session {sess}, trialCounter {trialCounter}')
  357. if trialCounter == [1] * n_classes:
  358. raise ValueError('No feature trial(s) extracted, because triggers are too close to the edges!')
  359. # bring data into correct format to use afterwards
  360. train_trials2 = np.zeros((n_classes, 1), dtype='O')
  361. for cl in range(n_classes):
  362. train_trials2[cl, 0] = train_trials[cl][:trialCounter[cl] - 1, :].T # remove trials if they were outside margins. to be confirmed !
  363. if features_fname>'':
  364. log.debug('Saving feature array...')
  365. with open(features_fname, 'wb') as f:
  366. pickle.dump(train_trials2, f)
  367. self.train_trials = train_trials2
  368. self.psth = psth
  369. self.psth_xx = range(self.params.classifier.psth.cut[0], self.params.classifier.psth.cut[1])
  370. log.debug(f'feature-array session 0 shape: {train_trials2[0,0].shape}')
  371. return None
  372. def get_trials_kiap_bl(self, train_data, triggers, save_features=0):
  373. log.debug('Computing feature matrix from training data ...')
  374. n_train_sess = train_data.size
  375. n_ch = train_data[0][0].shape[1]
  376. # n_triggers = self.params.classifier.n_triggers
  377. n_triggers = triggers.shape[1]
  378. n_classes = n_triggers
  379. deadtime = self.params.classifier.deadtime # number of samplepoints to consider around each trigger
  380. # n_neg_train = self.params.classifier.n_neg_train # number of sample-points to consider for baseline
  381. template = self.params.classifier.template
  382. log.debug(f'# of channels: {n_ch}')
  384. n_pos_train = np.zeros((n_triggers,), dtype=int) # total number of triggers per trigger type
  385. for trg_ind in range(n_triggers):
  386. n_pos_train[trg_ind] = np.concatenate(triggers[:, trg_ind], axis=1).size
  387. n_train_tot = 0
  388. for sess in range(n_train_sess):
  389. n_train_tot += train_data[sess, 0].shape[0] # total number of samples
  391. # --------------------------
  392. train_trials = [0] * n_classes
  393. psth = [0] * n_classes
  394. cut_win = self.params.classifier.psth.cut[1] - self.params.classifier.psth.cut[0]
  395. for trg_id in range(n_triggers):
  396. train_trials[trg_id] = np.zeros((n_pos_train[trg_id], len(template) * n_ch))
  397. # train_trials_f = np.zeros((n_neg_train, len(template) * n_ch)) # optimize code by avoiding inner loop over channels
  398. trialCounter = [1] * n_classes
  399. for sess in range(n_train_sess):
  400. for trg_id in range(n_triggers):
  401. psth[trg_id] = np.zeros((triggers[sess, trg_id].size, cut_win, n_ch))
  402. print(psth[trg_id].shape)
  404. for trg_id in range(n_triggers):
  405. for ii in range(triggers[sess, trg_id].size): # skip ii if sample points are < 0 or > max size of available data
  406. if (triggers[sess, trg_id][0, ii] + min(template) <= 0) or (triggers[sess, trg_id][0, ii] + max(template) > train_data[sess, 0].shape[0]):
  407. log.debug('exluded {} {} {}'.format(sess, trg_id, ii))
  408. continue
  409. for ch in range(0, n_ch):
  410. idx = list(range(ch * len(template), (ch + 1) * len(template))) # CAUTION: why -1 below? is trigger absolute index?
  411. train_trials[trg_id][trialCounter[trg_id] - 1, idx] = train_data[sess, 0][triggers[sess, trg_id][0, ii] + template - 1, ch]
  412. try:
  413. psth[trg_id][ii, :, ch] = train_data[sess, 0][triggers[sess, trg_id][0, ii]+self.params.classifier.psth.cut[0]:triggers[sess, trg_id][0, ii]+self.params.classifier.psth.cut[1], ch]
  414. except:
  415. log.debug(f'Trial {ii} out of limits')
  416. trialCounter[trg_id] = trialCounter[trg_id] + 1
  417. log.debug('feature for classes extracted')
  418. log.debug(f'Session {sess}, trialCounter {trialCounter}')
  419. if trialCounter == [1] * n_classes:
  420. raise ValueError('No feature trial(s) extracted, because triggers are too close to the edges!')
  421. # bring data into correct format to use afterwards
  422. train_trials2 = np.zeros((n_classes, 1), dtype='O')
  423. for cl in range(n_classes):
  424. train_trials2[cl, 0] = train_trials[cl][:trialCounter[cl] - 1, :].T # remove trials if they were outside margins. to be confirmed !
  425. self.train_trials = train_trials2
  426. self.psth = psth
  427. self.psth_xx = range(self.params.classifier.psth.cut[0], self.params.classifier.psth.cut[1])
  428. log.debug(f'feature-array session 0 shape: {train_trials2[0,0].shape}')
  429. return None
  430. def get_trials(self, train_data, triggers, features_fname='', train_trials_comparison=[], assert_results=0):
  431. '''import spike rates, extract the trials and construct features.
  432. it yields identical results for the monkey data to the matlab code provided by Tom
  433. Parameters
  434. ----------
  435. train_data: (n_train_sess,1)
  436. train_data[ii,0]: (n_samples, n_ch) per element
  437. triggers: (n_train_sess, n_triggers)
  438. Saves as object variable
  439. ------------------------
  440. self.train_trials: (n_classes, 1) with (n_feat, n_trials) per class
  441. train_trials_comparison this file is imported only to assert that the resulting train_trials is equivalent to Matlab code
  442. Note: set assert_results=1 to compare with results from Matlab code provided by Tom
  443. '''
  444.'Computing feature matrix from training data ...')
  445. n_train_sess = train_data.size
  446. n_ch = train_data[0][0].shape[1]
  447. # n_triggers = self.params.classifier.n_triggers
  448. n_triggers = triggers.shape[1]
  449. n_classes = n_triggers + 1
  450. deadtime = self.params.classifier.deadtime # number of samplepoints to consider around each trigger
  451. n_neg_train = self.params.classifier.n_neg_train # number of sample-points to consider for baseline
  452. # template = [0, -150, -300, -450, -600]
  453. template = self.params.classifier.template
  455. n_pos_train = np.zeros((n_triggers,), dtype=int) # total number of triggers per trigger type
  456. for trg_ind in range(n_triggers):
  457. n_pos_train[trg_ind] = np.concatenate(triggers[:, trg_ind], axis=1).size
  458. n_train_tot = 0
  459. for sess in range(n_train_sess):
  460. n_train_tot += train_data[sess, 0].shape[0] # total number of samples
  462. neg_train = np.zeros((2, n_train_tot), dtype=int) # row 0: session number, row 1: baseline indices for this session
  463. counter = 0
  464. for sess in range(n_train_sess): # build index with sample-points to exclude from baseline
  465. idx1 = list(range(deadtime + 1)) # exclude first and last samples of each session of length deadtime
  466. idx1.extend(list(range(train_data[sess, 0].shape[0] - deadtime, train_data[sess, 0].shape[0] + 1)))
  467. for trg_ind in range(n_triggers): # exclude also +/- deadtime samples around each trigger
  468. for tr in range(triggers[sess, trg_ind].size):
  469. idx1.extend(range(triggers[sess, trg_ind][0, tr] - deadtime, triggers[sess, trg_ind][0, tr] + deadtime + 1))
  470. tmpNegStarts = list(set(range(train_data[sess, 0].shape[0])) - set(idx1))
  471. neg_train[0, counter:counter + len(tmpNegStarts)] = sess
  472. neg_train[1, counter:counter + len(tmpNegStarts)] = tmpNegStarts
  473. counter = counter + len(tmpNegStarts)
  474. neg_train = neg_train[:, :counter] # remove zero elements from the end
  475. # note that 2nd row has to be modified for neg_train, it is being taken care of below
  476. # print(negTrain)
  477. # assert(np.array_equiv(negTrain, neg_train))
  478. for ii in range(8):
  479. log.debug(f'session {ii}, # baseline indices: {sum(neg_train[0, :] == ii)}')
  480. # limit the number of sample points for baseline
  481. if neg_train.shape[1] > n_neg_train:
  482. perm_idx = np.random.permutation(range(neg_train.shape[1]))
  483. neg_train = neg_train[:, perm_idx[:n_neg_train]]
  484. else:
  485. n_neg_train = neg_train.shape[1]
  487. # --------------------------
  488. train_trials = [0] * n_classes
  489. for trg_id in range(n_triggers):
  490. train_trials[trg_id] = np.zeros((n_pos_train[trg_id], len(template) * n_ch))
  491. train_trials[n_classes - 1] = np.zeros((n_neg_train, len(template) * n_ch))
  492. train_trials_f = np.zeros((n_neg_train, len(template) * n_ch)) # optimize code by avoiding inner loop over channels
  493. trialCounter = [1] * n_classes
  494. for sess in range(n_train_sess):
  496. for trg_id in range(n_triggers):
  497. for ii in range(triggers[sess, trg_id].size): # skip ii if sample points are < 0 or > max size of available data
  498. if (triggers[sess, trg_id][0, ii] + min(template) <= 0) or (triggers[sess, trg_id][0, ii] + max(template) > train_data[sess, 0].shape[0]):
  499. log.debug('exluded {} {} {}'.format(sess, trg_id, ii))
  500. continue
  501. for ch in range(0, n_ch):
  502. idx = list(range(ch * len(template), (ch + 1) * len(template))) # CAUTION: why -1 below? is trigger absolute index?
  503. train_trials[trg_id][trialCounter[trg_id] - 1, idx] = train_data[sess, 0][triggers[sess, trg_id][0, ii] + template - 1, ch]
  504. if assert_results:
  505. assert(np.array_equiv(train_trials[trg_id][trialCounter[trg_id] - 1, idx], train_trials_comparison[trg_id, 0].T[trialCounter[trg_id] - 1, idx]))
  506. trialCounter[trg_id] = trialCounter[trg_id] + 1
  507. log.debug('feature for classes extracted')
  509. neg_ind = np.where(neg_train[0, :] == sess)[0]
  510. for ii in range(0, len(neg_ind)): # throw away indices if they fall outside margins
  511. if (neg_train[1, neg_ind[ii]] + min(template) <= 0) or (neg_train[1, neg_ind[ii]] + max(template) > train_data[sess, 0].shape[0]):
  512. continue
  513. # print(ii)
  514. # for ch in range(0, n_ch):
  515. # idx = list(range(ch * len(template), (ch + 1) * len(template)))
  516. # train_trials[n_classes - 1][trialCounter[n_classes - 1]-1, idx] = train_data[sess, 0][neg_train[1, neg_ind[ii]] + template-1, ch]
  517. # # print(sess, ii, ch, 'neg trials')
  518. # if assert_results:
  519. # assert(np.array_equiv(train_trials[n_classes - 1][trialCounter[n_classes - 1] -1, idx], train_trials_comparison[n_classes - 1, 0].T[trialCounter[n_classes - 1] - 1, idx]))
  520. # avoid inner loop to speed up calculations
  521. fff = train_data[sess, 0][neg_train[1, neg_ind[ii]] + template-1, :].flatten('F')
  522. # fff = fff.flatten('F')
  523. train_trials_f[trialCounter[n_classes - 1] - 1] = fff
  524. trialCounter[n_classes - 1] = trialCounter[n_classes - 1] + 1
  525. train_trials[n_classes-1] = train_trials_f
  526. # print(np.array_equiv(train_trials_f, train_trials[n_classes - 1]))
  527. # assert(np.array_equiv(train_trials_f, train_trials[n_classes - 1]))
  528. log.debug('feature for baseline extracted')
  529. log.debug(f'Session {sess}, trialCounter {trialCounter}')
  530. # bring data into correct format to use afterwards
  531. train_trials2 = np.zeros((n_classes, 1), dtype='O')
  532. for cl in range(n_classes):
  533. train_trials2[cl, 0] = train_trials[cl][:trialCounter[cl]-1, :].T # remove trials if they were outside margins. to be confirmed !
  534. if features_fname>'':
  535.'Saving feature array...')
  536. with open(features_fname, 'wb') as f:
  537. pickle.dump(train_trials2, f)
  538. self.train_trials = train_trials2
  539. return None
  540. def get_test_trials(self, test_data):
  541. '''this function imports spike rates, extracts the trials for TESTING and constructs features.'''
  542. # data_tot = io.loadmat('/kiap/data/tom/model/trainData_rates2.mat')['trainData']
  543. # test_data = data_tot[session_id:session_id + 1]
  544. # # this file is imported only to assert that test_trials is equivalent to dt
  545. # dt = io.loadmat('/home/vlachos/devel/iv_misc/decoding/tom/scripts/model/Q19_20131122/testTrials.mat')['testTrials']
  546. n_test_sess = test_data.size # number of test sessions
  547. n_ch = test_data[0][0].shape[1]
  548. n_triggers = self.params.classifier.n_triggers
  549. n_classes = n_triggers + 1
  550. deadtime = self.params.classifier.deadtime # number of samplepoints to consider around each trigger
  551. n_neg_train = self.params.classifier.n_neg_train # number of sample-points to consider for baseline
  552. template = self.params.classifier.template
  553. n_test = 0
  554. for sess in range(n_test_sess): # total number of sample points
  555. n_test = n_test + test_data[sess, 0].shape[0]
  556. # % Extracting testing trials
  557. test_trials = np.zeros((n_test, len(template) * n_ch))
  558. test_ind = np.zeros((n_test, 1), dtype=int)
  559. test_cut = np.zeros((n_test, 1))
  560. cum_ses_len = 0
  561. for sess in range(n_test_sess):
  562. start_sess = 1 - min(np.append(template, 0)) # get start index
  563. # start_sess = 0 - min(np.append(template, 0)) # get start index
  564. end_sess = test_data[sess, 0].shape[0] - max(np.append(template, 1)) + 1 # get end index
  565. sess_size = end_sess - start_sess + 1
  566. if sess_size <= 0:
  567. continue
  568. test_ind[cum_ses_len:cum_ses_len + sess_size, 0] = np.arange(start_sess, end_sess + 1) + max(np.append(template, 1)) - 1
  569. test_cut[cum_ses_len:cum_ses_len + sess_size, 0] = sess
  570. for ch in range(n_ch):
  571. for ii in range(len(template)):
  572. tmp = test_data[sess, 0][np.arange(start_sess - 1, end_sess) + template[ii], ch]
  573. test_trials[cum_ses_len:cum_ses_len + sess_size, ch * len(template) + ii] = tmp
  574. # print(tmp[-10:-1])
  575. # inputData{sess}((startSes:endSess) + template_detection(ii),ch);
  576. cum_ses_len = cum_ses_len + sess_size
  577. test_trials = test_trials[: cum_ses_len, :]
  578. test_ind = test_ind[:cum_ses_len]
  579. test_cut = test_cut[:cum_ses_len]
  580. return test_trials
  581. def get_test_trials_kiap(self, test_data):
  582. '''this function imports spike rates, extracts the trials for TESTING and constructs features.'''
  583. # data_tot = io.loadmat('/kiap/data/tom/model/trainData_rates2.mat')['trainData']
  584. # test_data = data_tot[session_id:session_id + 1]
  585. # # this file is imported only to assert that test_trials is equivalent to dt
  586. # dt = io.loadmat('/home/vlachos/devel/iv_misc/decoding/tom/scripts/model/Q19_20131122/testTrials.mat')['testTrials']
  587. n_test_sess = test_data.size # number of test sessions
  588. n_ch = test_data[0][0].shape[1]
  589. # n_triggers = self.params.classifier.n_triggers
  590. # n_classes = n_triggers + 1
  591. # deadtime = self.params.classifier.deadtime # number of samplepoints to consider around each trigger
  592. # n_neg_train = self.params.classifier.n_neg_train # number of sample-points to consider for baseline
  593. template = self.params.classifier.template
  594. template = template-template.min()+1
  595. n_test = 0
  596. for sess in range(n_test_sess): # total number of sample points
  597. n_test = n_test + test_data[sess, 0].shape[0]
  598. # % Extracting testing trials
  599. test_trials = np.zeros((n_test, len(template) * n_ch))
  600. test_ind = np.zeros((n_test, 1), dtype=int)
  601. test_cut = np.zeros((n_test, 1))
  602. cum_ses_len = 0
  603. for sess in range(n_test_sess):
  604. start_sess = 1 - min(np.append(template, 0)) # get start index
  605. # start_sess = 0 - min(np.append(template, 0)) # get start index
  606. end_sess = test_data[sess, 0].shape[0] - max(np.append(template, 1)) + 1 # get end index
  607. sess_size = end_sess - start_sess + 1
  608. if sess_size <= 0:
  609. continue
  610. test_ind[cum_ses_len:cum_ses_len + sess_size, 0] = np.arange(start_sess, end_sess + 1) + max(np.append(template, 1)) - 1
  611. test_cut[cum_ses_len:cum_ses_len + sess_size, 0] = sess
  612. for ch in range(n_ch):
  613. for ii in range(len(template)):
  614. tmp = test_data[sess, 0][np.arange(start_sess - 1, end_sess) - template[ii]-1, ch] # negative template !!!
  615. test_trials[cum_ses_len:cum_ses_len + sess_size, ch * len(template) + ii] = tmp
  616. # print(tmp[-10:-1])
  617. # inputData{sess}((startSes:endSess) + template_detection(ii),ch);
  618. cum_ses_len = cum_ses_len + sess_size
  619. test_trials = test_trials[: cum_ses_len, :]
  620. test_ind = test_ind[:cum_ses_len]
  621. test_cut = test_cut[:cum_ses_len]
  622. self.test_trials = test_trials
  623. return None
  624. def online_decoder(self, cur_data, session_id=0):
  625. plt.cla()
  626. # plt.ion()
  627. plt.ylim(0, 1)
  628. plt.xlim(0, 10000)
  629. # plt.ion()
  630. # td = cur_data[sess, 0]
  631. for jj in range(600, 2000, 20):
  632. self.get_class(cur_data[jj-600:jj], jj)
  633. return None
  634. def get_class(self, cur_data, tt):
  635. '''used for Tom's data'''
  636. # log.debug(cur_data.shape)
  637. n_ch = cur_data.shape[1]
  638. n_triggers = 4
  639. n_classes = n_triggers + 1
  640. template = [0, -150, -300, -450, -600]
  641. n_test=1
  642. test_trials = np.zeros((n_test, len(template) * n_ch))
  643. cum_ses_len = 0
  644. sess_size = 1
  645. for ch in range(n_ch):
  646. for ii in range(len(template)):
  647. tmp_data = cur_data[template[ii], ch]
  648. test_trials[cum_ses_len:cum_ses_len + sess_size, ch * len(template) + ii] = tmp_data
  649. col = plt.rcParams['axes.prop_cycle'].by_key()['color']
  650. cumSesLen = 1
  651. log_prob = np.zeros((cumSesLen, n_classes))
  652. rel_prob = np.zeros((cumSesLen, n_classes))
  653. for class_id in range(n_classes):
  654. vect = test_trials - np.repeat(self.clf2['means'][:, class_id][:, None], cumSesLen, axis=1).T
  655. tmp = np.matmul(self.clf2['chol_mat'], vect.T)
  656. log_prob[:, class_id] = -np.sum(tmp**2, axis=0) / 2.
  657. for class_id in range(n_classes):
  658. log_prob_cl = log_prob - np.repeat(log_prob[:, class_id][:, None], n_classes, axis=1)
  659. rel_prob[:, class_id] = 1. / np.sum(np.exp(log_prob_cl), axis=1)
  660. if self.params.system.plot:
  661. for class_id in range(self.params.classifier.n_classes-1):
  662. plt.plot(tt-600, rel_prob[:, class_id], lw=1, alpha=0.5, color=col[class_id], marker='.')
  663. plt.pause(.01)
  664. # print(rel_prob.shape)
  665. # print(tt)
  666. # plt.draw()
  667. #
  668. # input('press enter')
  669. # fig = plt.gcf()
  670. # fig.canvas.mpl_connect('close_event', handle_close)
  671. my_decision = int(np.any(rel_prob[0, 0:4] > 0.7))
  672. return my_decision, rel_prob
  673. def get_class2(self, cur_data, tt, decoder_decision):
  674. '''used for data from NSP'''
  675. cur_data = cur_data[:, self.channel_mask] # get only channels used during training
  676. if self.init_buffer: #self.params.classifier.thr_window:
  677. self.online_decision = aux.decision.error1.value # not enough data at the beginning
  678. self.online_sig = [0] * self.online_n_classes
  679. # log.warning(f'not enough data, online_sig: {self.online_sig}')
  680. return None
  681. # log.warning(f'cur_data: {cur_data.shape}, {cur_data[:1,:3]}')
  682. for ch in range(self.online_n_ch):
  683. for ii in range(len(self.online_template)-1,-1,-1):
  684. # log.warning(f'----> {ii} {self.online_template[ii]}')
  685. tmp_data = cur_data[self.online_template[ii], ch]
  686. # log.warning(f'self.online_template: {self.online_template}, {ii}, {self.online_template[ii]}')
  687. self.online_features[0:1, ch * len(self.online_template) + ii] = tmp_data
  688. col = plt.rcParams['axes.prop_cycle'].by_key()['color']
  689. cumSesLen = 1
  690. log_prob = np.zeros((cumSesLen, self.online_n_classes))
  691. rel_prob = np.zeros((cumSesLen, self.online_n_classes))
  692. for class_id in range(self.online_n_classes):
  693. vect = self.online_features - np.repeat(self.clf2['means'][:, class_id][:, None], cumSesLen, axis=1).T
  694. tmp = np.matmul(self.clf2['chol_mat'], vect.T)
  695. log_prob[:, class_id] = -np.sum(tmp**2, axis=0) / 2.
  696. for class_id in range(self.online_n_classes):
  697. log_prob_cl = log_prob - np.repeat(log_prob[:, class_id][:, None], self.online_n_classes, axis=1)
  698. rel_prob[:, class_id] = 1. / np.sum(np.exp(log_prob_cl), axis=1)
  699. with np.errstate(divide='ignore',invalid='ignore'): # supress sklearn warning if probs are zero
  700. prob = self.clf1.predict_proba(self.online_features)
  701. # use either probabilites or class predictions
  702. sig = np.zeros((1,self.online_n_classes))
  703. if self.params.classifier.model_training.model == 'scikit' and self.params.classifier.peaks.sig == 'prob':
  704. sig = np.copy(prob)
  705. elif self.params.classifier.model_training.model == 'scikit' and self.params.classifier.peaks.sig == 'pred':
  706. pred_cl = int(self.clf1.predict(self.online_features))
  707. sig[0,pred_cl] = 1
  708. elif self.params.classifier.model_training.model == 'explicit':
  709. sig = np.copy(rel_prob)
  710. if np.any(sig[0]>self.params.classifier.thr_prob) and (self.block_phase.value == 2): # only if above thr and if in response state
  711. if self.params.classifier.peaks.sig == 'pred':
  712. self.online_history.append(int(np.argmax(np.array(sig[0])))) # 0: yes, 1:no, 2:baseline
  713. # log.warning(f'{sig}, {self.online_history[-1]}')
  714. else:
  715. self.online_history.append((np.array(sig[0]))) # 0: yes, 1:no, 2:baseline
  716. log.debug(sig)
  717. decoder_decision.value = -1
  718. # log.warning(f'sig:{sig}, prob: {prob}, {self.online_history}')
  719. # append only if prob cross threshold, and only if neural response has started
  720. # set decision if consecutive number of prob samples cross threshold
  721. if self.online_decision < 0 and len(self.online_history) >= self.params.classifier.thr_window:
  722. # set decision if same class appears in all samples within last thr_window
  723. # first if conditions: check if only decision in thr_window samples;
  724. if np.unique(self.online_history[-self.params.classifier.thr_window:]).size == 1 and (self.online_history[-1] < self.online_n_classes):
  725. if (self.online_n_classes ==3) and (self.online_history[-1] < 2): # DO NOT SEND DECISION FOR BASELINE
  726. decoder_decision.value = self.online_history[-1]
  727. self.online_history = []
  728. elif (self.online_n_classes ==2) and (self.online_history[-1] < 1):
  729. decoder_decision.value = self.online_history[-1]
  730. self.online_history = []
  731. if self.block_phase.value!=2:
  732. self.online_history = []
  733. # log.error(f'{self.online_decision}, {self.online_history}, {prob}')
  734. # log.debug(f'{prob}')
  735. self.online_sig = sig
  736. # self.online_sig = rel_prob
  737. return None #prob