data_management.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. '''
  2. description: script to read header and data from data.bin
  3. author: Ioannis Vlachos
  4. date: 02.11.18
  5. Copyright (c) 2018 Ioannis Vlachos.
  6. All rights reserved.
  7. HEADER OF BINARY FILE
  8. ---------------------
  9. laptop timestamp np.datetime64 bytes: 8
  10. NSP timestamp np.int64 bytes: 8
  11. number of bytes np.int64 bytes: 8
  12. number of samples np.int64 bytes: 8
  13. number of channels np.int64 bytes: 8
  14. '''
  15. import csv
  16. import datetime as dt
  17. import glob
  18. import os
  19. import sys
  20. import matplotlib.pyplot as plt
  21. import numpy as np
  22. from numpy import datetime64 as dt64
  23. from tabulate import tabulate
  24. import munch
  25. import aux
  26. from aux import log
  27. params = aux.load_config()
  28. # def get_event_name(filename, events_file_names)
  29. def get_raw(verbose=0, n_triggers=2, exploration=False, feedback=False, triggers_all=False, fids=[], trigger_pos='start'):
  30. '''read raw data from one or more binary files
  31. trigger_pos: ['start, 'stop']'''
  32. params = aux.load_config()
  33. n_channels = params.daq.n_channels
  34. file_names = []
  35. # events_file_names = []
  36. print('Available binary data files:\n')
  37. for ii, file_name in enumerate(sorted(glob.iglob(params.file_handling.data_path + '**/*.bin', recursive=True))):
  38. print(f'{ii}, {file_name}, {os.path.getsize(file_name)//1000}K')
  39. file_names.append(file_name)
  40. # for ii, file_name in enumerate(sorted(glob.iglob(params.file_handling.data_path + '**/*.txt', recursive=True))):
  41. # events_file_names.append(file_name)
  42. if fids ==[]:
  43. fids = [int(x) for x in input('\nSelect file ids (separated by space): ').split()] # select file ids
  44. if fids == []:
  45. fids = [len(file_names)-1]
  46. print(f'Selected: {fids}')
  47. data_tot = np.empty((len(fids), 1), dtype=object)
  48. time_stamps_tot = np.empty((len(fids), 1), dtype=object)
  49. if triggers_all:
  50. triggers_tot = np.empty((len(fids), 6), dtype=object)
  51. elif exploration:
  52. triggers_tot = np.empty((len(fids), 1), dtype=object)
  53. elif feedback:
  54. triggers_tot = np.empty((len(fids), 2), dtype=object) # up and down
  55. else:
  56. triggers_tot = np.empty((len(fids), n_triggers), dtype=object)
  57. ch_rec_list_tot = np.empty((len(fids), 1), dtype=object)
  58. for ii, fid in enumerate(fids): # go through all sessions
  59. data, time_stamps, ch_rec_list = get_session(file_names[fid], verbose=verbose)
  60. data = np.delete(data, params.classifier.exclude_data_channels, axis=1)
  61. # triggers = get_triggers(os.path.dirname(file_names[fid])+'/events.txt', time_stamps, n_triggers)
  62. event_file_name = file_names[fid].replace('data_','events_').replace('.bin','.txt')
  63. info_file_name = file_names[fid].replace('data_','info_').replace('.bin','.log')
  64. if triggers_all:
  65. triggers = get_triggers_all(event_file_name, time_stamps, trigger_pos)
  66. elif exploration:
  67. triggers = get_triggers_exploration(event_file_name, time_stamps)
  68. elif feedback:
  69. triggers = get_triggers_feedback(event_file_name, time_stamps, trigger_pos, n_triggers)
  70. else:
  71. triggers = get_triggers(event_file_name, time_stamps, trigger_pos, n_triggers)
  72. # yes_mask, no_mask = get_accuracy(info_file_name,n_triggers)
  73. # triggers[0] = np.where(yes_mask,triggers[0],'')
  74. # triggers[1] = np.where(no_mask,triggers[1],'')
  75. data_tot[ii, 0] = data
  76. time_stamps_tot[ii, 0] = time_stamps
  77. if exploration:
  78. triggers_tot[ii,0] = triggers
  79. else:
  80. triggers_tot[ii, :] = triggers
  81. ch_rec_list_tot[ii, 0] = ch_rec_list
  82. # triggers_tot[ii,1] = triggers[1]
  83. print(f'\nRead binary neural data from file {file_names[fid]}')
  84. print(f'Read trigger info events from file {event_file_name}')
  85. print(f'Session {ii}, data shape: {data.shape}')
  86. config = read_config('paradigm.yaml')
  87. states = config.exploration.states
  88. for jj, _ in enumerate(fids):
  89. print()
  90. if triggers_all:
  91. for ii in range(6):
  92. print(f'Session {fids[jj]}, Class {ii}, trigger #: {triggers_tot[jj,ii].shape}')
  93. elif exploration:
  94. for ii in range(len(triggers_tot[jj,0][0])):
  95. print(f'Session {fids[jj]}, State {states[ii]}, trigger #: {triggers_tot[jj,0][0][ii].shape}')
  96. else:
  97. for ii in range(n_triggers):
  98. print(f'Session {fids[jj]}, Class {ii}, trigger #: {triggers_tot[jj,ii].shape}')
  99. file_names = [file_names[ii] for ii in fids]
  100. return data_tot, time_stamps_tot, triggers_tot, ch_rec_list_tot, file_names
  101. def get_session(file_name, verbose=0):
  102. ii = 0
  103. # data = np.empty((0, params.daq.n_channels-len(params.daq.exclude_channels)))
  104. data = np.empty((0, params.daq.n_channels_max))
  105. log.info(f'Data shape: {data.shape}')
  106. time_stamps = []
  107. with open(file_name, 'rb') as fh:
  108. print(f'\nReading binary file {file_name}...\n')
  109. while True:
  110. tmp = np.frombuffer(fh.read(8), dtype='datetime64[us]') # laptop timestamp
  111. if tmp.size == 0:
  112. print(f'Imported session {ii}')
  113. return data, np.array(time_stamps, dtype=np.uint), ch_rec_list
  114. else:
  115. t_now1 =tmp
  116. t_now2 = np.frombuffer(fh.read(8), dtype=np.int64)[0] # NSP timestamp
  117. n_bytes = int.from_bytes(fh.read(8), byteorder='little') # number of bytes
  118. n_samples = int.from_bytes(fh.read(8), byteorder='little') # number of samples
  119. n_ch = int.from_bytes(fh.read(8), byteorder='little') # number of channels
  120. ch_rec_list_len = int.from_bytes(fh.read(2), byteorder='little')
  121. ch_rec_list = np.frombuffer(fh.read(ch_rec_list_len), dtype=np.uint16) # detailed channel list
  122. # log.info(f'recorded channels: {ch_rec_list}')
  123. d = fh.read(n_bytes)
  124. d2 = np.frombuffer(d, dtype=np.float32)
  125. d3 = d2.reshape(d2.size // n_ch, n_ch) # data, shape : (n_samples, n_ch)
  126. print('\n', t_now1, t_now2, n_bytes, n_samples, n_ch, '\n')
  127. log.info(f'data shape: {d3.shape}')
  128. if data.size == 0 and data.shape[1] != d3.shape[1]:
  129. log.warning(f'Shape mismatch. {d3.shape} vs {data.shape[1]}. Using data shape from file: {d3.shape}')
  130. data = np.empty((0, d3.shape[1]))
  131. data = np.concatenate((data, d3))
  132. fct = params.daq.spike_rates.bin_width * 30000 # factor to get correct starting time in ticks
  133. # time_stamps.extend(np.arange(t_now2-d3.shape[0]*fct + 1, t_now2+1)) # check if +1 index is correct
  134. # time_stamps.extend(np.arange(t_now2-d3.shape[0]*fct + 1, t_now2+1, 3000)) # check if +1 index is correct
  135. ts = np.frombuffer(fh.read(8 * d3.shape[0]), dtype=np.uint64)
  136. time_stamps.extend(ts)
  137. log.info(f'ts size: {ts.size}, {ts[0]}, {ts[-1]}')
  138. # log.info(time_stamps)
  139. # if verbose:
  140. print(ii, t_now1[0], t_now2, n_bytes, n_samples, n_ch, d3[10:20, 0], np.any(d3))
  141. ii +=1
  142. return None
  143. def get_accuracy_question(fname, n_triggers=2):
  144. with open(fname, 'r') as fh:
  145. events = fh.read().splitlines()
  146. cl1 = []
  147. cl2 = []
  148. for ev in events:
  149. if 'Yes Question' in ev:
  150. if 'Decoder decision: yes' in ev:
  151. cl1.append(True)
  152. else:
  153. cl1.append(False)
  154. elif 'No Question' in ev:
  155. if 'Decoder decision: no' in ev:
  156. cl2.append(True)
  157. else:
  158. cl2.append(False)
  159. return [ind1, ind2]
  160. def get_triggers(fname, time_stamps, trigger_pos, n_triggers=2):
  161. with open(fname, 'r') as fh:
  162. # events = fh.readlines()
  163. events = fh.read().splitlines()
  164. cl1 = []
  165. cl2 = []
  166. cl3 = []
  167. tt1 = []
  168. tt2 = []
  169. tt3 = []
  170. for ev in events:
  171. if 'response' in ev and 'yes' in ev and trigger_pos in ev:
  172. cl1.append(int(ev.split(',')[0]))
  173. elif 'response' in ev and 'no' in ev and trigger_pos in ev:
  174. cl2.append(int(ev.split(',')[0]))
  175. elif 'baseline' in ev and 'start'in ev:
  176. cl3.append(int(ev.split(',')[0]))
  177. if n_triggers == 2:
  178. # cl2.extend(cl3) # add baseline to class 2
  179. cl3 = []
  180. for ev in events:
  181. if 'response' in ev and trigger_pos in ev:
  182. print(f'\033[91m{ev}\033[0m')
  183. else:
  184. print(ev)
  185. for ii in cl1:
  186. tt1.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
  187. for ii in cl2:
  188. tt2.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
  189. for ii in cl3:
  190. tt3.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
  191. ind1 = np.where(np.in1d(time_stamps, tt1))[0][np.newaxis, :]
  192. ind2 = np.where(np.in1d(time_stamps, tt2))[0][np.newaxis, :]
  193. ind3 = np.where(np.in1d(time_stamps, tt3))[0][np.newaxis, :]
  194. print()
  195. print(cl1, tt1, ind1)
  196. print(cl2, tt2, ind2)
  197. print(cl3, tt3, ind3)
  198. if n_triggers == 1:
  199. res = [ind1]
  200. elif n_triggers == 2:
  201. res = [ind1, ind2]
  202. # res = [ind1, np.hstack((ind2, ind3))] # put no and baseline together
  203. elif n_triggers == 3:
  204. res = [ind1, ind2, ind3]
  205. return res
  206. def get_triggers_feedback(fname, time_stamps, trigger_pos, n_triggers=2):
  207. with open(fname, 'r') as fh:
  208. # events = fh.readlines()
  209. events = fh.read().splitlines()
  210. cl1 = []
  211. cl2 = []
  212. cl3 = []
  213. tt1 = []
  214. tt2 = []
  215. tt3 = []
  216. for ev in events:
  217. if 'response' in ev and 'down' in ev and trigger_pos in ev:
  218. cl2.append(int(ev.split(',')[0]))
  219. elif 'response' in ev and 'up' in ev and trigger_pos in ev:
  220. cl1.append(int(ev.split(',')[0]))
  221. elif 'baseline' in ev and 'start'in ev:
  222. cl3.append(int(ev.split(',')[0]))
  223. if n_triggers == 2:
  224. # cl2.extend(cl3) # add baseline to class 2
  225. cl3 = []
  226. for ev in events:
  227. if 'response' in ev and trigger_pos in ev:
  228. print(f'\033[91m{ev}\033[0m')
  229. else:
  230. print(ev)
  231. for ii in cl1:
  232. tt1.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
  233. for ii in cl2:
  234. tt2.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
  235. for ii in cl3:
  236. tt3.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
  237. ind1 = np.where(np.in1d(time_stamps, tt1))[0][np.newaxis, :]
  238. ind2 = np.where(np.in1d(time_stamps, tt2))[0][np.newaxis, :]
  239. ind3 = np.where(np.in1d(time_stamps, tt3))[0][np.newaxis, :]
  240. print()
  241. print(cl1, tt1, ind1)
  242. print(cl2, tt2, ind2)
  243. print(cl3, tt3, ind3)
  244. if n_triggers == 1:
  245. res = [ind1]
  246. elif n_triggers == 2:
  247. res = [ind1, ind2]
  248. # res = [ind1, np.hstack((ind2, ind3))] # put no and baseline together
  249. elif n_triggers == 3:
  250. res = [ind1, ind2, ind3]
  251. return res
  252. def get_triggers_exploration(fname, time_stamps):
  253. with open(fname, 'r') as fh:
  254. # events = fh.readlines()
  255. events = fh.read().splitlines()
  256. config = read_config('paradigm.yaml')
  257. states = config.exploration.states
  258. cl1 = [[] for x in range(len(states))]
  259. cl2 = []
  260. # cl3 = []
  261. tt1 = [[] for x in range(len(states))]
  262. tt2 = []
  263. # tt3 = []
  264. for ev in events:
  265. for ii,state in enumerate(states):
  266. if 'response' in ev and state in ev and 'start' in ev:
  267. cl1[ii].append(int(ev.split(',')[0]))
  268. if 'baseline' in ev and 'start'in ev:
  269. cl2.append(int(ev.split(',')[0]))
  270. for ev in events:
  271. if 'response' in ev and 'start' in ev:
  272. print(f'\033[91m{ev}\033[0m')
  273. else:
  274. print(ev)
  275. for ii in range(len(cl1)):
  276. for jj in cl1[ii]:
  277. tt1[ii].append(time_stamps.flat[np.abs(time_stamps - jj).argmin()])
  278. for ii in cl2:
  279. tt2.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
  280. ind1 = [[] for x in range(len(states))]
  281. for ii in range(len(tt1)):
  282. ind1[ii] = np.where(np.in1d(time_stamps, tt1[ii]))[0][np.newaxis, :]
  283. ind2 = np.where(np.in1d(time_stamps, tt2))[0][np.newaxis, :]
  284. # ind3 = np.where(np.in1d(time_stamps, tt3))[0][np.newaxis, :]
  285. print()
  286. print(cl1, tt1, ind1)
  287. print(cl2, tt2, ind2)
  288. # print(cl3, tt3, ind3)
  289. # xx
  290. res = [ind1,ind2]
  291. return res
  292. def get_triggers_all(fname, time_stamps, trigger_pos, n_triggers=2):
  293. with open(fname, 'r') as fh:
  294. # events = fh.readlines()
  295. events = fh.read().splitlines()
  296. cl1 = []
  297. cl2 = []
  298. cl3 = []
  299. cl4 = []
  300. cl5 = []
  301. cl6 = []
  302. tt1 = []
  303. tt2 = []
  304. tt3 = []
  305. tt4 = []
  306. tt5 = []
  307. tt6 = []
  308. for ev in events:
  309. if 'baseline' in ev and 'start'in ev:
  310. cl1.append(int(ev.split(',')[0]))
  311. elif 'baseline' in ev and 'stop'in ev:
  312. cl2.append(int(ev.split(',')[0]))
  313. elif 'stimulus' in ev and 'start'in ev:
  314. cl3.append(int(ev.split(',')[0]))
  315. elif 'stimulus' in ev and 'stop'in ev:
  316. cl4.append(int(ev.split(',')[0]))
  317. elif 'response' in ev and 'start' in ev:
  318. cl5.append(int(ev.split(',')[0]))
  319. elif 'response' in ev and 'stop' in ev:
  320. cl6.append(int(ev.split(',')[0]))
  321. for ev in events:
  322. if 'response' in ev and trigger_pos in ev:
  323. print(f'\033[91m{ev}\033[0m')
  324. else:
  325. print(ev)
  326. for ii in cl1:
  327. tt1.append(time_stamps.flat[np.abs(np.int64(time_stamps - ii)).argmin()])
  328. for ii in cl2:
  329. tt2.append(time_stamps.flat[np.abs(np.int64(time_stamps - ii)).argmin()])
  330. for ii in cl3:
  331. tt3.append(time_stamps.flat[np.abs(np.int64(time_stamps - ii)).argmin()])
  332. for ii in cl4:
  333. tt4.append(time_stamps.flat[np.abs(np.int64(time_stamps - ii)).argmin()])
  334. for ii in cl5:
  335. tt5.append(time_stamps.flat[np.abs(np.int64(time_stamps - ii)).argmin()])
  336. for ii in cl6:
  337. tt6.append(time_stamps.flat[np.abs(np.int64(time_stamps - ii)).argmin()])
  338. ind1 = np.where(np.in1d(time_stamps, tt1))[0][np.newaxis, :]
  339. ind2 = np.where(np.in1d(time_stamps, tt2))[0][np.newaxis, :]
  340. ind3 = np.where(np.in1d(time_stamps, tt3))[0][np.newaxis, :]
  341. ind4 = np.where(np.in1d(time_stamps, tt4))[0][np.newaxis, :]
  342. ind5 = np.where(np.in1d(time_stamps, tt5))[0][np.newaxis, :]
  343. ind6 = np.where(np.in1d(time_stamps, tt6))[0][np.newaxis, :]
  344. print('\nTriggers and timestamps')
  345. print(cl1, tt1, ind1)
  346. print(cl2, tt2, ind2)
  347. print(cl3, tt3, ind3)
  348. print(cl4, tt4, ind4)
  349. print(cl5, tt5, ind5)
  350. print(cl6, tt6, ind6)
  351. # if n_triggers == 1:
  352. # res = [ind1]
  353. # elif n_triggers == 2:
  354. # res = [ind1, ind2]
  355. # # res = [ind1, np.hstack((ind2, ind3))] # put no and baseline together
  356. # elif n_triggers == 3:
  357. # res = [ind1, ind2, ind3]
  358. res = [ind1, ind2, ind3, ind4, ind5, ind6]
  359. return res
  360. def read_config(file_name):
  361. try:
  362. with open(file_name) as stream:
  363. config = munch.fromYAML(stream)
  364. return config
  365. except Exception as e:
  366. raise e
  367. if __name__ == '__main__':
  368. print("\nto read binary data use: 'data = get_raw(verbose=1)'")
  369. print("\nto read log file use: 'log = read_log(date)'")
  370. if aux.args.speller == 'exploration':
  371. exploration = True
  372. else:
  373. exploration = False
  374. if aux.args.speller == 'feedback':
  375. feedback = True
  376. else:
  377. feedback = False
  378. col = ['b', 'r', 'g']
  379. data_tot, tt, triggers_tot, ch_rec_list, file_names = get_raw(n_triggers=params.classifier.n_classes, exploration=exploration, feedback=feedback)
  380. if not exploration:
  381. plt.figure(1)
  382. plt.clf()
  383. xx = np.arange(data_tot[0, 0].shape[0]) * 0.05
  384. for cl_id in range(triggers_tot.shape[1]):
  385. markerline, stemlines, baseline = plt.stem(triggers_tot[0, cl_id][0]*0.05, triggers_tot[0, cl_id][0] * 0 + 50, '--', basefmt=' ')
  386. plt.setp(stemlines, alpha=0.8, color=col[cl_id], lw=1)
  387. plt.setp(markerline, marker=None)
  388. plt.gca().set_prop_cycle(None)
  389. plt.plot(xx, data_tot[0, 0][:, :2], alpha=0.5)
  390. # plt.plot(cur_data[0, 0][:, :2])
  391. plt.show()