plot_figures_part_A.py 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142
  1. import matplotlib
  2. matplotlib.use('TkAgg')
  3. from datetime import datetime as dt
  4. import matplotlib.pyplot as plt
  5. import pandas as pd
  6. import numpy as np
  7. import re
  8. import helpers.sessions as hs
  9. from pathlib import Path
  10. import yaml
  11. import munch
  12. from basics import BASE_PATH, BASE_PATH_OUT, IMPLANT_DATE, FEEDBACK_CHANGE_DATE, ARRAY_MAPS
  13. import logging
  14. from logging.handlers import TimedRotatingFileHandler
  15. from helpers.tsdumper import TSDumper
  16. from matplotlib.colors import ListedColormap
  17. import matplotlib as mpl
  18. import itertools
  19. import scipy.stats as stats
  20. from scipy.interpolate import UnivariateSpline
  21. logger = logging.getLogger("KIAP")
  22. logger.handlers.clear()
  23. logger.setLevel(logging.DEBUG)
  24. ch = logging.StreamHandler()
  25. ch.setLevel(logging.DEBUG)
  26. formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s')
  27. ch.setFormatter(formatter)
  28. logger.addHandler(ch)
  29. fh = TimedRotatingFileHandler(BASE_PATH_OUT/'kiap_figures.log', when='D', backupCount=5)
  30. fh.setLevel(logging.DEBUG)
  31. fh.setFormatter(formatter)
  32. fh.setLevel(logging.INFO)
  33. logger.addHandler(fh)
  34. def string_count_eq(str1, str2):
  35. """Counts how many characters are the same from the beginning, between two strings"""
  36. n_match = 0
  37. for (c1, c2) in zip(str1, str2):
  38. if c1 == c2:
  39. n_match += 1
  40. else:
  41. break
  42. return n_match
  43. def string_dist(str1, str2):
  44. """Helper function to calculate string distance.
  45. In our case, we want to know how many characters were added and removed between two strings,
  46. the one a speller session started with, and the one it ended with.
  47. If these share n characters at beginning, then the first one has m and the second one has
  48. another k characters, then the difference in character changes is (m+k)."""
  49. return len(str1) + len(str2) - 2 * string_count_eq(str1, str2)
  50. T_RE = re.compile(r"^(\d+-\d+-\d+T\d+:\d+:\d+\.\d+)\s+color.*\('(.*)', '(.*)'\)$")
  51. def extract_session_data(session_log_file_name):
  52. """
  53. given a session log file name for a colour speller session, of the format info_XX_XX_XX.log,
  54. this function returns a dictionary with the following keys:
  55. phrase_start: speller started with this phrase
  56. phrase: speller ended with this phrase
  57. n: number of characters in (phrase - phrase_start)
  58. ch_per_min: characters spelled per minute
  59. """
  60. with open(session_log_file_name, 'r') as f:
  61. evs = f.read().splitlines()
  62. d = {'n': 0, 'ch_per_min': 0, 'phrase_start': '', 'phrase': ''}
  63. m01 = T_RE.match(evs[0])
  64. m02 = T_RE.match(evs[-1])
  65. start_dt = dt.strptime(m01.group(1), '%Y-%m-%dT%H:%M:%S.%f')
  66. end_dt = dt.strptime(m02.group(1), '%Y-%m-%dT%H:%M:%S.%f')
  67. duration_min = (end_dt - start_dt).total_seconds() / 60.0
  68. d['phrase_start'] = m01.group(2)
  69. d['phrase'] = m02.group(2)
  70. d['n'] = string_dist(d['phrase_start'], d['phrase'])
  71. if duration_min > 0:
  72. d['ch_per_min'] = d['n'] / duration_min
  73. return d
  74. def precompile_sessions(pth=BASE_PATH, start=0, n=None, use_cache=True):
  75. """
  76. Finds recorded sessions at pth and loads information into a pandas DataFrame
  77. with columns:
  78. mode: mode of BCI session (feedback, color, exploration, question)
  79. cfg: relative path to the configuration file
  80. events: relative path to the event log file
  81. data: relative path to the binary data file
  82. log: relative path to the debug log file
  83. duration_s: duration of session in seconds (last data timestamp - first data timestamp)
  84. duration_min: duration of session in minutes
  85. start_dt: start of session as datetime, from time in config file
  86. end_dt: end of session as datetime
  87. d_since_impl: days since implantation
  88. The session data will be saved in a cache file and reloaded from there by default,
  89. as it takes some time to read each individual data file.
  90. This reading is necessary to get the most accurate estimate for the length of a session.
  91. Since on some occasions the saving of data files was not ended automatically at the end
  92. of a session, the events file will also be read to determine the length.
  93. Params:
  94. pth: base path for data folders
  95. start: skip this many sessions at start
  96. n: read this many sessions
  97. use_cache: try to read cache file if True
  98. """
  99. cache_file_name = Path(pth, 'session_cache.pkl')
  100. if use_cache and cache_file_name.exists():
  101. logger.info(f"Using cache file {cache_file_name}")
  102. cfgs = pd.read_pickle(cache_file_name)
  103. return cfgs
  104. (cfgs, _) = hs.get_sessions(pth, start=start, n=n)
  105. n_total = len(cfgs)
  106. t_diffs = []
  107. start_dts = []
  108. for i, s in cfgs.iterrows():
  109. logger.debug(f"loading {i + 1} of {n_total}: {s['cfg']}")
  110. try:
  111. (ts, _, _, evts) = hs.get_session_data(pth, s)
  112. t_diff = evts[-1] - evts[0]
  113. logger.info(f"Loaded session of {t_diff} s.")
  114. t_diffs.append(t_diff)
  115. except FileNotFoundError:
  116. logger.warning(f"Data file not found for {pth}")
  117. t_diffs.append(0)
  118. cfgstr = s['cfg']
  119. cfgstr = cfgstr.replace('/', '_')
  120. cfgstr = cfgstr.replace('\\', '_')
  121. start_dt = dt.strptime(cfgstr, '%Y-%m-%d_config_dump_%H_%M_%S.yaml')
  122. start_dts.append(start_dt)
  123. cfgs['duration_s'] = t_diffs
  124. cfgs['duration_min'] = cfgs['duration_s'] / 60.0
  125. cfgs['start_dt'] = start_dts
  126. cfgs['end_dt'] = cfgs['start_dt'] + pd.to_timedelta(cfgs['duration_s'], unit='seconds')
  127. cfgs['d_since_impl'] = cfgs['start_dt'].apply(lambda x: (x.to_pydatetime() - IMPLANT_DATE).days)
  128. cfgs.to_pickle(cache_file_name)
  129. return cfgs
  130. def add_session_info(d):
  131. """
  132. Compute particular bits of information given data sessions.
  133. Params:
  134. d: DataFrame containing session information, as generated by precompile_sessions()
  135. Adds to DataFrame:
  136. cum_dur: cumulative session duration per day
  137. ecol: edge color for plots
  138. col: color for plots based on session mode
  139. Returns: DataFrame
  140. """
  141. d.set_index(['d_since_impl'], inplace=True)
  142. d["cum_dur"] = d.groupby(level='d_since_impl')['duration_min'].cumsum() - d['duration_min']
  143. d.reset_index(inplace=True)
  144. # add colour information to session entries
  145. d['ecol'] = [(1, 1, 1, 1)] * len(d)
  146. d['col'] = [(0.1, 0.1, 0.1)] * len(d)
  147. idx = d['mode'] == 'feedback'
  148. d.loc[idx, 'col'] = pd.Series([(0, 0.447, 0.741)] * idx.sum(), index=d.loc[idx].index)
  149. idx = d['mode'] == 'question'
  150. d.loc[idx, 'col'] = pd.Series([(0.494, 0.184, 0.556)] * idx.sum(), index=d.loc[idx].index)
  151. idx = d['mode'] == 'color'
  152. d.loc[idx, 'col'] = pd.Series([(0.85, 0.325, 0.098)] * idx.sum(), index=d.loc[idx].index)
  153. idx = d['mode'] == 'exploration'
  154. d.loc[idx, 'col'] = pd.Series([(0.301, 0.745, 0.933)] * idx.sum(), index=d.loc[idx].index)
  155. return d
  156. def add_speller_data(d, pth=BASE_PATH):
  157. """
  158. Reads speller results for every speller session and adds that information to
  159. the dataframe.
  160. Params:
  161. d: DataFrame containing session information, as generated by precompile_sessions()
  162. Adds to DataFrame:
  163. 'n': length of communication for speller block
  164. 'ch_per_min': number of characters spelled per minute
  165. 'phrase_start': phrase at begin of speller block
  166. 'phrase': phrase at end of speller block
  167. 'cum_n': cumulative length of phrases per day
  168. Returns: DataFrame
  169. """
  170. d['n'] = pd.array([None] * len(d), dtype='Int32')
  171. d['ch_per_min'] = None
  172. d['phrase_start'] = None
  173. d['phrase'] = None
  174. for ix, rw in d.loc[d['mode'] == 'color'].iterrows():
  175. sp_dict = extract_session_data(Path(pth, rw['log']))
  176. spr = pd.Series(sp_dict, name=ix)
  177. d.loc[ix, spr.index] = spr
  178. d.set_index(['d_since_impl'], inplace=True)
  179. ix = d['mode'] == 'color'
  180. d.loc[ix, "cum_n"] = d.loc[ix].groupby(level='d_since_impl')['n'].cumsum() - d.loc[ix, 'n']
  181. d.reset_index(inplace=True)
  182. return d
  183. def extract_trials_old(filename):
  184. """
  185. filename - path to an info_*.log file
  186. 'old': log pattern before 20 July 2019 ('up' / 'down' trials, 'yes' response for correct, 'unclassified' for timeout)
  187. """
  188. with open(filename, 'r') as f:
  189. evs = f.read().splitlines()
  190. # fix event timestamps
  191. decpat = re.compile(r".*\sfeedback - Decoder decision: (\w*) - \('feedback', '(\w+)'\)$")
  192. contingency = pd.DataFrame
  193. imap = {'up': 'down', 'down': 'up'}
  194. conditions = []
  195. responses = []
  196. for ev in evs:
  197. m = decpat.match(ev)
  198. if m is not None:
  199. condition = m.group(2)
  200. response = m.group(1)
  201. if condition == 'baseline':
  202. continue
  203. if response == 'yes':
  204. response = condition
  205. elif response == 'unclassified':
  206. response = imap.get(condition, 'unclassified')
  207. conditions.append(condition)
  208. responses.append(response)
  209. all_condition = pd.Categorical(conditions, categories=['up', 'down'])
  210. all_response = pd.Categorical(responses, categories=['up', 'down', 'unclassified'])
  211. # print(evs)
  212. ct = pd.crosstab(all_response, all_condition, dropna=False, colnames=['condition'], rownames=['response'])
  213. return ct
  214. def extract_trials_new(filename):
  215. """
  216. filename - path to an info_*.log file
  217. 'new': log pattern on and after 20 July 2019: ('up' / 'down' trials, 'yes' / 'no' / 'unclassified' response)
  218. """
  219. with open(filename, 'r') as f:
  220. evs = f.read().splitlines()
  221. # fix event timestamps
  222. decpat = re.compile(r".*\sfeedback - Decoder decision: (\w*) - \('feedback', '(\w+)'\)$")
  223. contingency = pd.DataFrame
  224. conditions = []
  225. responses = []
  226. for ev in evs:
  227. m = decpat.match(ev)
  228. if m is not None:
  229. condition = m.group(2)
  230. response = m.group(1)
  231. if response == 'yes':
  232. response = 'up'
  233. elif response == 'no':
  234. response = 'down'
  235. conditions.append(condition)
  236. responses.append(response)
  237. all_condition = pd.Categorical(conditions, categories=['up', 'down'])
  238. all_response = pd.Categorical(responses, categories=['up', 'down', 'unclassified'])
  239. # dft = pd.DataFrame({'condition': conditions, 'response':responses})
  240. # ct = pd.crosstab(dft.condition, dft.response)
  241. ct = pd.crosstab(all_response, all_condition, dropna=False, colnames=['condition'], rownames=['response'])
  242. return ct
  243. def extract_trials_q(filename):
  244. """
  245. filename - path to an info_*.log file
  246. log pattern: 2019-10-02T21:20:05.053678 question - Decoder decision: no - ('No question', '002_11038.wav')
  247. """
  248. with open(filename, 'r') as f:
  249. evs = f.read().splitlines()
  250. # fix event timestamps
  251. decpat = re.compile(r".*\squestion - Decoder decision: (\w*) - \('(Yes|No) question', '(.+)'\)$")
  252. contingency = pd.DataFrame
  253. conditions = []
  254. responses = []
  255. for ev in evs:
  256. m = decpat.match(ev)
  257. if m is not None:
  258. condition = m.group(2)
  259. response = m.group(1)
  260. if response == 'yes':
  261. response = 'up'
  262. elif response == 'no':
  263. response = 'down'
  264. if condition == 'Yes':
  265. condition = 'up'
  266. elif condition == 'No':
  267. condition = 'down'
  268. conditions.append(condition)
  269. responses.append(response)
  270. all_condition = pd.Categorical(conditions, categories=['up', 'down'])
  271. all_response = pd.Categorical(responses, categories=['up', 'down', 'unclassified'])
  272. # dft = pd.DataFrame({'condition': conditions, 'response':responses})
  273. # ct = pd.crosstab(dft.condition, dft.response)
  274. ct = pd.crosstab(all_response, all_condition, dropna=False, colnames=['condition'], rownames=['response'])
  275. return ct
  276. def add_feedback_info(d, pth=BASE_PATH):
  277. """
  278. For all feedback sessions, read log file and save contingency table.
  279. """
  280. if 'ct' not in d.columns:
  281. d['ct'] = [None for _ in range(len(d))]
  282. for ix, rw in d.iterrows():
  283. if rw['mode'] != 'feedback':
  284. continue
  285. if FEEDBACK_CHANGE_DATE > rw.start_dt:
  286. ct = extract_trials_old(pth / rw.log)
  287. else:
  288. ct = extract_trials_new(pth / rw.log)
  289. cond_sums = ct[ct.index != 'unclassified'].sum()
  290. tpr = ct.loc['up', 'up'] / cond_sums['up']
  291. fpr = ct.loc['up', 'down'] / cond_sums['down']
  292. acc = (ct.loc['up', 'up'] + ct.loc['down', 'down']) / (cond_sums['up'] + cond_sums['down'])
  293. d.at[ix, 'ct'] = ct
  294. d.at[ix, 'tpr'] = tpr
  295. d.at[ix, 'fpr'] = fpr
  296. d.at[ix, 'acc'] = acc
  297. d.at[ix, 'n_trials'] = ct.sum().sum()
  298. return d
  299. def calculate_feedback_before_speller(d):
  300. """For each day, find feedback sessions before the first speller session and add up trials."""
  301. u_dsi = d['d_since_impl'].unique()
  302. day_n_fb = [(di, d_day[(d_day['mode'] == 'feedback') & (d_day.index < ix)]['ct'].sum().sum().sum())
  303. for di in u_dsi
  304. for d_day in [d[d['d_since_impl'] == di]]
  305. for ix in [d_day[d_day['mode'] == 'color'].first_valid_index()]
  306. if ix is not None
  307. ]
  308. return day_n_fb
  309. def add_question_info(d, pth=BASE_PATH):
  310. """
  311. For all feedback sessions, read log file and save contingency table.
  312. """
  313. if 'ct' not in d.columns:
  314. d['ct'] = [None for _ in range(len(d))]
  315. for ix, rw in d.iterrows():
  316. if rw['mode'] != 'question':
  317. continue
  318. ct = extract_trials_q(pth / rw.log)
  319. cond_sums = ct[ct.index != 'unclassified'].sum()
  320. tpr = ct.loc['up', 'up'] / cond_sums['up']
  321. fpr = ct.loc['up', 'down'] / cond_sums['down']
  322. acc = (ct.loc['up', 'up'] + ct.loc['down', 'down']) / (cond_sums['up'] + cond_sums['down'])
  323. d.at[ix, 'ct'] = ct
  324. d.at[ix, 'tpr'] = tpr
  325. d.at[ix, 'fpr'] = fpr
  326. d.at[ix, 'acc'] = acc
  327. return d
  328. def prepare_for_annotation_export(d):
  329. """
  330. Prepare a yaml file for annotation of speller sessions.
  331. Will be written at [BASE_PATH_OUT]/records_for_annotation.yml
  332. """
  333. d2 = d.reset_index().set_index('start_dt', drop=False)
  334. d2 = d2.loc[d2['mode'] == 'color', ['data', 'd_since_impl', 'start_dt', 'phrase_start', 'phrase']]
  335. d2['intelligible'] = None
  336. d2['start'] = d2['start_dt'].map(lambda x: x.strftime('%Y-%m-%d %H:%M:%S'))
  337. d2 = d2[['data', 'start', 'd_since_impl', 'phrase_start', 'phrase', 'intelligible']]
  338. d_dict = d2.to_dict(orient='index')
  339. with open(BASE_PATH_OUT / 'records_for_annotation.yml', 'w') as f:
  340. f.write("""# Rating scheme for intelligibility of patient's communications.
  341. #
  342. # Instructions for filling this file:
  343. #
  344. # We consider the result of one total session / speller run.
  345. #
  346. # 1. Find session in log by date of data file (also given in plain text)
  347. # 2. Look up speller start (key 'phrase_start') and final output (key 'phrase')
  348. # and look up speller output / session remarks.
  349. # 3. Rate speller output.
  350. # For copy spelling sessions:
  351. # 0 – completely wrong
  352. # 1 – up to 20% of characters wrong or missing
  353. # 2 - no mistake
  354. # For free spelling:
  355. # 0 - incomprehensible speller output
  356. # 1 - partially understandable, but with doubts due to spelling mistakes
  357. # 2 - unambiguous to family / experimenter (even if single letters are
  358. # wrong or missing; even if words are incomplete)
  359. # where one session's output could be counted into several categories,
  360. # category 1 is likely appropriate.
  361. # 4. Find the record for the session in list below and replace the 'null' entry
  362. # under the 'intelligible' key with your rating.
  363. """)
  364. yaml.dump(d_dict, f, default_flow_style=False, Dumper=TSDumper, sort_keys=False)
  365. def add_annotation(d):
  366. """
  367. Read annotations and add them to DataFrame.
  368. Params:
  369. d: DataFrame containing session information, as generated by precompile_sessions()
  370. Adds to DataFrame:
  371. 'intelligible': rating of intelligibility of a spelled session
  372. 'col': updated color based on intelligibility
  373. Returns: DataFrame
  374. """
  375. # with open(Path('annotations', 'speller', 'records_for_annotation_full.yml'), 'r') as f:
  376. with open(Path('annotations', 'speller', 'records_for_annotation_consolidated.yml'), 'r') as f:
  377. anno = munch.Munch.fromYAML(f, Loader=yaml.Loader)
  378. adict = pd.DataFrame.from_dict(anno, orient='index')
  379. d.reset_index(inplace=True)
  380. d.set_index('start_dt', drop=False, inplace=True)
  381. d.loc[adict.index, 'intelligible'] = adict['intelligible']
  382. d.loc[adict.index, 'rating'] = adict['rating']
  383. d.loc[(d['intelligible'] == 2) & (d['mode'] == 'color'), 'col'] = [[(0.318, 0.039, 0.090)]]
  384. d.loc[(d['intelligible'] == 1) & (d['mode'] == 'color'), 'col'] = [[(0.635, 0.078, 0.184)]]
  385. d.loc[(d['intelligible'] == 0) & (d['mode'] == 'color'), 'col'] = [[(1.000, 0.737, 0.843)]]
  386. d.loc[(d['intelligible'] == 0) & (d['mode'] == 'color'), 'ecol'] = [[(.2, .2, .2)]]
  387. d.set_index('index', inplace=True)
  388. return d
  389. def add_channel_info(d, pth=BASE_PATH):
  390. """
  391. For all sessions, load config files and extract information about channels
  392. being used in sessions. Channels are in 1-base number, corresponding to Blackrock's numbering scheme.
  393. """
  394. d.loc[:, 'channels'] = None
  395. d.loc[:, 'use_all'] = None
  396. # d.loc[:, 'norm'] = None
  397. d.loc[:, 'submode'] = None
  398. for ix, rw in d.iterrows():
  399. with open(Path(pth, rw.cfg), 'r') as f:
  400. cfg = munch.Munch.fromYAML(f, Loader=yaml.Loader)
  401. d.loc[ix, 'channels'] = [[ch.id + 1] for ch in cfg.daq.normalization.channels]
  402. d.loc[ix, 'use_all'] = cfg.daq.normalization.use_all_channels
  403. # d.loc[ix, 'norm'] = cfg.daq.normalization
  404. # test if paradigms key in config file exists. if not, load paradigms file
  405. if cfg.get('paradigms') is None:
  406. p_fn = rw.cfg.replace('config_dump', 'paradigm')
  407. with open(Path(pth, p_fn), 'r') as f:
  408. cfg.paradigms = munch.Munch.fromYAML(f, Loader=yaml.Loader)
  409. try:
  410. if rw['mode'] == 'question':
  411. d.loc[ix, 'submode'] = cfg.paradigms.question.mode[cfg.paradigms.question.selected_mode]
  412. elif rw['mode'] == 'color':
  413. d.loc[ix, 'submode'] = cfg.paradigms.color.mode[cfg.paradigms.color.selected_mode]
  414. elif rw['mode'] == 'feedback':
  415. d.loc[ix, 'submode'] = cfg.paradigms.feedback.mode[cfg.paradigms.feedback.selected_mode]
  416. except Exception:
  417. logger.warning(f"Exception reading mode for row {ix}", exc_info=True)
  418. logger.debug(f"Loading normalization info for session {ix} ({d.loc[ix, 'mode']}, {d.loc[ix, 'submode']}):"
  419. f" {d.loc[ix, 'channels']}")
  420. return d
  421. def get_indexer_and_color(d, mode, intelligible=None):
  422. """
  423. Given a session DataFrame, the session mode, and optionally an intelligibility rating,
  424. return indexer for rows matching that criterion, colour, edge colour, and label for plotting.
  425. Params:
  426. d: session DataFrame
  427. mode: one of 'color', 'feedback', 'exploration', 'question'
  428. intelligible: if mode == 'color', this should be 0, 1, or 2
  429. """
  430. ix = d['mode'] == mode
  431. ecol = (1, 1, 1, 1)
  432. col = (.1, .1, .1)
  433. label = ""
  434. if mode == 'color':
  435. if intelligible is None:
  436. ix = ix & pd.isna(d['intelligible'])
  437. else:
  438. ix = ix & (d['intelligible'] == intelligible)
  439. label = "Speller"
  440. if intelligible == 2:
  441. col = (0.635, 0.078, 0.184)
  442. label = "Speller clear"
  443. elif intelligible == 1:
  444. col = (0.85, 0.325, 0.098)
  445. label = "Speller ambiguous"
  446. elif intelligible == 0:
  447. col = (0.929, 0.694, 0.125)
  448. # ecol = (.2, .2, .2)
  449. label = "Speller unintelligible"
  450. else:
  451. col = (0.929, 0.894, 0.825)
  452. # ecol = (.2, .2, .2)
  453. label = "Speller not rated"
  454. elif mode == 'feedback':
  455. col = (0.466, 0.674, 0.188)
  456. label = "Feedback Training"
  457. elif mode == 'exploration':
  458. col = (0, 0.447, 0.741)
  459. label = "Exploration"
  460. elif mode == 'question':
  461. col = (0.301, 0.745, 0.933)
  462. label = "Questions"
  463. return dict(ix=ix, col=col, ec=ecol, label=label)
  464. def plot_bars_for_sel(ax, d, mode, intelligible=None):
  465. """
  466. Plot bars into axis.
  467. Params:
  468. ax: axis to plot into
  469. d: Pandas DataFrame containing session info
  470. mode: BCI mode, one of 'color', 'feedback', 'exploration', 'question'
  471. intelligible: if mode == 'color', this should be 0, 1, or 2
  472. """
  473. ic = get_indexer_and_color(d, mode, intelligible)
  474. return ax.bar(x=d.loc[ic['ix'], 'd_since_impl'], height=d.loc[ic['ix'], 'duration_min'], color=ic['col'],
  475. ec=ic['ec'], label=ic['label'], bottom=d.loc[ic['ix'], 'cum_dur_min'], lw=.5, width=.8)
  476. def prepare_sessions(pth=BASE_PATH):
  477. """Combine a few steps to load session data"""
  478. d = precompile_sessions(pth=pth)
  479. add_session_info(d)
  480. add_speller_data(d, pth=pth)
  481. return d
  482. def plot_sessions(d):
  483. """Load sessions, plot, and print summary"""
  484. from brokenaxes import brokenaxes
  485. n_days = len(d['d_since_impl'].unique())
  486. x_ranges = ((105, 126), (146, 470)) # ((105, 126), (146, 163), (174, 212), (223, 227), (238, 465))
  487. d2 = d.set_index(['d_since_impl'])
  488. other_m_d = d2[d2['mode'] != 'color'].groupby(['d_since_impl', 'mode'])[['duration_min']].sum()
  489. other_m_d.reset_index('mode', inplace=True)
  490. sp_d = d2[d2['mode'] == 'color'][
  491. ['start_dt', 'mode', 'duration_min', 'cfg', 'events', 'data', 'log', 'phrase', 'intelligible', 'n',
  492. 'ch_per_min']]
  493. df_dur_plot = pd.concat([other_m_d, sp_d])
  494. df_dur_plot.sort_index(inplace=True, kind='mergesort')
  495. df_dur_plot['cum_dur_min'] = df_dur_plot.groupby(level=0)['duration_min'].cumsum() - df_dur_plot['duration_min']
  496. df_dur_plot.reset_index(inplace=True)
  497. # save number of letters spelled in 'intelligible'==2 group, per day
  498. n_per_day_intell = d[(d['mode'] == 'color') & (d['intelligible'] == 2)].groupby('d_since_impl')['n'].sum()
  499. n_per_day_intell.name = 'n_per_day'
  500. cpm_intell = d[(d['mode'] == 'color') & (d['intelligible'] == 2)].groupby('d_since_impl').apply(
  501. lambda x: x['n'].sum() / (x['duration_s'].sum() / 60.0))
  502. cpm_intell.name = 'cpm_per_day'
  503. spl = UnivariateSpline(n_per_day_intell.index, n_per_day_intell.values, s=850000, k=3)
  504. tx = np.arange(n_per_day_intell.index[0], n_per_day_intell.index[-1])
  505. smoothv = spl(tx)
  506. spl_cpm = UnivariateSpline(cpm_intell.index, cpm_intell.values, s=850000, k=3)
  507. tx_cpm = np.arange(cpm_intell.index[0], cpm_intell.index[-1])
  508. smoothv_cpm = spl_cpm(tx_cpm)
  509. spell_int = d.groupby('d_since_impl').apply(lambda x: ((x['mode'] == 'color') & (x['intelligible'] >= 2)).any())
  510. spell_int_days = spell_int[spell_int].index
  511. spell_not_int = d.groupby('d_since_impl').apply(
  512. lambda x: ((x['mode'] == 'color') & (x['intelligible'] < 2)).any() & ~ (
  513. (x['mode'] == 'color') & (x['intelligible'] >= 2)).any())
  514. spell_not_int_days = spell_not_int[spell_not_int].index
  515. no_spell = d.groupby('d_since_impl').apply(lambda x: ~(x['mode'] == 'color').any())
  516. no_spell_days = no_spell[no_spell].index
  517. save_name = BASE_PATH_OUT / "Figure_2_SessionSummary"
  518. BASE_PATH_OUT.mkdir(parents=True, exist_ok=True)
  519. fig = plt.figure(2, figsize=(15, 6))
  520. fig.clf()
  521. gs = fig.add_gridspec(3, hspace=0, height_ratios=[4,4,1])
  522. axs = gs.subplots(sharex=True)
  523. no_spell_color = (0.93725, 0.92941, 0.96078)
  524. not_intell_color = (0.73725, 0.74118, 0.86275)
  525. intell_color = (0.45882, 0.41961, 0.69412)
  526. bax = axs[2]
  527. bax.bar(no_spell_days, 1, width=1, color=no_spell_color)
  528. bax.bar(spell_not_int_days, 1, width=1, color=not_intell_color)
  529. bax.bar(spell_int_days, 1, width=1, color=intell_color)
  530. bax.set_ylim([0, 1])
  531. bax.set_xlabel('days since implantation')
  532. bax.set_yticks([])
  533. bax = axs[0]
  534. bax.plot(n_per_day_intell.index, n_per_day_intell.values, color=intell_color, marker='o', linestyle='none')
  535. bax.plot(tx, smoothv, color=intell_color)
  536. bax.set_ylabel('number of characters')
  537. ax2 = axs[1]
  538. ax2.set_ylabel('characters per minute') # we already handled the x-label with ax1
  539. ax2.plot(cpm_intell.index, cpm_intell.values, color=intell_color, marker='o', linestyle='none')
  540. ax2.plot(tx_cpm, smoothv_cpm, color=intell_color)
  541. ax2.tick_params(axis='y')
  542. fig.show()
  543. fig.savefig(save_name.with_suffix(".pdf"))
  544. logger.info(f"Plot saved at <{save_name.with_suffix('.pdf')}>")
  545. fig.savefig(save_name.with_suffix(".svg"))
  546. logger.info(f"Plot saved at <{save_name.with_suffix('.svg')}>")
  547. fig.savefig(save_name.with_suffix(".eps"))
  548. logger.info(f"Plot saved at <{save_name.with_suffix('.eps')}>")
  549. ## Aggregate sessions.
  550. # First, for each day, how much time was spent with feedback, questions, exploration
  551. fbdesc = other_m_d[other_m_d['mode'] == 'feedback'].describe()
  552. qdesc = other_m_d[other_m_d['mode'] == 'question'].describe()
  553. exdesc = other_m_d[other_m_d['mode'] == 'exploration'].describe()
  554. # Sum up all speller sessions
  555. aggr_sp_d = sp_d.groupby(['d_since_impl'])[['duration_min']].sum()
  556. spdesc = aggr_sp_d.describe()
  557. # Only take speller sessions where the message could be understood at least partially.
  558. aggr_sp_int_d = sp_d[sp_d['intelligible'] > 0].groupby(['d_since_impl'])[['duration_min', 'n']].sum()
  559. int_spdesc = aggr_sp_int_d.describe()
  560. aggr_sum = aggr_sp_int_d.sum()
  561. aggr_sp_clear_d = sp_d[sp_d['intelligible'] > 1].groupby(['d_since_impl'])[['duration_min', 'n']].sum()
  562. clear_spdesc = aggr_sp_clear_d.describe()
  563. aggr_clear_sum = aggr_sp_clear_d.sum()
  564. ch_per_min = sp_d[(sp_d['mode'] == 'color') & (sp_d['intelligible'] == 2)]['ch_per_min']
  565. desc_str = f"""
  566. This analysis covers visits on {n_days} days. On average, {fbdesc.loc['mean', 'duration_min']:4.1f} minutes
  567. were spent in feedback training (min/25%/50%/75%/max: {fbdesc.loc['min', 'duration_min']:4.1f}, {fbdesc.loc['25%', 'duration_min']:4.1f}, {fbdesc.loc['50%', 'duration_min']:4.1f}, {fbdesc.loc['75%', 'duration_min']:4.1f}, {fbdesc.loc['max', 'duration_min']:4.1f}).
  568. On {qdesc.loc['count', 'duration_min']:n} days, the question paradigm was performed. On average, {qdesc.loc['mean', 'duration_min']:4.1f} minutes
  569. were spent in the question paradigm (min/25%/50%/75%/max: {qdesc.loc['min', 'duration_min']:4.1f}, {qdesc.loc['25%', 'duration_min']:4.1f}, {qdesc.loc['50%', 'duration_min']:4.1f}, {qdesc.loc['75%', 'duration_min']:4.1f}, {qdesc.loc['max', 'duration_min']:4.1f}).
  570. On {exdesc.loc['count', 'duration_min']:n} days, the exploration paradigm was performed. On average, {exdesc.loc['mean', 'duration_min']:4.1f} minutes
  571. were spent in the exploration paradigm (min/25%/50%/75%/max: {exdesc.loc['min', 'duration_min']:4.1f}, {exdesc.loc['25%', 'duration_min']:4.1f}, {exdesc.loc['50%', 'duration_min']:4.1f}, {exdesc.loc['75%', 'duration_min']:4.1f}, {exdesc.loc['max', 'duration_min']:4.1f}).
  572. On {spdesc.loc['count', 'duration_min']:n} days, we attempted to use the speller.
  573. On {int_spdesc.loc['count', 'duration_min']:n} days, the patient used the speller to generate at least partially understandable output. On average, {int_spdesc.loc['mean', 'duration_min']:4.1f} minutes
  574. were spent spelling (min/25%/50%/75%/max: {int_spdesc.loc['min', 'duration_min']:4.1f}, {int_spdesc.loc['25%', 'duration_min']:4.1f}, {int_spdesc.loc['50%', 'duration_min']:4.1f}, {int_spdesc.loc['75%', 'duration_min']:4.1f}, {int_spdesc.loc['max', 'duration_min']:4.1f}).
  575. On average, the daily output was {int_spdesc.loc['mean', 'n']:4.1f} characters (min/25%/50%/75%/max: {int_spdesc.loc['min', 'n']:n}, {int_spdesc.loc['25%', 'n']:n}, {int_spdesc.loc['50%', 'n']:n}, {int_spdesc.loc['75%', 'n']:n}, {int_spdesc.loc['max', 'n']:n}).
  576. Overall, the patient's at least partially comprehensible utterances comprised {aggr_sum['n']:n} characters
  577. produced over {aggr_sum['duration_min']:n} minutes, corresponding to an grand average rate of {aggr_sum['n'] / aggr_sum['duration_min']:4.2f} characters per minute.
  578. On {clear_spdesc.loc['count', 'duration_min']:n} days, the patient used the speller to generate clearly understandable output. On average, {clear_spdesc.loc['mean', 'duration_min']:4.1f} minutes
  579. were spent spelling (min/25%/50%/75%/max: {clear_spdesc.loc['min', 'duration_min']:4.1f}, {clear_spdesc.loc['25%', 'duration_min']:4.1f}, {clear_spdesc.loc['50%', 'duration_min']:4.1f}, {clear_spdesc.loc['75%', 'duration_min']:4.1f}, {clear_spdesc.loc['max', 'duration_min']:4.1f}).
  580. On average, the daily output was {clear_spdesc.loc['mean', 'n']:4.1f} characters (min/25%/50%/75%/max: {clear_spdesc.loc['min', 'n']:n}, {clear_spdesc.loc['25%', 'n']:n}, {clear_spdesc.loc['50%', 'n']:n}, {clear_spdesc.loc['75%', 'n']:n}, {clear_spdesc.loc['max', 'n']:n}).
  581. Overall, the patient's clearly intelligible communications comprised {aggr_clear_sum['n']:n} characters
  582. produced over {aggr_clear_sum['duration_min']:n} minutes, corresponding to an grand average rate of {aggr_clear_sum['n'] / aggr_clear_sum['duration_min']:4.2f} characters per minute.
  583. Per-session spelling rate was min/median/max: {ch_per_min.min():.1f}/{ch_per_min.median():.1f}/{ch_per_min.max():.1f} characters per minute.
  584. On {len(d.groupby('d_since_impl')) - len(d[d['mode']=='color'].groupby('d_since_impl'))} days, use of the speller was not attempted because criterion was not reached or because of other circumstances.
  585. """
  586. logger.info(desc_str)
  587. days_df = pd.concat([pd.DataFrame(index=spell_int_days, columns=['intell'], data='intelligible'),
  588. pd.DataFrame(index=spell_not_int_days, columns=['intell'], data='not_intelligible'),
  589. pd.DataFrame(index=no_spell_days, columns=['intell'], data='no_speller')]).sort_index()
  590. n_df = pd.concat([n_per_day_intell, cpm_intell], axis=1)
  591. session_summary_df = pd.concat([n_df, days_df], axis=1)
  592. session_summary_df.to_csv(save_name.with_suffix(".csv"))
  593. return d, fig, bax
  594. def get_fb_sessions_before_speller(d, ignore_sessions_before_fb_change=True):
  595. """
  596. Finds neurofeedback sessions before speller
  597. :param d: DataFrame of sessions
  598. :param ignore_sessions_before_fb_change: if True (default), ignore sessions before logging scheme was changed.
  599. :return: DataFrame, list of indices, list of pairs of indices of NF blocks and corresponding speller blocks.
  600. """
  601. if ignore_sessions_before_fb_change:
  602. color_ix = d[d['mode'].isin(['color']) & (d['start_dt'] >= FEEDBACK_CHANGE_DATE)].index
  603. fb_ix = d[(d['mode'] == 'feedback') & (d['start_dt'] >= FEEDBACK_CHANGE_DATE)].index
  604. else:
  605. color_ix = d[d['mode'].isin(['color'])].index
  606. fb_ix = d[(d['mode'] == 'feedback')].index
  607. fbix = np.unique([fb_ix[np.where(fb_ix < ci)[0][-1]] for ci in color_ix])
  608. fb_ci_pairs = [(fb_ix[np.where(fb_ix < ci)[0][-1]], ci) for ci in color_ix]
  609. return d, fbix, fb_ci_pairs
  610. def generate_fb_session_list(d):
  611. CFG_RE = re.compile(r"^(\d+-\d+-\d+)/config_dump_(\d+_\d+_\d+)\.yaml$")
  612. DTA_RE = re.compile(r"^\d+-\d+-\d+/data_(\d+_\d+_\d+)\.bin$")
  613. (_, fbix, _) = get_fb_sessions_before_speller(d)
  614. fb_list_d = []
  615. for fbi in fbix:
  616. s = d.loc[fbi]
  617. cfm = CFG_RE.match(d.loc[fbi].cfg)
  618. dtm = DTA_RE.match(d.loc[fbi].data)
  619. line_d = {'day': cfm[1], 'cfg_t': cfm[2], 'data_t': dtm[1]}
  620. fb_list_d.append(line_d)
  621. return fb_list_d
  622. def rand_jitter(arr):
  623. stdev = .005 * (max(arr) - min(arr))
  624. return arr + np.random.randn(len(arr)) * stdev
  625. def speller_performance_by_nf(d):
  626. *_, fb_ci_pairs = get_fb_sessions_before_speller(d)
  627. # acc_n = list(map(lambda p: (d.iloc[p[0]]['acc'], d.iloc[p[1]]['n']), fb_ci_pairs))
  628. acc_i_n = [(d.iloc[p[0]]['d_since_impl'], d.iloc[p[0]]['acc'], int(d.iloc[p[1]]['intelligible']), d.iloc[p[1]]['n']) for p in fb_ci_pairs]
  629. df = pd.DataFrame(acc_i_n, columns=['DSI', 'Acc', 'Int', 'N'])
  630. spc_acc_int = stats.spearmanr(df['Acc'], df['Int'])
  631. spc_acc_n = stats.spearmanr(df['Acc'], df['N'])
  632. res_text = f"""
  633. Correlation between Neurofeedback task accuracy and subsequent spelling.
  634. There were {len(df)} pairs of speller blocks and preceding neurofeedback blocks. The speller output was rated
  635. 0 for unintelligble, 1 for partially intelligible, 2 for intelligible. The Spearman correlation between Neurofeedback
  636. task accuracy and speller intelligibility was {spc_acc_int.correlation:4.3f} (p={spc_acc_int.pvalue:4.3e}).
  637. The Spearman correlation between Neurofeedback accuracy and number of letters spelled was
  638. {spc_acc_n.correlation:4.3f} (p={spc_acc_n.pvalue:4.3e}).
  639. """
  640. logger.info(res_text)
  641. return res_text
  642. def plot_audio_feedback_tpfp(d):
  643. """
  644. Finds all audio feedback sessions before speller.
  645. It then calculates true positive rates and false positive rates and plots
  646. each of these sessions in a scatter plot.
  647. """
  648. (_, fbix, _) = get_fb_sessions_before_speller(d)
  649. BASE_PATH_OUT.mkdir(parents=True, exist_ok=True)
  650. save_name = BASE_PATH_OUT / "Figure_3B_TPFP"
  651. fig = plt.figure(2, figsize=(15, 6))
  652. fig.clf()
  653. ax = fig.subplots()
  654. ax.scatter(rand_jitter(d.loc[fbix, 'fpr']), rand_jitter(d.loc[fbix, 'tpr']))
  655. ax.plot([0, 1], [0, 1], 'k:')
  656. ax.set_ylabel('True Positive Rate')
  657. ax.set_xlabel('False Positive Rate')
  658. ax.set_aspect('equal')
  659. fig.savefig(save_name.with_suffix(".pdf"))
  660. logger.info(f"Plot saved at <{save_name.with_suffix('.pdf')}>")
  661. fig.savefig(save_name.with_suffix(".eps"))
  662. logger.info(f"Plot saved at <{save_name.with_suffix('.eps')}>")
  663. fig.savefig(save_name.with_suffix(".svg"))
  664. logger.info(f"Plot saved at <{save_name.with_suffix('.svg')}>")
  665. ctsum = d.loc[fbix, 'ct'].sum()
  666. cond_sums = ctsum[ctsum.index != 'unclassified'].sum()
  667. n_trials = ctsum.sum().sum()
  668. n_timeout = ctsum[ctsum.index == 'unclassified'].sum().sum()
  669. n_correct = (ctsum.loc['up', 'up'] + ctsum.loc['down', 'down'])
  670. n_up_incorrect = ctsum.loc['down', 'up']
  671. n_down_incorrect = ctsum.loc['up', 'down']
  672. n_up = ctsum.sum()['up']
  673. n_down = ctsum.sum()['down']
  674. r_up_incorrect = n_up_incorrect / n_up
  675. r_down_incorrect = n_down_incorrect / n_down
  676. tpr = ctsum.loc['up', 'up'] / cond_sums['up']
  677. fpr = ctsum.loc['up', 'down'] / cond_sums['down']
  678. acc = n_correct / n_trials
  679. acc_thr = 0.8
  680. fraction_less_than_80 = (d.loc[fbix, 'acc'] < acc_thr).sum() / len(fbix)
  681. acc_min = 100.0 * d.loc[fbix, 'acc'].min()
  682. acc_max = 100.0 * d.loc[fbix, 'acc'].max()
  683. res_str = f"""Contingency table (columns = conditions, rows = observations):\n{ctsum}\n
  684. There were {len(fbix)} sessions. In total, there were {n_trials} trials.
  685. The accuracy over all trials was {100.0 * acc:4.1f}% (n={n_correct}). There were {n_timeout} timeout trials ({100.0 * n_timeout / n_trials:4.1f}%).
  686. There were {n_up_incorrect} ({100.0 * r_up_incorrect:4.1f}%) incorrect 'up' trials and {n_down_incorrect} ({100.0 * r_down_incorrect:4.1f}%) incorrect 'down' trials.\n
  687. In the last feedback sessions before speller sessions, the median accuracy was {100.0 * d.loc[fbix, 'acc'].median():4.1f}%,
  688. the minimum was {acc_min:4.1f}%. In {100.0 * fraction_less_than_80:4.1f}% of the sessions, accuracy was below {100.0 * acc_thr:4.1f}%.
  689. """
  690. logger.info(res_str)
  691. # Export as CSV
  692. d_filt = d.loc[fbix].copy()
  693. for i, row in d_filt.iterrows():
  694. d_filt.at[i, 'up_up'] = row.ct.loc['up', 'up']
  695. d_filt.at[i, 'up_down'] = row.ct.loc['down', 'up']
  696. d_filt.at[i, 'up_unclassified'] = row.ct.loc['unclassified', 'up']
  697. d_filt.at[i, 'down_up'] = row.ct.loc['up', 'down']
  698. d_filt.at[i, 'down_down'] = row.ct.loc['down', 'down']
  699. d_filt.at[i, 'down_unclassified'] = row.ct.loc['unclassified', 'down']
  700. d_filt[
  701. ['d_since_impl', 'start_dt', 'duration_s', 'channels', 'mode', 'data', 'tpr', 'fpr', 'acc', 'up_up', 'up_down',
  702. 'up_unclassified', 'down_up', 'down_down', 'down_unclassified']].to_csv(save_name.with_suffix(".csv"))
  703. def plot_fb_accuracy(d):
  704. """
  705. Plots all feedback sessions' accuracy as function of day.
  706. This generates Supplementary Figure 2.
  707. """
  708. from brokenaxes import brokenaxes
  709. from matplotlib.gridspec import GridSpec
  710. BASE_PATH_OUT.mkdir(parents=True, exist_ok=True)
  711. save_name = BASE_PATH_OUT / "Figure_S2_fb_acc"
  712. d_filt = d[d['mode'] == 'feedback'].copy()
  713. d_filt['ddiff'] = (d_filt['start_dt'] - IMPLANT_DATE) / pd.to_timedelta(1, unit='D')
  714. # accv = d_filt[['d_since_impl', 'acc', 'start_dt', 'ddiff']]
  715. n_fb_sessions = d_filt.groupby('d_since_impl').count()['acc']
  716. n_fb_sessions.name = 'n_fb'
  717. # Get fb before speller
  718. (_, fbix, _) = get_fb_sessions_before_speller(d, ignore_sessions_before_fb_change=False)
  719. # d_nf = d.loc[fbix].copy()
  720. # d_nf['ddiff'] = (d_nf['start_dt'] - IMPLANT_DATE) / pd.to_timedelta(1, unit='D')
  721. d_filt['b4sp'] = 0
  722. d_filt.loc[fbix, 'b4sp'] = 1
  723. fig = plt.figure(22, figsize=(15, 8), constrained_layout=False)
  724. fig.clf()
  725. gs = GridSpec(ncols=2, nrows=2, figure=fig, hspace=.25, top=.9, bottom=.1)
  726. bax = fig.add_subplot(gs[0, :])
  727. # bax = brokenaxes(xlims=x_ranges, hspace=.05, tilt=65, d=.005, subplot_spec=gs[0, :])
  728. bax.plot(d_filt['ddiff'] - .5, d_filt['acc'], 'o', color=(.5, .5, .5), ms=2, label="NF Trial sets")
  729. accvmax = d_filt.groupby('d_since_impl')[['acc']].max()
  730. # bax.plot(accvmax.index, accvmax['acc'], '-', color=(0, 0, 0), lw=2, label="Daily maximum")
  731. d_b4sp = d_filt[d_filt['b4sp'] > 0]
  732. bax.plot(d_b4sp['ddiff'] - .5, d_b4sp['acc'], 'o', color=(1, 0, 0), ms=3, lw=2, label="NF before speller")
  733. bax.set_ylim(-0.02, 1.02)
  734. bax.set_xlabel("Days after implantation")
  735. bax.set_ylabel("Accuracy")
  736. bax.set_title("Accuracy of Feedback trials")
  737. bax.legend(loc="lower right")
  738. dsi = 113
  739. d_sub = d[(d['d_since_impl'] == dsi) & (d['mode'] == 'feedback')]
  740. d_sub.reset_index(inplace=True)
  741. d_sub.index = d_sub.index + 1
  742. f_ax1 = fig.add_subplot(gs[1, 0])
  743. f_ax1.plot((d_sub['start_dt'] - d_sub.loc[1, 'start_dt']).astype('timedelta64[s]') / 60, d_sub['acc'], 'o-k')
  744. f_ax1.set_ylim(-0.05, 1.05)
  745. f_ax1.set_xlabel("Time since start [min]")
  746. f_ax1.set_ylabel("Accuracy")
  747. f_ax1.set_title(f"Feedback trial sets on day {dsi} post-implantation")
  748. dsi = 197
  749. d_sub = d[(d['d_since_impl'] == dsi) & (d['mode'] == 'feedback')]
  750. d_sub.reset_index(inplace=True)
  751. d_sub.index = d_sub.index + 1
  752. f_ax2 = fig.add_subplot(gs[1, 1])
  753. f_ax2.plot((d_sub['start_dt'] - d_sub.loc[1, 'start_dt']).astype('timedelta64[s]') / 60, d_sub['acc'], 'o-k')
  754. f_ax2.set_ylim(-0.05, 1.05)
  755. f_ax2.set_xlabel("Time since start [min]")
  756. f_ax2.set_ylabel("Accuracy")
  757. f_ax2.set_title(f"Feedback trial sets on day {dsi} post-implantation")
  758. fig.savefig(save_name.with_suffix(".pdf"))
  759. logger.info(f"Plot saved at <{save_name.with_suffix('.pdf')}>")
  760. fig.savefig(save_name.with_suffix(".eps"))
  761. logger.info(f"Plot saved at <{save_name.with_suffix('.eps')}>")
  762. fig.savefig(save_name.with_suffix(".svg"))
  763. logger.info(f"Plot saved at <{save_name.with_suffix('.svg')}>")
  764. thr = .9
  765. descstr = f"""
  766. The best run per day had an accuracy of {thr:.1f} in {(accvmax['acc'] > thr).sum() / len(accvmax) * 100.0:.1f} % of days.
  767. Over the reported period, there were {n_fb_sessions.sum()} feedback sessions, min/median/maximum per day:
  768. {n_fb_sessions.min()} / {n_fb_sessions.median()} / {n_fb_sessions.max()}.
  769. Over all {d.loc[d['mode'] == 'feedback', 'acc'].count()} feedback sessions, the accuracy was {100*d.loc[d['mode'] == 'feedback', 'acc'].mean():.1f}%.
  770. """
  771. logger.info(descstr)
  772. # Export as CSV
  773. d_filt.loc[:, 'channels1'] = None
  774. for i, row in d_filt.iterrows():
  775. d_filt.at[i, 'up_up'] = row.ct.loc['up', 'up']
  776. d_filt.at[i, 'up_down'] = row.ct.loc['down', 'up']
  777. d_filt.at[i, 'up_unclassified'] = row.ct.loc['unclassified', 'up']
  778. d_filt.at[i, 'down_up'] = row.ct.loc['up', 'down']
  779. d_filt.at[i, 'down_down'] = row.ct.loc['down', 'down']
  780. d_filt.at[i, 'down_unclassified'] = row.ct.loc['unclassified', 'down']
  781. d_filt.at[i, 'channels1'] = ", ".join([str(x) for x in row.channels])
  782. d_filt[
  783. ['d_since_impl', 'start_dt', 'channels1', 'data', 'acc', 'b4sp',
  784. 'up_up', 'up_down', 'up_unclassified', 'down_up', 'down_down',
  785. 'down_unclassified']].to_csv(save_name.with_suffix(".csv"))
  786. return fig, bax, f_ax1, f_ax2, d_filt
  787. def plot_channels_used(d):
  788. """Plots for each speller block the channel that was used"""
  789. BASE_PATH_OUT.mkdir(parents=True, exist_ok=True)
  790. save_name = BASE_PATH_OUT / "Figure_S_channel_use"
  791. color_list = [(0.8, .8, .8), (.55, .55, .55), (.3, .3, .3), (0, 0, 0)]
  792. cmap = ListedColormap(color_list, len(color_list))
  793. df_blocks = d[d.channels.notna() & (d['mode'] == 'color')]
  794. ch_list = list(df_blocks.channels)
  795. ch_list_merged = list(itertools.chain(*ch_list))
  796. unique_channels = sorted(list(set(ch_list_merged)))
  797. ch_use_matrix = np.empty((len(unique_channels), len(df_blocks)))
  798. ch_use_matrix[:] = np.nan
  799. ch_count = np.zeros((len(unique_channels), len(df_blocks)))
  800. c_idx = {r: i for i, r in enumerate(unique_channels)}
  801. day_blocks = pd.DataFrame(index=df_blocks.d_since_impl.unique(), columns=unique_channels)
  802. day_blocks.fillna(0, inplace=True)
  803. b_d_s_i = df_blocks.d_since_impl
  804. b_d_s_i_change = (b_d_s_i.diff()).isna() | (b_d_s_i.diff() > 0)
  805. b_d_s_i_change = b_d_s_i_change.reset_index()
  806. tick_loc = b_d_s_i_change.index[b_d_s_i_change.d_since_impl].to_list()
  807. tick_lab = list(b_d_s_i[list(b_d_s_i_change.d_since_impl)])
  808. j = 0
  809. for i, r in df_blocks.iterrows():
  810. col_val = len(r['channels'])
  811. # if r['mode'] == 'color':
  812. # pass
  813. # col_val += 5
  814. for ch in r['channels']:
  815. ch_use_matrix[c_idx[ch], j] = col_val
  816. day_blocks.loc[r['d_since_impl'], ch] = 1
  817. ch_count[c_idx[ch], j] = 1
  818. j += 1
  819. # This is sorting by number of blocks. For consistency, number of days is better (below)
  820. # sort_idx = np.argsort(ch_count.sum(1))[::-1]
  821. channel_days = day_blocks.sum()
  822. sort_idx = list(channel_days.argsort()[::-1])
  823. sorted_channels = np.asarray(unique_channels)[sort_idx]
  824. channel_days.sort_values(ascending=False, inplace=True)
  825. # create array map with channel use counts
  826. array_map = ARRAY_MAPS['K01']
  827. map_sma = array_map[0]
  828. amesh_sma = np.empty((np.max(map_sma['x']) + 1, np.max(map_sma['y']) + 1))
  829. amesh_sma[:] = np.nan
  830. for index, el_ix in enumerate(map_sma['ix']):
  831. if el_ix + 1 in channel_days.index:
  832. amesh_sma[7 - map_sma['y'][index], map_sma['x'][index]] = channel_days[el_ix + 1]
  833. norm = mpl.colors.BoundaryNorm(np.arange(-.5+1, cmap.N+1), cmap.N)
  834. my_cmap = plt.get_cmap("plasma")
  835. rescale = lambda y: (y - np.min(y)) / (np.max(y) - np.min(y))
  836. # plot number of channels used per block, over days
  837. fig = plt.figure(constrained_layout=True)
  838. gs = fig.add_gridspec(2, 3)
  839. # fig, axs = plt.subplots(2, 1, constrained_layout=True)
  840. ax = fig.add_subplot(gs[0, :-1])
  841. imgplt = ax.pcolormesh(ch_use_matrix[sort_idx, :], cmap=cmap, norm=norm)
  842. ax.set_yticks(np.arange(0, len(unique_channels)) + .5)
  843. ax.set_yticklabels(sorted_channels)
  844. ax.set_ylabel('Channel ID')
  845. ax.set_xlabel('Days after implantation')
  846. ax.yaxis.set_inverted(True)
  847. ax.xaxis.set_ticks(tick_loc, minor=True)
  848. ax.set_xticks(tick_loc[::20])
  849. ax.set_xticklabels([f"{x}" for x in tick_lab[::20]])
  850. cbar = fig.colorbar(imgplt, ax=ax, location='right', ticks=range(1, cmap.N + 1))
  851. # Plotting the histogram for channels over days
  852. ax = fig.add_subplot(gs[1, :-1])
  853. ax.bar(range(len(channel_days)), channel_days, color=my_cmap(rescale(channel_days)), linewidth=.3, edgecolor='k')#, color='k')
  854. ax.set_xticks(range(len(channel_days)))
  855. ax.set_xticklabels(channel_days.index)
  856. ax.set_ylabel('used on number of days')
  857. ax.set_xlabel('Channel ID')
  858. ax = fig.add_subplot(gs[-1, -1])
  859. rect = plt.Rectangle((0, 0), 8, 8, fill=True, facecolor=(.3, .3, .3), edgecolor=(0, 0, 0), alpha=0.2, zorder=-1)
  860. ax.add_patch(rect)
  861. ax.pcolormesh(amesh_sma, cmap=my_cmap)
  862. ax.set_xlim(-1, 9)
  863. ax.set_ylim(-1, 9)
  864. ax.set_aspect(1)
  865. ax.axis('off')
  866. fig.savefig(save_name.with_suffix(".pdf"))
  867. logger.info(f"Plot saved at <{save_name.with_suffix('.pdf')}>")
  868. fig.savefig(save_name.with_suffix(".eps"))
  869. logger.info(f"Plot saved at <{save_name.with_suffix('.eps')}>")
  870. fig.savefig(save_name.with_suffix(".svg"))
  871. logger.info(f"Plot saved at <{save_name.with_suffix('.svg')}>")
  872. # save as csv
  873. ch_exp = ch_use_matrix[sort_idx, :]
  874. ch_exp[ch_exp > 0] = 1
  875. df_ch_exp = pd.DataFrame(data=ch_exp.T, columns=[f"ch_{c}" for c in sorted_channels], index=df_blocks.index)
  876. df_ch_exp['n_ch_used'] = df_ch_exp.sum(axis=1)
  877. df_ch_exp['d_since_impl'] = df_blocks.d_since_impl
  878. df_ch_exp.reset_index(drop=True, inplace=True)
  879. df_ch_exp[df_ch_exp.isna()] = 0
  880. df_ch_exp = df_ch_exp.astype('int')
  881. print(f"Number of channels used in speller blocks:\n{df_ch_exp[['n_ch_used']].groupby(['n_ch_used']).apply(lambda x: len(x))}")
  882. df_ch_exp.to_csv(save_name.with_suffix('._panel_a.csv'))
  883. data_panel_b = pd.DataFrame.from_dict({'days_used': channel_days})
  884. data_panel_b.to_csv(save_name.with_suffix('._panel_b.csv'))
  885. def calculate_ITR(d):
  886. # define ITR calculation helper
  887. def itr_helper(r):
  888. # number of selectable characters (26 letters, space, delete, question mark, end program)
  889. N = 30
  890. # remove substring from rating corresponding to the phrase_start
  891. rate_str = r.rating[len(r.phrase_start):]
  892. # calculate rate of correct characters
  893. P = rate_str.count('T') / (rate_str.count('T') + rate_str.count('F'))
  894. # calculate ITR in bits / min
  895. return (np.log2(N) + P * np.log2(P) + (1 - P) * np.log2((1 - P + 1e-6) / (N - 1))) * r.ch_per_min
  896. # select all rows with non-empty ratings
  897. dfwr = d.loc[(~d.rating.isna()) ]
  898. itr = dfwr.apply(itr_helper, axis=1)
  899. res = f"""There were {len(itr)} rated speller sessions. The min/mean/median/max ITR was {itr.min():.3f} / {itr.mean():.3f} / {itr.median():.3f} / {itr.max():.3f} bit per minute.\n"""
  900. print(res)
  901. if __name__ == '__main__':
  902. pth = BASE_PATH/'KIAP_BCI_neurofeedback'
  903. df = prepare_sessions(pth=pth)
  904. df = add_feedback_info(df, pth=pth)
  905. df = add_channel_info(df, pth=pth)
  906. pths = BASE_PATH/'KIAP_BCI_speller'
  907. ds = prepare_sessions(pth=pths)
  908. ds = add_feedback_info(ds, pth=pths)
  909. ds = add_channel_info(ds, pth=pths)
  910. d = pd.concat((df, ds), ignore_index=True)
  911. d.sort_values('start_dt', inplace=True)
  912. d.reset_index(drop=True, inplace=True)
  913. add_annotation(d)
  914. plot_sessions(d)
  915. plot_audio_feedback_tpfp(d)
  916. plot_fb_accuracy(d)
  917. plot_channels_used(d)
  918. speller_performance_by_nf(d)
  919. calculate_ITR(d)