plot_figures_part_B.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. from helpers import data_management as dm
  2. import matplotlib.pyplot as plt
  3. import matplotlib
  4. from pathlib import Path
  5. # matplotlib.use('TkAgg')
  6. import numpy as np
  7. import pandas as pd
  8. from helpers.data import DataNormalizer
  9. import re
  10. import yaml
  11. import munch
  12. import scipy
  13. import scipy.interpolate
  14. from scipy import stats
  15. from helpers.fbplot import plot_df_combined
  16. from helpers.nsp import fix_timestamps
  17. from basics import BASE_PATH, BASE_PATH_OUT
  18. plot_t = munch.munchify({'response_start_samples': {'t': [-1.5, 3.0]}})
  19. plot_order_comb = munch.munchify({
  20. 'p': [
  21. {'r': 0, 'c': 0, 'col': (0.635, 0.078, 0.184, .2), 'ls': '-', 'lw': 2, 'show_var': True,
  22. 'sel': lambda df: ((df.target == 'up') | (df.target == 'yes')) & (df.decision == 'up'), 'desc': 'Target: up'},
  23. {'r': 0, 'c': 1, 'col': (0, 0.447, 0.741, .2), 'ls': '-', 'lw': 2, 'show_var': True,
  24. 'sel': lambda df: ((df.target == 'down') | (df.target == 'no')) & (df.decision == 'down'),
  25. 'desc': 'Target: down'},
  26. {'r': 1, 'c': 0, 'col': (0.635, 0.078, 0.184, .2), 'ls': ':', 'lw': 2,
  27. 'show_var': False, 'sel': lambda df: (df.target == 'up') & (df.decision == 'down'),
  28. 'desc': 'Target: up, decision: down'},
  29. {'r': 1, 'c': 1, 'col': (0, 0.447, 0.741, .2), 'ls': ':', 'lw': 2, 'show_var': False,
  30. 'sel': lambda df: (df.target == 'down') & (df.decision == 'up'), 'desc': 'Target: down, decision: up'}
  31. ],
  32. 's': [
  33. {'dcol': 'response_start_samples', 'title': 'Response Period', 'norm': False}],
  34. 'state': {'down': {'c': (0, 0.447, 0.741, .2), 'd': 'Low Frequency Target Tone'},
  35. 'up': {'c': (0.635, 0.078, 0.184, .2), 'd': 'High Frequency Target Tone'},
  36. 'no': {'c': (0, 0.447, 0.741, .2), 'd': 'Low Frequency Target Tone'},
  37. 'yes': {'c': (0.635, 0.078, 0.184, .2), 'd': 'High Frequency Target Tone'}}},
  38. )
  39. SAVENAME = f'Figure_3A_FBTrials'
  40. def extract_trials(filename, offsets=None):
  41. if offsets is None:
  42. offsets = []
  43. with open(filename, 'r') as f:
  44. evs = f.read().splitlines()
  45. # fix event timestamps
  46. tpat = re.compile(r"^(\d+)(, .*)$")
  47. stage_pat = re.compile(r"(\d+), b'(feedback|question), (\w+), (\w+), (\w+), (\w+)'$")
  48. decpat = re.compile(r".*\s(\w+), Decoder decision is: (.*)'$")
  49. evs_time = []
  50. for ev in evs:
  51. m = tpat.match(ev)
  52. evs_time.append(np.int64(m.group(1)))
  53. evs_time = np.asarray(evs_time, dtype=np.int64)
  54. rollev, = np.where(np.diff(evs_time) < 0)
  55. for j, i in enumerate(rollev):
  56. evs_time[(i + 1):] += offsets[j]
  57. trials = []
  58. this_trial = {}
  59. for ev, ev_time in zip(evs, evs_time):
  60. m = stage_pat.match(ev)
  61. if m is not None:
  62. if m.group(5) == 'Block':
  63. continue
  64. if m.group(5) == 'baseline' and m.group(6) == 'start':
  65. if len(this_trial) > 0:
  66. trials.append(this_trial)
  67. this_trial = {'baseline_start': -1, 'baseline_stop': -1, 'stimulus_start': -1,
  68. 'stimulus_stop': -1, 'response_start': -1, 'response_stop': -1, 'target': m.group(4),
  69. 'response_start_samples': None, 'response_start_samples_wnan': None,
  70. 'response_start_samples_norm': None, 'response_start_samples_good': None}
  71. this_trial[f'{m.group(5)}_{m.group(6)}'] = ev_time
  72. continue
  73. m = decpat.match(ev)
  74. if m is not None:
  75. if m.group(2) == 'yes':
  76. this_trial['decision'] = 'up'
  77. elif m.group(2) == 'no':
  78. this_trial['decision'] = 'down'
  79. else:
  80. this_trial['decision'] = m.group(2)
  81. if this_trial['decision']:
  82. trials.append(this_trial)
  83. return pd.DataFrame(trials)
  84. s = {'day': '2019-11-21', 'cfg_t': '15_23_18'}
  85. day_str = s['day']
  86. cfg_t_str = s['cfg_t']
  87. data_t_str = s.get('data_t', s['cfg_t'])
  88. try:
  89. fn_cfgdump = BASE_PATH / 'KIAP_BCI_neurofeedback' / day_str / f'config_dump_{cfg_t_str}.yaml'
  90. with open(fn_cfgdump) as stream:
  91. params = yaml.load(stream, Loader=yaml.Loader)
  92. except Exception as e:
  93. print(e)
  94. fb_states = params.paradigms.feedback.states
  95. for (k, v) in plot_t.items():
  96. plot_t[k].s = [int(plot_t[k].t[0] * 1000 / params.daq.spike_rates.loop_interval),
  97. int(plot_t[k].t[1] * 1000 / params.daq.spike_rates.loop_interval)]
  98. plot_t[k].swin = np.arange(plot_t[k].s[0], plot_t[k].s[1] + 1)
  99. plot_t[k].twin = plot_t[k].swin / 1000.0 * params.daq.spike_rates.loop_interval
  100. # print(params)
  101. fn_sess = BASE_PATH / 'KIAP_BCI_neurofeedback' / day_str / f'data_{data_t_str}.bin'
  102. fn_evs = BASE_PATH / 'KIAP_BCI_neurofeedback' / day_str / f'events_{data_t_str}.txt'
  103. datav, ts, ch_rec_list = dm.get_session(fn_sess, params=params)
  104. ts, offsets, _ = fix_timestamps(ts)
  105. tv = ts / 3e4
  106. trs = extract_trials(fn_evs, offsets=offsets)
  107. dn = DataNormalizer(params)
  108. fr = dn.calculate_norm_rate(datav)
  109. plot_data = np.reshape(fr, (-1, 1))
  110. labels = ('fr',)
  111. sess_info = 'Channels used for control: [' + ' '.join([f"{ch.id}" for ch in params.daq.normalization.channels]) + ']'
  112. for label, dv in zip(labels, np.hsplit(plot_data, plot_data.shape[1])):
  113. dv = np.reshape(dv, (-1))
  114. for ii, row in trs.iterrows():
  115. t = row['stimulus_start']
  116. t_off = np.where(ts >= t)[0][0] - 1
  117. t_stop = row['response_stop']
  118. if len(np.where(ts >= t_stop)[0]) > 0:
  119. t_stop_off = np.where(ts >= t_stop)[0][0] - 1
  120. t_rstart = row['response_start']
  121. if len(np.where(ts >= t_rstart)[0]) > 0:
  122. t_rstart_off = np.where(ts >= t_rstart)[0][0] - 1
  123. resp_start_idx = plot_t.response_start_samples.swin + t_rstart_off
  124. if np.all(resp_start_idx < len(dv)):
  125. trs.at[ii, 'response_start_samples'] = dv[resp_start_idx]
  126. else:
  127. trs.at[ii, 'response_start_samples'] = np.empty(resp_start_idx.shape)
  128. trs.at[ii, 'response_start_samples'][:] = np.nan
  129. good_vals = resp_start_idx < len(dv)
  130. trs.at[ii, 'response_start_samples'][good_vals] = dv[resp_start_idx[good_vals]]
  131. trs.at[ii, 'response_start_offset'] = (ts[t_rstart_off] - t_rstart) / 3e4
  132. trs.at[ii, 'response_start_samples_good'] = resp_start_idx < t_stop_off - 1
  133. # trs.at[ii, 'response_start_samples_good'] = np.full(plot_t.response_start_samples.swin.shape, True)
  134. trs.at[ii, 'response_start_samples_wnan'] = trs.at[ii, 'response_start_samples']
  135. trs.at[ii, 'response_start_samples_wnan'][~ trs.at[ii, 'response_start_samples_good']] = np.nan
  136. fp = trs.at[ii, 'response_start_samples'][trs.at[ii, 'response_start_samples_good']]
  137. xp = plot_t.response_start_samples.twin[trs.at[ii, 'response_start_samples_good']]
  138. x = np.linspace(xp[0], xp[-1], len(plot_t.response_start_samples.twin))
  139. f = scipy.interpolate.interp1d(xp, fp, kind='cubic')
  140. trs.at[ii, 'response_start_samples_norm'] = np.interp(x, xp, fp)
  141. if np.all(np.vstack(trs['response_start_samples']) == 0):
  142. continue
  143. BASE_PATH_OUT.mkdir(parents=True, exist_ok=True)
  144. fig = plt.figure(35, figsize=(16, 9))
  145. fig.clf()
  146. plot_df_combined(trs, plot_order_comb, plot_t, fb_states, fig=fig,
  147. options=munch.Munch({'show_thresholds': label == 'fr', 'show_end': 0, 'show_median': False}))
  148. fig.suptitle(f"{day_str} {cfg_t_str} ch:{label}\n{sess_info} n={len(trs)}")
  149. fig.show()
  150. savename = BASE_PATH_OUT / SAVENAME
  151. fig.savefig(savename.with_suffix('.pdf'))
  152. fig.savefig(savename.with_suffix('.eps'))
  153. fig.savefig(savename.with_suffix('.svg'))
  154. n_correct_up = ((trs.target == trs.decision) & (trs.target == 'up')).sum()
  155. n_correct_down = ((trs.target == trs.decision) & (trs.target == 'down')).sum()
  156. n_error_up = (('down' == trs.decision) & (trs.target == 'up')).sum()
  157. n_error_down = (('up' == trs.decision) & (trs.target == 'down')).sum()
  158. n_timeout_up = (('unclassified' == trs.decision) & (trs.target == 'up')).sum()
  159. n_timeout_down = (('unclassified' == trs.decision) & (trs.target == 'down')).sum()
  160. n_total = len(trs)
  161. print(
  162. f"Total trials: {n_total}\nTarget up, decision up: {n_correct_up} trials.\nTarget up, decision down: {n_error_up} trials.\nTarget down, decision up: {n_error_down} trials.\nTarget down, decision down: {n_correct_down} trials.\nCorrect: {(n_correct_up + n_correct_down) / n_total}. Correct up: {n_correct_up / (n_correct_up + n_error_up)}\nCorrect down: {n_correct_down / (n_correct_down + n_error_down)}\nTime-out up target: {n_timeout_up}; Time-out down target: {n_timeout_down}")
  163. # Export as CSV
  164. trs_export = pd.DataFrame()
  165. trs_export['trial_data'] = [r.response_start_samples[r.response_start_samples_good] for i, r in trs.iterrows()]
  166. trs_export['trial_times'] = [plot_t['response_start_samples']['twin'][r.response_start_samples_good] for i, r in trs.iterrows()]
  167. trs_export['target'] = trs['target']
  168. trs_export['decision'] = trs['decision']
  169. trs_export.to_csv(savename.with_suffix('.csv'))