03_get_neural_rdms.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. # -*- coding: utf-8 -*-
  2. # %% general setup
  3. # import libraries
  4. import os
  5. import os.path as op
  6. from string import ascii_lowercase
  7. import re
  8. import numpy as np
  9. # import scipy
  10. import pandas as pd
  11. import mne
  12. import rsatoolbox
  13. import matplotlib.pyplot as plt
  14. import matplotlib as mpl
  15. from mpl_toolkits.axes_grid1 import make_axes_locatable
  16. fig_font_size = 8
  17. plt.rcParams.update({
  18. "font.family": "Helvetica",
  19. 'font.size': fig_font_size
  20. })
  21. # from tqdm import tqdm
  22. beh_path = op.join('beh', 'data')
  23. eeg_path = 'eeg'
  24. epo_path = 'epo'
  25. # get list of all subject IDs from behavioural file names
  26. plist = pd.read_csv(op.join('eeg', 'participants.tsv'), delimiter='\t')
  27. plist = plist.loc[plist.recording_restarted == 0]
  28. all_subj_ids = plist.participant_id
  29. # the list used to order the similarity matrices
  30. chars = [*ascii_lowercase, 'ä', 'ö', 'ü', 'ß']
  31. # %% import all data
  32. # get all participants' data
  33. all_epos = [mne.read_epochs(os.path.join(epo_path, f'{subj_id}-epo.fif')) for subj_id in all_subj_ids]
  34. # remove non-eeg channels
  35. all_epos = [e.pick(['eeg']) for e in all_epos]
  36. # for each participant get an array with all stimulus IDs
  37. all_epo_labs = [np.array(e.metadata.stimulus) for e in all_epos]
  38. # list of all posterior electrodes
  39. post_chs = ['TP9', 'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'TP8', 'TP10',
  40. 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8',
  41. 'PO7', 'PO3', 'POz', 'PO4', 'PO8',
  42. 'PO9', 'O1', 'Oz', 'O2', 'PO10']
  43. epos_concat = mne.concatenate_epochs(all_epos)
  44. # %%
  45. # for illustration
  46. fig = epos_concat.average().plot(picks = post_chs, time_unit='ms', highlight=(150, 225), show=False, selectable=False);
  47. fig.set_size_inches(2.4, 0.75)
  48. fig.savefig(op.join('fig', 'ERP.svg'))
  49. plt.close()
  50. eg_chars = chars
  51. # plot the patterns for one participant
  52. eg_participant_idx = 1
  53. # for scaling the colour bar
  54. min_avg = np.min([all_epos[eg_participant_idx][c].get_data(picks=post_chs, tmin=.150, tmax=.225).mean(axis=(0,2)) for c in eg_chars])
  55. max_avg = np.max([all_epos[eg_participant_idx][c].get_data(picks=post_chs, tmin=.150, tmax=.225).mean(axis=(0,2)) for c in eg_chars])
  56. abs_max = max([abs(min_avg), abs(max_avg)])
  57. n_chs = len(all_epos[0].info['ch_names'])
  58. rdbu_cmap = mpl.colormaps['RdBu_r']
  59. cb_fig, ax = plt.subplots(figsize=(1.7, .075), layout='constrained')
  60. cb_fig.colorbar(mpl.cm.ScalarMappable(
  61. norm=mpl.colors.Normalize(vmin=-abs_max * 1e6, vmax=abs_max * 1e6),
  62. cmap=rdbu_cmap),
  63. cax=ax, orientation='horizontal', label='µV',
  64. # ticks=[-3,-2,-1,0,1,2,3]
  65. ticks=[-8,-4,0,4,8]
  66. )
  67. cb_fig.savefig(op.join('fig', 'pattern_examples', 'colorbar.svg'), bbox_inches='tight', pad_inches=0)
  68. plt.close()
  69. for c in eg_chars:
  70. epo_dat = all_epos[eg_participant_idx][c].get_data(picks='eeg', tmin=.150, tmax=.225)
  71. epo_avg_per_ch = epo_dat.mean(axis = (0,2))
  72. # ch_grps = [ch if epos_concat.info['ch_names'][ch] in post_chs
  73. # else np.nan
  74. # for ch in range(len(epos_concat.info['ch_names']))]
  75. ch_grps = [[i] for i in range(n_chs)]
  76. ch_vals = np.array([epo_avg_per_ch[i] for i in range(n_chs)])
  77. ch_vals_scaled = (ch_vals + abs_max) / (abs_max*2)
  78. cmap_custom = mpl.colors.ListedColormap([rdbu_cmap(x) for x in ch_vals_scaled])
  79. # twilight.set_bad('white', 0)
  80. point_sizes = [125 if epos_concat.info.ch_names[i] in post_chs else 0 for i in range(n_chs)]
  81. ch_epo_fig = mne.viz.plot_sensors(epos_concat.info, ch_groups=ch_grps, cmap=cmap_custom, pointsize=point_sizes, linewidth=0)
  82. ch_epo_fig.set_size_inches((2, 2))
  83. ch_epo_fig.savefig(op.join('fig', 'pattern_examples', f'{c}.svg'), bbox_inches='tight', pad_inches=0)
  84. plt.close()
  85. # illustrate correlation distance RDMs for all participants
  86. eg_rdms = []
  87. for p in range(len(all_epos)):
  88. vecs = [all_epos[p][c].get_data(picks='eeg', tmin=.150, tmax=.225).mean(axis = (0,2)) for c in chars]
  89. eg_rdm = np.zeros((len(chars), len(chars)))
  90. for i in range(len(chars)):
  91. for j in range(len(chars)):
  92. eg_rdm[i, j] = 1 - np.corrcoef(x=vecs[i], y=vecs[j])[0, 1]
  93. eg_rdms.append(eg_rdm)
  94. rank_eg_rdms = []
  95. for p in range(len(all_epos)):
  96. eg_rdm = eg_rdms[p]
  97. eg_rdm_tril = eg_rdm[np.tril_indices(n=eg_rdm.shape[0], k=-1)]
  98. eg_rdm_tril_rank = eg_rdm_tril.argsort().argsort() + 1
  99. rank_eg_rdm = np.zeros((len(chars), len(chars)))
  100. rank_eg_rdm[np.tril_indices(n=eg_rdm.shape[0], k=-1)] = eg_rdm_tril_rank
  101. rank_eg_rdm[np.triu_indices(n=eg_rdm.shape[0], k=0)] = rank_eg_rdm.T[np.triu_indices(n=eg_rdm.shape[0], k=0)]
  102. rank_eg_rdm[np.diag_indices(n=rank_eg_rdm.shape[0])] = np.nan
  103. rank_eg_rdms.append(rank_eg_rdm)
  104. rdm_cmap = 'viridis'
  105. for p in range(len(all_epos)):
  106. eg_rdm = eg_rdms[p]
  107. fig, ax = plt.subplots(1, 1, figsize=(.5, .5))
  108. plt.imshow(eg_rdm, interpolation='none', cmap=rdm_cmap)
  109. # plt.imshow(eg_rdm, interpolation='none', cmap=rdm_cmap, vmin=0, vmax=2)
  110. ax.set_xticks([])
  111. ax.set_yticks([])
  112. fig.savefig(op.join('fig', 'rdm_examples', f'rdm_{p}.svg'), bbox_inches='tight', pad_inches=0)
  113. plt.close()
  114. rank_eg_rdm = rank_eg_rdms[p]
  115. fig, ax = plt.subplots(1, 1, figsize=(.5, .5))
  116. plt.imshow(rank_eg_rdm, interpolation='none', cmap=rdm_cmap, vmin=1, vmax=435)
  117. ax.set_xticks([])
  118. ax.set_yticks([])
  119. fig.savefig(op.join('fig', 'rank_rdm_examples', f'rank_rdm_{p}.svg'), bbox_inches='tight', pad_inches=0)
  120. plt.close()
  121. cb_fig, ax = plt.subplots(figsize=(0.5, .075), layout='constrained')
  122. cb_fig.colorbar(mpl.cm.ScalarMappable(
  123. norm=mpl.colors.Normalize(vmin=0, vmax=2),
  124. cmap=rdm_cmap),
  125. cax=ax, orientation='horizontal', label='1-r',
  126. # ticks=[-3,-2,-1,0,1,2,3]
  127. ticks=[]
  128. )
  129. cb_fig.savefig(op.join('fig', 'rdm_examples', 'colorbar.svg'), bbox_inches='tight', pad_inches=0)
  130. plt.close()
  131. cb_fig, ax = plt.subplots(figsize=(0.5, .075), layout='constrained')
  132. cb_fig.colorbar(mpl.cm.ScalarMappable(
  133. norm=mpl.colors.Normalize(vmin=1, vmax=435),
  134. cmap=rdm_cmap),
  135. cax=ax, orientation='horizontal', label='Rank',
  136. ticks=[]
  137. )
  138. cb_fig.savefig(op.join('fig', 'rank_rdm_examples', 'colorbar.svg'), bbox_inches='tight', pad_inches=0)
  139. plt.close()
  140. # %% RSA toolbox method
  141. r_comb_ot = []
  142. r_comb_jacc = []
  143. r_comb_complexity = []
  144. r_comb_ot_poi = []
  145. r_comb_jacc_poi = []
  146. rdms_data_list = []
  147. rdms_data_poi_list = []
  148. rdms_data_p1_list = []
  149. rdms_data_list_all_chs = []
  150. rdms_data_poi_list_all_chs = []
  151. rdms_data_p1_list_all_chs = []
  152. sim_path = 'stim_sim'
  153. # sim_path_jacc = op.join(sim_path, 'jacc')
  154. # sim_path_ot = op.join(sim_path, 'ot')
  155. sim_path_prereg = op.join(sim_path, 'preregistered')
  156. jacc_dissim_mat = np.load(op.join(sim_path_prereg, 'jacc.npy'))
  157. jacc_vec = jacc_dissim_mat[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)]
  158. jacc_mod = rsatoolbox.rdm.RDMs(
  159. dissimilarities = jacc_dissim_mat[np.newaxis, :, :],
  160. dissimilarity_measure = 'jaccard',
  161. descriptors = {
  162. 'index': np.arange(len(chars)),
  163. 'stimulus': chars
  164. }
  165. )
  166. ot_dissim_mat = np.load(op.join(sim_path_prereg, 'ot.npy'))
  167. ot_vec = ot_dissim_mat[np.tril_indices(ot_dissim_mat.shape[0], k=-1)]
  168. ot_mod = rsatoolbox.rdm.RDMs(
  169. dissimilarities = ot_dissim_mat[np.newaxis, :, :],
  170. dissimilarity_measure = 'optimal transport',
  171. descriptors = {
  172. 'index': np.arange(len(chars)),
  173. 'stimulus': chars
  174. }
  175. )
  176. sim_path_complexity = op.join(sim_path, 'complexity')
  177. complexity_dissim_mat = np.load(op.join(sim_path_complexity, 'complexity.npy'))
  178. complexity_vec = complexity_dissim_mat[np.tril_indices(complexity_dissim_mat.shape[0], k=-1)]
  179. complexity_mod = rsatoolbox.rdm.RDMs(
  180. dissimilarities = complexity_dissim_mat[np.newaxis, :, :],
  181. dissimilarity_measure = 'complexity',
  182. descriptors = {
  183. 'index': np.arange(len(chars)),
  184. 'stimulus': chars
  185. }
  186. )
  187. out_path = 'rdm_data'
  188. n1_period = [0.15, 0.225]
  189. p1_period = [0.080, 0.130]
  190. for s, subj_id in enumerate(all_subj_ids):
  191. print(f'Participant {subj_id} ({s+1}/{len(all_subj_ids)})')
  192. epo = all_epos[s].pick('eeg')
  193. # epo = all_epos[s]
  194. epo_labs = all_epo_labs[s]
  195. # get the similarities for the period of interest
  196. epo_poi = epo.copy().get_data(picks = post_chs, tmin=n1_period[0], tmax=n1_period[1])
  197. # reshape into n_obs * n_chs
  198. epo_poi_resh = np.zeros((epo_poi.shape[0] * epo_poi.shape[2], epo_poi.shape[1]))
  199. epo_labs_resh = np.repeat(epo_labs, epo_poi.shape[2])
  200. for ch in range(epo_poi.shape[1]):
  201. # for each channel, get data from all timepoints for epoch 1, then for epoch 2, epoch 3, etc.
  202. epo_poi_resh[:, ch] = epo_poi[:, ch, :].flatten()
  203. data_poi = rsatoolbox.data.Dataset(
  204. epo_poi_resh,
  205. channel_descriptors={'names': epo.copy().pick(post_chs).info['ch_names']},
  206. obs_descriptors={'stimulus': epo_labs_resh}
  207. )
  208. # get the equivalent for all EEG channels (not just the posterior region)
  209. epo_poi_all_chs = epo.copy().get_data(picks = 'eeg', tmin=n1_period[0], tmax=n1_period[1])
  210. epo_poi_all_chs_resh = np.zeros((epo_poi_all_chs.shape[0] * epo_poi_all_chs.shape[2], epo_poi_all_chs.shape[1]))
  211. for ch in range(epo_poi_all_chs.shape[1]):
  212. epo_poi_all_chs_resh[:, ch] = epo_poi_all_chs[:, ch, :].flatten()
  213. data_poi_all_chs = rsatoolbox.data.Dataset(
  214. epo_poi_all_chs_resh,
  215. channel_descriptors={'names': epo.info['ch_names']},
  216. obs_descriptors={'stimulus': epo_labs_resh} # can reuse this as N trials doesn't change
  217. )
  218. # get the similarities for the exploratory P1 period
  219. epo_p1 = epo.copy().get_data(picks = post_chs, tmin=p1_period[0], tmax=p1_period[1])
  220. # reshape into n_obs * n_chs
  221. epo_p1_resh = np.zeros((epo_p1.shape[0] * epo_p1.shape[2], epo_p1.shape[1]))
  222. epo_labs_resh_p1 = np.repeat(epo_labs, epo_p1.shape[2])
  223. for ch in range(epo_p1.shape[1]):
  224. # for each channel, get data from all timepoints for epoch 1, then for epoch 2, epoch 3, etc.
  225. epo_p1_resh[:, ch] = epo_p1[:, ch, :].flatten()
  226. data_p1 = rsatoolbox.data.Dataset(
  227. epo_p1_resh,
  228. channel_descriptors={'names': epo.copy().pick(post_chs).info['ch_names']},
  229. obs_descriptors={'stimulus': epo_labs_resh_p1}
  230. )
  231. # calculate RDMs and Rs
  232. rdms_data_poi = rsatoolbox.rdm.calc_rdm(data_poi, method='correlation', descriptor='stimulus')
  233. rdms_data_poi_all_chs = rsatoolbox.rdm.calc_rdm(data_poi_all_chs, method='correlation', descriptor='stimulus')
  234. rdms_data_p1 = rsatoolbox.rdm.calc_rdm(data_p1, method='correlation', descriptor='stimulus')
  235. jacc_r_poi = rsatoolbox.rdm.compare(jacc_mod, rdms_data_poi, method='rho-a')
  236. ot_r_poi = rsatoolbox.rdm.compare(ot_mod, rdms_data_poi, method='rho-a')
  237. jacc_r_p1 = rsatoolbox.rdm.compare(jacc_mod, rdms_data_p1, method='rho-a')
  238. ot_r_p1 = rsatoolbox.rdm.compare(ot_mod, rdms_data_p1, method='rho-a')
  239. r_comb_ot_poi.append(ot_r_poi)
  240. r_comb_jacc_poi.append(jacc_r_poi)
  241. # save to csv for later model
  242. mat_poi = rdms_data_poi.get_matrices().squeeze()
  243. vec_poi = mat_poi[np.tril_indices(mat_poi.shape[0], k=-1)]
  244. char1_ids = np.array(chars)[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)[0]]
  245. char2_ids = np.array(chars)[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)[1]]
  246. poi_df = pd.DataFrame({
  247. 'subj_id': subj_id,
  248. 'char1': char1_ids,
  249. 'char2': char2_ids,
  250. # 'ot_dist': ot_vec,
  251. # 'jacc_dissim': jacc_vec,
  252. 'eeg_dissim': vec_poi
  253. })
  254. poi_data_filepath = op.join(out_path, 'period_of_interest', f'{subj_id}.csv')
  255. poi_df.to_csv(poi_data_filepath, index=False)
  256. # also save the results for all channels
  257. mat_poi_all_chs = rdms_data_poi_all_chs.get_matrices().squeeze()
  258. vec_poi_all_chs = mat_poi_all_chs[np.tril_indices(mat_poi_all_chs.shape[0], k=-1)]
  259. char1_ids = np.array(chars)[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)[0]]
  260. char2_ids = np.array(chars)[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)[1]]
  261. poi_all_chs_df = pd.DataFrame({
  262. 'subj_id': subj_id,
  263. 'char1': char1_ids,
  264. 'char2': char2_ids,
  265. # 'ot_dist': ot_vec,
  266. # 'jacc_dissim': jacc_vec,
  267. 'eeg_dissim': vec_poi_all_chs
  268. })
  269. poi_data_all_chs_filepath = op.join(out_path, 'period_of_interest_all_chs', f'{subj_id}.csv')
  270. poi_all_chs_df.to_csv(poi_data_all_chs_filepath, index=False)
  271. # also save the P1 results
  272. mat_p1 = rdms_data_p1.get_matrices().squeeze()
  273. vec_p1 = mat_p1[np.tril_indices(mat_p1.shape[0], k=-1)]
  274. char1_ids = np.array(chars)[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)[0]]
  275. char2_ids = np.array(chars)[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)[1]]
  276. p1_df = pd.DataFrame({
  277. 'subj_id': subj_id,
  278. 'char1': char1_ids,
  279. 'char2': char2_ids,
  280. # 'ot_dist': ot_vec,
  281. # 'jacc_dissim': jacc_vec,
  282. 'eeg_dissim': vec_p1
  283. })
  284. p1_data_filepath = op.join(out_path, 'p1_period', f'{subj_id}.csv')
  285. p1_df.to_csv(p1_data_filepath, index=False)
  286. # get the time resolved results
  287. data = rsatoolbox.data.TemporalDataset(
  288. epo.copy().get_data(picks = post_chs),
  289. channel_descriptors={'names': post_chs},
  290. obs_descriptors={'stimulus': epo_labs},
  291. time_descriptors={'time': epo.times}
  292. )
  293. # bins with bin_len time points per bin (first sample, in baseline, will have bin_len+1 time points, because of odd number)
  294. # at 1000 ms, samples and ms are equivalent, so bins are bin_len ms wide
  295. # bin_len = 4
  296. bin_len=10
  297. times_binned = np.array_split(epo.times, len(epo.times)/bin_len)
  298. # rdms_data = rsatoolbox.rdm.calc_rdm_movie(data, method='correlation', descriptor='stimulus', bins=times_binned, noise=noise_prec_shrink)
  299. rdms_data = rsatoolbox.rdm.calc_rdm_movie(data, method='correlation', descriptor='stimulus', bins=times_binned)
  300. jacc_r = rsatoolbox.rdm.compare(jacc_mod, rdms_data, method='rho-a')
  301. ot_r = rsatoolbox.rdm.compare(ot_mod, rdms_data, method='rho-a')
  302. complexity_r = rsatoolbox.rdm.compare(complexity_mod, rdms_data, method='rho-a')
  303. # save to csv for later model
  304. mats_per_tp = list(rdms_data.get_matrices())
  305. vecs_per_tp = [x[np.tril_indices(x.shape[0], k=-1)] for x in mats_per_tp]
  306. tp_dfs = [pd.DataFrame({
  307. 'subj_id': subj_id,
  308. 'char1': char1_ids,
  309. 'char2': char2_ids,
  310. 'time': t,
  311. # 'ot_dist': ot_vec,
  312. # 'jacc_dissim': jacc_vec,
  313. 'eeg_dissim': vecs_per_tp[i]
  314. }) for i, t in enumerate(rdms_data.rdm_descriptors['time'])]
  315. subj_df = pd.concat(tp_dfs)
  316. rdm_data_filepath = op.join(out_path, 'time_resolved', f'{subj_id}.csv')
  317. subj_df.to_csv(rdm_data_filepath, index=False)
  318. # also save time-resolved data for all channels
  319. data_all_chs = rsatoolbox.data.TemporalDataset(
  320. epo.copy().get_data(picks = 'eeg'),
  321. channel_descriptors={'names': epo.info['ch_names']},
  322. obs_descriptors={'stimulus': epo_labs},
  323. time_descriptors={'time': epo.times}
  324. )
  325. rdms_data_all_chs = rsatoolbox.rdm.calc_rdm_movie(data_all_chs, method='correlation', descriptor='stimulus', bins=times_binned)
  326. mats_per_tp_all_chs = list(rdms_data_all_chs.get_matrices())
  327. vecs_per_tp_all_chs = [x[np.tril_indices(x.shape[0], k=-1)] for x in mats_per_tp_all_chs]
  328. tp_dfs_all_chs = [pd.DataFrame({
  329. 'subj_id': subj_id,
  330. 'char1': char1_ids,
  331. 'char2': char2_ids,
  332. 'time': t,
  333. 'eeg_dissim': vecs_per_tp_all_chs[i]
  334. }) for i, t in enumerate(rdms_data_all_chs.rdm_descriptors['time'])]
  335. subj_df_all_chs = pd.concat(tp_dfs_all_chs)
  336. rdm_data_all_chs_filepath = op.join(out_path, 'time_resolved_all_chs', f'{subj_id}.csv')
  337. subj_df_all_chs.to_csv(rdm_data_all_chs_filepath, index=False)
  338. # # get Rs
  339. # r = [scipy.stats.spearmanr(jacc_vec, x).statistic for x in vecs_per_tp]
  340. # Spearman's Rho
  341. r_comb_ot.append(ot_r)
  342. r_comb_jacc.append(jacc_r)
  343. r_comb_complexity.append(complexity_r)
  344. rdms_data_list.append(rdms_data)
  345. rdms_data_poi_list.append(rdms_data_poi)
  346. rdms_data_p1_list.append(rdms_data_p1)
  347. rdms_data_list_all_chs.append(rdms_data_all_chs)
  348. rdms_data_poi_list_all_chs.append(rdms_data_poi_all_chs)
  349. # calculate noise ceiling with average of per-participant leave-one-out cross-validation and full-set comparison
  350. rdms_data_poi_concat = rsatoolbox.rdm.concat(rdms_data_poi_list)
  351. nc = rsatoolbox.inference.noise_ceiling.boot_noise_ceiling(rdms_data_poi_concat, method='rho-a', rdm_descriptor='index') # index descriptor will specify participants
  352. nc_df = pd.DataFrame({'lwr': [nc[0]], 'upr': [nc[1]]})
  353. # calculate P1 noise ceiling
  354. rdms_data_p1_concat = rsatoolbox.rdm.concat(rdms_data_p1_list)
  355. nc_p1 = rsatoolbox.inference.noise_ceiling.boot_noise_ceiling(rdms_data_p1_concat, method='rho-a', rdm_descriptor='index') # index descriptor will specify participants
  356. nc_p1_df = pd.DataFrame({'lwr': [nc_p1[0]], 'upr': [nc_p1[1]]})
  357. # calculate POI noise ceiling for all channels
  358. rdms_data_poi_all_chs_concat = rsatoolbox.rdm.concat(rdms_data_poi_list_all_chs)
  359. nc_poi_all_chs = rsatoolbox.inference.noise_ceiling.boot_noise_ceiling(rdms_data_poi_all_chs_concat, method='rho-a', rdm_descriptor='index') # index descriptor will specify participants
  360. nc_all_chs_df = pd.DataFrame({'lwr': [nc_poi_all_chs[0]], 'upr': [nc_poi_all_chs[1]]})
  361. rdms_by_time = [rsatoolbox.rdm.concat([rdms_s.subset('index', i)
  362. for rdms_s in rdms_data_list])
  363. for i in range(len(times_binned))]
  364. nc_time = np.array([rsatoolbox.inference.noise_ceiling.boot_noise_ceiling(x, method='rho-a', rdm_descriptor='index') for x in rdms_by_time])
  365. nc_df_time = pd.DataFrame({'time': rdms_data.rdm_descriptors['time'], 'lwr': nc_time[:, 0], 'upr': nc_time[:, 1]})
  366. rdms_all_chs_by_time = [rsatoolbox.rdm.concat([rdms_s.subset('index', i)
  367. for rdms_s in rdms_data_list_all_chs])
  368. for i in range(len(times_binned))]
  369. nc_all_chs_time = np.array([rsatoolbox.inference.noise_ceiling.boot_noise_ceiling(x, method='rho-a', rdm_descriptor='index') for x in rdms_all_chs_by_time])
  370. nc_all_chs_df_time = pd.DataFrame({'time': rdms_data_all_chs.rdm_descriptors['time'], 'lwr': nc_all_chs_time[:, 0], 'upr': nc_all_chs_time[:, 1]})
  371. nc_df.to_csv(op.join('noise_ceiling', 'noise_ceiling_poi.csv'), index=False)
  372. nc_p1_df.to_csv(op.join('noise_ceiling', 'noise_ceiling_p1.csv'), index=False)
  373. nc_df_time.to_csv(op.join('noise_ceiling', 'noise_ceiling_time.csv'), index=False)
  374. nc_all_chs_df.to_csv(op.join('noise_ceiling', 'noise_ceiling_poi_all_chs.csv'), index=False)
  375. nc_all_chs_df_time.to_csv(op.join('noise_ceiling', 'noise_ceiling_time_all_chs.csv'), index=False)
  376. # plt.plot(np.array([np.mean(x) for x in times_binned]), nc_time)
  377. # plt.fill_between(np.array([np.mean(x) for x in times_binned]), nc_time[:, 0], nc_time[:, 1], facecolor='lightgrey')
  378. Rs_ot = np.array(r_comb_ot).squeeze()
  379. Rs_jacc = np.array(r_comb_jacc).squeeze()
  380. Rs_complexity = np.array(r_comb_complexity).squeeze()
  381. print('AVERAGE ESTIMATES')
  382. print(f'Jaccard POI Rho: {np.array(r_comb_jacc_poi).squeeze().mean()}')
  383. print(f'Optimal Transport POI Rho: {np.array(r_comb_ot_poi).squeeze().mean()}')
  384. # %%
  385. # quick figures
  386. # save images of the channel locations
  387. # load the first epochs file
  388. tmp_epo = mne.read_epochs(os.path.join(epo_path, f'{str(all_subj_ids[0]).zfill(3)}-epo.fif'))
  389. # get a list of all channels not in the posterior ROI
  390. non_post_chs = [x['ch_name'] for x in tmp_epo.info['chs'] if (~np.isin(x['ch_name'], post_chs)) & (x['kind']==2)]
  391. # get a list of lists containing indices for the plot colours
  392. ch_groups = [[tmp_epo.ch_names.index(ch) for ch in g] for g in [non_post_chs, post_chs]]
  393. # plot ROI in red, and all other channels in black
  394. ch_pl = mne.viz.plot_sensors(tmp_epo.info, ch_groups=ch_groups, cmap = mpl.colors.ListedColormap(['black', 'red']), linewidth=0, pointsize=50, show_names=False, show=False)
  395. # set the figure size
  396. ch_pl.set_size_inches(2, 2)
  397. # edit the width of the lines for the head
  398. for line in ch_pl.axes[0].lines:
  399. line.set_linewidth(2)
  400. # explort as pdf
  401. ch_pl.savefig(os.path.join('fig', 'channels.pdf'), bbox_inches='tight', pad_inches=0)
  402. plt.close()