data_management.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. '''
  2. description: script to read header and data from data.bin
  3. author: Ioannis Vlachos
  4. date: 02.11.18
  5. Copyright (c) 2018 Ioannis Vlachos.
  6. All rights reserved.
  7. HEADER OF BINARY FILE
  8. ---------------------
  9. laptop timestamp np.datetime64 bytes: 8
  10. NSP timestamp np.int64 bytes: 8
  11. number of bytes np.int64 bytes: 8
  12. number of samples np.int64 bytes: 8
  13. number of channels np.int64 bytes: 8
  14. '''
  15. import csv
  16. import datetime as dt
  17. import glob
  18. import os
  19. import sys
  20. import matplotlib.pyplot as plt
  21. import numpy as np
  22. from numpy import datetime64 as dt64
  23. from tabulate import tabulate
  24. # import munch
  25. from . import kaux as aux
  26. from .kaux import log
  27. # params = aux.load_config()
  28. # def get_event_name(filename, events_file_names)
  29. def get_raw(verbose=0, n_triggers=2, exploration=False, feedback=False, triggers_all=False, fids=[], trigger_pos='start'):
  30. '''read raw data from one or more binary files
  31. trigger_pos: ['start, 'stop']'''
  32. params = aux.load_config()
  33. n_channels = params.daq.n_channels
  34. file_names = []
  35. # events_file_names = []
  36. print('Available binary data files:\n')
  37. for ii, file_name in enumerate(sorted(glob.iglob(params.file_handling.data_path + '**/*.bin', recursive=True))):
  38. print(f'{ii}, {file_name}, {os.path.getsize(file_name)//1000}K')
  39. file_names.append(file_name)
  40. # for ii, file_name in enumerate(sorted(glob.iglob(params.file_handling.data_path + '**/*.txt', recursive=True))):
  41. # events_file_names.append(file_name)
  42. if fids ==[]:
  43. fids = [int(x) for x in input('\nSelect file ids (separated by space): ').split()] # select file ids
  44. if fids == []:
  45. fids = [len(file_names)-1]
  46. print(f'Selected: {fids}')
  47. data_tot = np.empty((len(fids), 1), dtype=object)
  48. time_stamps_tot = np.empty((len(fids), 1), dtype=object)
  49. if triggers_all:
  50. triggers_tot = np.empty((len(fids), 6), dtype=object)
  51. elif exploration:
  52. triggers_tot = np.empty((len(fids), 1), dtype=object)
  53. elif feedback:
  54. triggers_tot = np.empty((len(fids), 2), dtype=object) # up and down
  55. else:
  56. triggers_tot = np.empty((len(fids), n_triggers), dtype=object)
  57. ch_rec_list_tot = np.empty((len(fids), 1), dtype=object)
  58. for ii, fid in enumerate(fids): # go through all sessions
  59. data, time_stamps, ch_rec_list = get_session(file_names[fid], verbose=verbose)
  60. data = np.delete(data, params.classifier.exclude_data_channels, axis=1)
  61. # triggers = get_triggers(os.path.dirname(file_names[fid])+'/events.txt', time_stamps, n_triggers)
  62. event_file_name = file_names[fid].replace('data_','events_').replace('.bin','.txt')
  63. info_file_name = file_names[fid].replace('data_','info_').replace('.bin','.log')
  64. if triggers_all:
  65. triggers = get_triggers_all(event_file_name, time_stamps, trigger_pos)
  66. elif exploration:
  67. triggers = get_triggers_exploration(event_file_name, time_stamps)
  68. elif feedback:
  69. triggers = get_triggers_feedback(event_file_name, time_stamps, trigger_pos, n_triggers)
  70. else:
  71. triggers = get_triggers(event_file_name, time_stamps, trigger_pos, n_triggers)
  72. # yes_mask, no_mask = get_accuracy(info_file_name,n_triggers)
  73. # triggers[0] = np.where(yes_mask,triggers[0],'')
  74. # triggers[1] = np.where(no_mask,triggers[1],'')
  75. data_tot[ii, 0] = data
  76. time_stamps_tot[ii, 0] = time_stamps
  77. if exploration:
  78. triggers_tot[ii,0] = triggers
  79. else:
  80. triggers_tot[ii, :] = triggers
  81. ch_rec_list_tot[ii, 0] = ch_rec_list
  82. # triggers_tot[ii,1] = triggers[1]
  83. print(f'\nRead binary neural data from file {file_names[fid]}')
  84. print(f'Read trigger info events from file {event_file_name}')
  85. print(f'Session {ii}, data shape: {data.shape}')
  86. config = read_config('paradigm.yaml')
  87. states = config.exploration.states
  88. for jj, _ in enumerate(fids):
  89. print()
  90. if triggers_all:
  91. for ii in range(6):
  92. print(f'Session {fids[jj]}, Class {ii}, trigger #: {triggers_tot[jj,ii].shape}')
  93. elif exploration:
  94. for ii in range(len(triggers_tot[jj,0][0])):
  95. print(f'Session {fids[jj]}, State {states[ii]}, trigger #: {triggers_tot[jj,0][0][ii].shape}')
  96. else:
  97. for ii in range(n_triggers):
  98. print(f'Session {fids[jj]}, Class {ii}, trigger #: {triggers_tot[jj,ii].shape}')
  99. file_names = [file_names[ii] for ii in fids]
  100. return data_tot, time_stamps_tot, triggers_tot, ch_rec_list_tot, file_names
  101. def get_session(file_name, verbose=0, t_lim_start=None, t_lim_end=None, params=None):
  102. ii = 0
  103. # data = np.empty((0, params.daq.n_channels-len(params.daq.exclude_channels)))
  104. data = np.empty((0, params.daq.n_channels_max))
  105. log.info(f'Data shape: {data.shape}')
  106. time_stamps = []
  107. date_times = []
  108. time_stamps_rcv = []
  109. with open(file_name, 'rb') as fh:
  110. log.info(f'\nReading binary file {file_name}...\n')
  111. while True:
  112. tmp = np.frombuffer(fh.read(8), dtype='datetime64[us]') # laptop timestamp
  113. if tmp.size == 0:
  114. break
  115. # return data, np.array(time_stamps, dtype=np.uint), ch_rec_list
  116. else:
  117. t_now1 = tmp
  118. if (t_lim_end is not None) and (t_lim_end < t_now1[0]):
  119. break
  120. t_now2 = np.frombuffer(fh.read(8), dtype=np.int64)[0] # NSP timestamp
  121. n_bytes = int.from_bytes(fh.read(8), byteorder='little') # number of bytes
  122. n_samples = int.from_bytes(fh.read(8), byteorder='little') # number of samples
  123. n_ch = int.from_bytes(fh.read(8), byteorder='little') # number of channels
  124. ch_rec_list_len = int.from_bytes(fh.read(2), byteorder='little')
  125. ch_rec_list = np.frombuffer(fh.read(ch_rec_list_len), dtype=np.uint16) # detailed channel list
  126. # log.info(f'recorded channels: {ch_rec_list}')
  127. d = fh.read(n_bytes)
  128. d2 = np.frombuffer(d, dtype=np.float32)
  129. d3 = d2.reshape(d2.size // n_ch, n_ch) # data, shape : (n_samples, n_ch)
  130. log.info(f'data shape: {d3.shape}')
  131. if data.size == 0 and data.shape[1] != d3.shape[1]:
  132. log.warning(f'Shape mismatch. {d3.shape} vs {data.shape[1]}. Using data shape from file: {d3.shape}')
  133. data = np.empty((0, d3.shape[1]))
  134. # fct = params.daq.spike_rates.bin_width * 30000 # factor to get correct starting time in ticks
  135. # time_stamps.extend(np.arange(t_now2-d3.shape[0]*fct + 1, t_now2+1)) # check if +1 index is correct
  136. # time_stamps.extend(np.arange(t_now2-d3.shape[0]*fct + 1, t_now2+1, 3000)) # check if +1 index is correct
  137. ts = np.frombuffer(fh.read(8 * d3.shape[0]), dtype=np.uint64)
  138. if (t_lim_start is None) or (t_lim_start <= t_now1[0] - np.timedelta64(int((t_now2 - ts[0]) / 3e4), 's')):
  139. data = np.concatenate((data, d3))
  140. time_stamps.extend(ts)
  141. date_times.append(t_now1[0])
  142. time_stamps_rcv.append(t_now2)
  143. else:
  144. log.info(f"Skipped set {t_now1[0]} | {ts[0]} n_samples: {n_samples}")
  145. # if ts.size[0] > 0:
  146. log.info(f'ts size: {ts.size}')
  147. # log.info(time_stamps)
  148. # if verbose:
  149. # print(ii, t_now1[0], t_now2, n_bytes, n_samples, n_ch, d3[10:20, 0], np.any(d3))
  150. ii += 1
  151. return data, np.array(time_stamps, dtype=np.uint), ch_rec_list
  152. # return data, np.array(time_stamps, dtype=np.uint), ch_rec_list, np.array(date_times, dtype='datetime64[us]'), \
  153. # np.array(time_stamps_rcv, dtype=np.uint64)
  154. def get_accuracy_question(fname, n_triggers=2):
  155. with open(fname, 'r') as fh:
  156. events = fh.read().splitlines()
  157. cl1 = []
  158. cl2 = []
  159. for ev in events:
  160. if 'Yes Question' in ev:
  161. if 'Decoder decision: yes' in ev:
  162. cl1.append(True)
  163. else:
  164. cl1.append(False)
  165. elif 'No Question' in ev:
  166. if 'Decoder decision: no' in ev:
  167. cl2.append(True)
  168. else:
  169. cl2.append(False)
  170. return [ind1, ind2]
  171. def get_triggers(fname, time_stamps, trigger_pos, n_triggers=2):
  172. with open(fname, 'r') as fh:
  173. # events = fh.readlines()
  174. events = fh.read().splitlines()
  175. cl1 = []
  176. cl2 = []
  177. cl3 = []
  178. tt1 = []
  179. tt2 = []
  180. tt3 = []
  181. for ev in events:
  182. if 'response' in ev and 'yes' in ev and trigger_pos in ev:
  183. cl1.append(int(ev.split(',')[0]))
  184. elif 'response' in ev and 'no' in ev and trigger_pos in ev:
  185. cl2.append(int(ev.split(',')[0]))
  186. elif 'baseline' in ev and 'start'in ev:
  187. cl3.append(int(ev.split(',')[0]))
  188. if n_triggers == 2:
  189. # cl2.extend(cl3) # add baseline to class 2
  190. cl3 = []
  191. for ev in events:
  192. if 'response' in ev and trigger_pos in ev:
  193. print(f'\033[91m{ev}\033[0m')
  194. else:
  195. print(ev)
  196. for ii in cl1:
  197. tt1.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
  198. for ii in cl2:
  199. tt2.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
  200. for ii in cl3:
  201. tt3.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
  202. ind1 = np.where(np.in1d(time_stamps, tt1))[0][np.newaxis, :]
  203. ind2 = np.where(np.in1d(time_stamps, tt2))[0][np.newaxis, :]
  204. ind3 = np.where(np.in1d(time_stamps, tt3))[0][np.newaxis, :]
  205. print()
  206. print(cl1, tt1, ind1)
  207. print(cl2, tt2, ind2)
  208. print(cl3, tt3, ind3)
  209. if n_triggers == 1:
  210. res = [ind1]
  211. elif n_triggers == 2:
  212. res = [ind1, ind2]
  213. # res = [ind1, np.hstack((ind2, ind3))] # put no and baseline together
  214. elif n_triggers == 3:
  215. res = [ind1, ind2, ind3]
  216. return res
  217. def get_triggers_feedback(fname, time_stamps, trigger_pos, n_triggers=2):
  218. with open(fname, 'r') as fh:
  219. # events = fh.readlines()
  220. events = fh.read().splitlines()
  221. cl1 = []
  222. cl2 = []
  223. cl3 = []
  224. tt1 = []
  225. tt2 = []
  226. tt3 = []
  227. for ev in events:
  228. if 'response' in ev and 'down' in ev and trigger_pos in ev:
  229. cl2.append(int(ev.split(',')[0]))
  230. elif 'response' in ev and 'up' in ev and trigger_pos in ev:
  231. cl1.append(int(ev.split(',')[0]))
  232. elif 'baseline' in ev and 'start'in ev:
  233. cl3.append(int(ev.split(',')[0]))
  234. if n_triggers == 2:
  235. # cl2.extend(cl3) # add baseline to class 2
  236. cl3 = []
  237. for ev in events:
  238. if 'response' in ev and trigger_pos in ev:
  239. print(f'\033[91m{ev}\033[0m')
  240. else:
  241. print(ev)
  242. for ii in cl1:
  243. tt1.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
  244. for ii in cl2:
  245. tt2.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
  246. for ii in cl3:
  247. tt3.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
  248. ind1 = np.where(np.in1d(time_stamps, tt1))[0][np.newaxis, :]
  249. ind2 = np.where(np.in1d(time_stamps, tt2))[0][np.newaxis, :]
  250. ind3 = np.where(np.in1d(time_stamps, tt3))[0][np.newaxis, :]
  251. print()
  252. print(cl1, tt1, ind1)
  253. print(cl2, tt2, ind2)
  254. print(cl3, tt3, ind3)
  255. if n_triggers == 1:
  256. res = [ind1]
  257. elif n_triggers == 2:
  258. res = [ind1, ind2]
  259. # res = [ind1, np.hstack((ind2, ind3))] # put no and baseline together
  260. elif n_triggers == 3:
  261. res = [ind1, ind2, ind3]
  262. return res
  263. def get_triggers_exploration(fname, time_stamps):
  264. with open(fname, 'r') as fh:
  265. # events = fh.readlines()
  266. events = fh.read().splitlines()
  267. config = read_config('paradigm.yaml')
  268. states = config.exploration.states
  269. cl1 = [[] for x in range(len(states))]
  270. cl2 = []
  271. # cl3 = []
  272. tt1 = [[] for x in range(len(states))]
  273. tt2 = []
  274. # tt3 = []
  275. for ev in events:
  276. for ii,state in enumerate(states):
  277. if 'response' in ev and state in ev and 'start' in ev:
  278. cl1[ii].append(int(ev.split(',')[0]))
  279. if 'baseline' in ev and 'start'in ev:
  280. cl2.append(int(ev.split(',')[0]))
  281. for ev in events:
  282. if 'response' in ev and 'start' in ev:
  283. print(f'\033[91m{ev}\033[0m')
  284. else:
  285. print(ev)
  286. for ii in range(len(cl1)):
  287. for jj in cl1[ii]:
  288. tt1[ii].append(time_stamps.flat[np.abs(time_stamps - jj).argmin()])
  289. for ii in cl2:
  290. tt2.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
  291. ind1 = [[] for x in range(len(states))]
  292. for ii in range(len(tt1)):
  293. ind1[ii] = np.where(np.in1d(time_stamps, tt1[ii]))[0][np.newaxis, :]
  294. ind2 = np.where(np.in1d(time_stamps, tt2))[0][np.newaxis, :]
  295. # ind3 = np.where(np.in1d(time_stamps, tt3))[0][np.newaxis, :]
  296. print()
  297. print(cl1, tt1, ind1)
  298. print(cl2, tt2, ind2)
  299. # print(cl3, tt3, ind3)
  300. # xx
  301. res = [ind1,ind2]
  302. return res
  303. def get_triggers_all(fname, time_stamps, trigger_pos, n_triggers=2):
  304. with open(fname, 'r') as fh:
  305. # events = fh.readlines()
  306. events = fh.read().splitlines()
  307. cl1 = []
  308. cl2 = []
  309. cl3 = []
  310. cl4 = []
  311. cl5 = []
  312. cl6 = []
  313. tt1 = []
  314. tt2 = []
  315. tt3 = []
  316. tt4 = []
  317. tt5 = []
  318. tt6 = []
  319. for ev in events:
  320. if 'baseline' in ev and 'start'in ev:
  321. cl1.append(int(ev.split(',')[0]))
  322. elif 'baseline' in ev and 'stop'in ev:
  323. cl2.append(int(ev.split(',')[0]))
  324. elif 'stimulus' in ev and 'start'in ev:
  325. cl3.append(int(ev.split(',')[0]))
  326. elif 'stimulus' in ev and 'stop'in ev:
  327. cl4.append(int(ev.split(',')[0]))
  328. elif 'response' in ev and 'start' in ev:
  329. cl5.append(int(ev.split(',')[0]))
  330. elif 'response' in ev and 'stop' in ev:
  331. cl6.append(int(ev.split(',')[0]))
  332. for ev in events:
  333. if 'response' in ev and trigger_pos in ev:
  334. print(f'\033[91m{ev}\033[0m')
  335. else:
  336. print(ev)
  337. for ii in cl1:
  338. tt1.append(time_stamps.flat[np.abs(np.int64(time_stamps - ii)).argmin()])
  339. for ii in cl2:
  340. tt2.append(time_stamps.flat[np.abs(np.int64(time_stamps - ii)).argmin()])
  341. for ii in cl3:
  342. tt3.append(time_stamps.flat[np.abs(np.int64(time_stamps - ii)).argmin()])
  343. for ii in cl4:
  344. tt4.append(time_stamps.flat[np.abs(np.int64(time_stamps - ii)).argmin()])
  345. for ii in cl5:
  346. tt5.append(time_stamps.flat[np.abs(np.int64(time_stamps - ii)).argmin()])
  347. for ii in cl6:
  348. tt6.append(time_stamps.flat[np.abs(np.int64(time_stamps - ii)).argmin()])
  349. ind1 = np.where(np.in1d(time_stamps, tt1))[0][np.newaxis, :]
  350. ind2 = np.where(np.in1d(time_stamps, tt2))[0][np.newaxis, :]
  351. ind3 = np.where(np.in1d(time_stamps, tt3))[0][np.newaxis, :]
  352. ind4 = np.where(np.in1d(time_stamps, tt4))[0][np.newaxis, :]
  353. ind5 = np.where(np.in1d(time_stamps, tt5))[0][np.newaxis, :]
  354. ind6 = np.where(np.in1d(time_stamps, tt6))[0][np.newaxis, :]
  355. print('\nTriggers and timestamps')
  356. print(cl1, tt1, ind1)
  357. print(cl2, tt2, ind2)
  358. print(cl3, tt3, ind3)
  359. print(cl4, tt4, ind4)
  360. print(cl5, tt5, ind5)
  361. print(cl6, tt6, ind6)
  362. # if n_triggers == 1:
  363. # res = [ind1]
  364. # elif n_triggers == 2:
  365. # res = [ind1, ind2]
  366. # # res = [ind1, np.hstack((ind2, ind3))] # put no and baseline together
  367. # elif n_triggers == 3:
  368. # res = [ind1, ind2, ind3]
  369. res = [ind1, ind2, ind3, ind4, ind5, ind6]
  370. return res
  371. def read_config(file_name):
  372. try:
  373. with open(file_name) as stream:
  374. config = munch.fromYAML(stream)
  375. return config
  376. except Exception as e:
  377. raise e
  378. if __name__ == '__main__':
  379. print("\nto read binary data use: 'data = get_raw(verbose=1)'")
  380. print("\nto read log file use: 'log = read_log(date)'")
  381. if aux.args.speller == 'exploration':
  382. exploration = True
  383. else:
  384. exploration = False
  385. if aux.args.speller == 'feedback':
  386. feedback = True
  387. else:
  388. feedback = False
  389. col = ['b', 'r', 'g']
  390. data_tot, tt, triggers_tot, ch_rec_list, file_names = get_raw(n_triggers=params.classifier.n_classes, exploration=exploration, feedback=feedback)
  391. if not exploration:
  392. plt.figure(1)
  393. plt.clf()
  394. xx = np.arange(data_tot[0, 0].shape[0]) * 0.05
  395. for cl_id in range(triggers_tot.shape[1]):
  396. markerline, stemlines, baseline = plt.stem(triggers_tot[0, cl_id][0]*0.05, triggers_tot[0, cl_id][0] * 0 + 50, '--', basefmt=' ')
  397. plt.setp(stemlines, alpha=0.8, color=col[cl_id], lw=1)
  398. plt.setp(markerline, marker=None)
  399. plt.gca().set_prop_cycle(None)
  400. plt.plot(xx, data_tot[0, 0][:, :2], alpha=0.5)
  401. # plt.plot(cur_data[0, 0][:, :2])
  402. plt.show()