123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550 |
- # -*- coding: utf-8 -*-
- # %% general setup
- # import libraries
- import os
- import os.path as op
- from string import ascii_lowercase
- import re
- import numpy as np
- # import scipy
- import pandas as pd
- import mne
- import rsatoolbox
- import matplotlib.pyplot as plt
- import matplotlib as mpl
- from mpl_toolkits.axes_grid1 import make_axes_locatable
- fig_font_size = 8
- plt.rcParams.update({
- "font.family": "Helvetica",
- 'font.size': fig_font_size
- })
- # from tqdm import tqdm
- beh_path = op.join('beh', 'data')
- eeg_path = 'eeg'
- epo_path = 'epo'
- # get list of all subject IDs from behavioural file names
- plist = pd.read_csv(op.join('eeg', 'participants.tsv'), delimiter='\t')
- plist = plist.loc[plist.recording_restarted == 0]
- all_subj_ids = plist.participant_id
- # the list used to order the similarity matrices
- chars = [*ascii_lowercase, 'ä', 'ö', 'ü', 'ß']
- # %% import all data
- # get all participants' data
- all_epos = [mne.read_epochs(os.path.join(epo_path, f'{subj_id}-epo.fif')) for subj_id in all_subj_ids]
- # remove non-eeg channels
- all_epos = [e.pick(['eeg']) for e in all_epos]
- # for each participant get an array with all stimulus IDs
- all_epo_labs = [np.array(e.metadata.stimulus) for e in all_epos]
- # list of all posterior electrodes
- post_chs = ['TP9', 'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'TP8', 'TP10',
- 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8',
- 'PO7', 'PO3', 'POz', 'PO4', 'PO8',
- 'PO9', 'O1', 'Oz', 'O2', 'PO10']
- epos_concat = mne.concatenate_epochs(all_epos)
- # %%
- # for illustration
- fig = epos_concat.average().plot(picks = post_chs, time_unit='ms', highlight=(150, 225), show=False, selectable=False);
- fig.set_size_inches(2.4, 0.75)
- fig.savefig(op.join('fig', 'ERP.svg'))
- plt.close()
- eg_chars = chars
- # plot the patterns for one participant
- eg_participant_idx = 1
- # for scaling the colour bar
- min_avg = np.min([all_epos[eg_participant_idx][c].get_data(picks=post_chs, tmin=.150, tmax=.225).mean(axis=(0,2)) for c in eg_chars])
- max_avg = np.max([all_epos[eg_participant_idx][c].get_data(picks=post_chs, tmin=.150, tmax=.225).mean(axis=(0,2)) for c in eg_chars])
- abs_max = max([abs(min_avg), abs(max_avg)])
- n_chs = len(all_epos[0].info['ch_names'])
- rdbu_cmap = mpl.colormaps['RdBu_r']
- cb_fig, ax = plt.subplots(figsize=(1.7, .075), layout='constrained')
- cb_fig.colorbar(mpl.cm.ScalarMappable(
- norm=mpl.colors.Normalize(vmin=-abs_max * 1e6, vmax=abs_max * 1e6),
- cmap=rdbu_cmap),
- cax=ax, orientation='horizontal', label='µV',
- # ticks=[-3,-2,-1,0,1,2,3]
- ticks=[-8,-4,0,4,8]
- )
- cb_fig.savefig(op.join('fig', 'pattern_examples', 'colorbar.svg'), bbox_inches='tight', pad_inches=0)
- plt.close()
- for c in eg_chars:
- epo_dat = all_epos[eg_participant_idx][c].get_data(picks='eeg', tmin=.150, tmax=.225)
- epo_avg_per_ch = epo_dat.mean(axis = (0,2))
- # ch_grps = [ch if epos_concat.info['ch_names'][ch] in post_chs
- # else np.nan
- # for ch in range(len(epos_concat.info['ch_names']))]
- ch_grps = [[i] for i in range(n_chs)]
- ch_vals = np.array([epo_avg_per_ch[i] for i in range(n_chs)])
-
- ch_vals_scaled = (ch_vals + abs_max) / (abs_max*2)
- cmap_custom = mpl.colors.ListedColormap([rdbu_cmap(x) for x in ch_vals_scaled])
- # twilight.set_bad('white', 0)
- point_sizes = [125 if epos_concat.info.ch_names[i] in post_chs else 0 for i in range(n_chs)]
- ch_epo_fig = mne.viz.plot_sensors(epos_concat.info, ch_groups=ch_grps, cmap=cmap_custom, pointsize=point_sizes, linewidth=0)
- ch_epo_fig.set_size_inches((2, 2))
- ch_epo_fig.savefig(op.join('fig', 'pattern_examples', f'{c}.svg'), bbox_inches='tight', pad_inches=0)
- plt.close()
- # illustrate correlation distance RDMs for all participants
- eg_rdms = []
- for p in range(len(all_epos)):
- vecs = [all_epos[p][c].get_data(picks='eeg', tmin=.150, tmax=.225).mean(axis = (0,2)) for c in chars]
- eg_rdm = np.zeros((len(chars), len(chars)))
- for i in range(len(chars)):
- for j in range(len(chars)):
- eg_rdm[i, j] = 1 - np.corrcoef(x=vecs[i], y=vecs[j])[0, 1]
- eg_rdms.append(eg_rdm)
- rank_eg_rdms = []
- for p in range(len(all_epos)):
- eg_rdm = eg_rdms[p]
- eg_rdm_tril = eg_rdm[np.tril_indices(n=eg_rdm.shape[0], k=-1)]
- eg_rdm_tril_rank = eg_rdm_tril.argsort().argsort() + 1
- rank_eg_rdm = np.zeros((len(chars), len(chars)))
- rank_eg_rdm[np.tril_indices(n=eg_rdm.shape[0], k=-1)] = eg_rdm_tril_rank
- rank_eg_rdm[np.triu_indices(n=eg_rdm.shape[0], k=0)] = rank_eg_rdm.T[np.triu_indices(n=eg_rdm.shape[0], k=0)]
- rank_eg_rdm[np.diag_indices(n=rank_eg_rdm.shape[0])] = np.nan
- rank_eg_rdms.append(rank_eg_rdm)
- rdm_cmap = 'viridis'
- for p in range(len(all_epos)):
- eg_rdm = eg_rdms[p]
- fig, ax = plt.subplots(1, 1, figsize=(.5, .5))
- plt.imshow(eg_rdm, interpolation='none', cmap=rdm_cmap)
- # plt.imshow(eg_rdm, interpolation='none', cmap=rdm_cmap, vmin=0, vmax=2)
- ax.set_xticks([])
- ax.set_yticks([])
- fig.savefig(op.join('fig', 'rdm_examples', f'rdm_{p}.svg'), bbox_inches='tight', pad_inches=0)
- plt.close()
- rank_eg_rdm = rank_eg_rdms[p]
- fig, ax = plt.subplots(1, 1, figsize=(.5, .5))
- plt.imshow(rank_eg_rdm, interpolation='none', cmap=rdm_cmap, vmin=1, vmax=435)
- ax.set_xticks([])
- ax.set_yticks([])
- fig.savefig(op.join('fig', 'rank_rdm_examples', f'rank_rdm_{p}.svg'), bbox_inches='tight', pad_inches=0)
- plt.close()
- cb_fig, ax = plt.subplots(figsize=(0.5, .075), layout='constrained')
- cb_fig.colorbar(mpl.cm.ScalarMappable(
- norm=mpl.colors.Normalize(vmin=0, vmax=2),
- cmap=rdm_cmap),
- cax=ax, orientation='horizontal', label='1-r',
- # ticks=[-3,-2,-1,0,1,2,3]
- ticks=[]
- )
- cb_fig.savefig(op.join('fig', 'rdm_examples', 'colorbar.svg'), bbox_inches='tight', pad_inches=0)
- plt.close()
- cb_fig, ax = plt.subplots(figsize=(0.5, .075), layout='constrained')
- cb_fig.colorbar(mpl.cm.ScalarMappable(
- norm=mpl.colors.Normalize(vmin=1, vmax=435),
- cmap=rdm_cmap),
- cax=ax, orientation='horizontal', label='Rank',
- ticks=[]
- )
- cb_fig.savefig(op.join('fig', 'rank_rdm_examples', 'colorbar.svg'), bbox_inches='tight', pad_inches=0)
- plt.close()
- # %% RSA toolbox method
- r_comb_ot = []
- r_comb_jacc = []
- r_comb_complexity = []
- r_comb_ot_poi = []
- r_comb_jacc_poi = []
- rdms_data_list = []
- rdms_data_poi_list = []
- rdms_data_p1_list = []
- rdms_data_list_all_chs = []
- rdms_data_poi_list_all_chs = []
- rdms_data_p1_list_all_chs = []
- sim_path = 'stim_sim'
- # sim_path_jacc = op.join(sim_path, 'jacc')
- # sim_path_ot = op.join(sim_path, 'ot')
- sim_path_prereg = op.join(sim_path, 'preregistered')
- jacc_dissim_mat = np.load(op.join(sim_path_prereg, 'jacc.npy'))
- jacc_vec = jacc_dissim_mat[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)]
- jacc_mod = rsatoolbox.rdm.RDMs(
- dissimilarities = jacc_dissim_mat[np.newaxis, :, :],
- dissimilarity_measure = 'jaccard',
- descriptors = {
- 'index': np.arange(len(chars)),
- 'stimulus': chars
- }
- )
- ot_dissim_mat = np.load(op.join(sim_path_prereg, 'ot.npy'))
- ot_vec = ot_dissim_mat[np.tril_indices(ot_dissim_mat.shape[0], k=-1)]
- ot_mod = rsatoolbox.rdm.RDMs(
- dissimilarities = ot_dissim_mat[np.newaxis, :, :],
- dissimilarity_measure = 'optimal transport',
- descriptors = {
- 'index': np.arange(len(chars)),
- 'stimulus': chars
- }
- )
- sim_path_complexity = op.join(sim_path, 'complexity')
- complexity_dissim_mat = np.load(op.join(sim_path_complexity, 'complexity.npy'))
- complexity_vec = complexity_dissim_mat[np.tril_indices(complexity_dissim_mat.shape[0], k=-1)]
- complexity_mod = rsatoolbox.rdm.RDMs(
- dissimilarities = complexity_dissim_mat[np.newaxis, :, :],
- dissimilarity_measure = 'complexity',
- descriptors = {
- 'index': np.arange(len(chars)),
- 'stimulus': chars
- }
- )
- out_path = 'rdm_data'
- n1_period = [0.15, 0.225]
- p1_period = [0.080, 0.130]
- for s, subj_id in enumerate(all_subj_ids):
- print(f'Participant {subj_id} ({s+1}/{len(all_subj_ids)})')
- epo = all_epos[s].pick('eeg')
- # epo = all_epos[s]
- epo_labs = all_epo_labs[s]
- # get the similarities for the period of interest
- epo_poi = epo.copy().get_data(picks = post_chs, tmin=n1_period[0], tmax=n1_period[1])
- # reshape into n_obs * n_chs
- epo_poi_resh = np.zeros((epo_poi.shape[0] * epo_poi.shape[2], epo_poi.shape[1]))
- epo_labs_resh = np.repeat(epo_labs, epo_poi.shape[2])
- for ch in range(epo_poi.shape[1]):
- # for each channel, get data from all timepoints for epoch 1, then for epoch 2, epoch 3, etc.
- epo_poi_resh[:, ch] = epo_poi[:, ch, :].flatten()
- data_poi = rsatoolbox.data.Dataset(
- epo_poi_resh,
- channel_descriptors={'names': epo.copy().pick(post_chs).info['ch_names']},
- obs_descriptors={'stimulus': epo_labs_resh}
- )
- # get the equivalent for all EEG channels (not just the posterior region)
- epo_poi_all_chs = epo.copy().get_data(picks = 'eeg', tmin=n1_period[0], tmax=n1_period[1])
- epo_poi_all_chs_resh = np.zeros((epo_poi_all_chs.shape[0] * epo_poi_all_chs.shape[2], epo_poi_all_chs.shape[1]))
- for ch in range(epo_poi_all_chs.shape[1]):
- epo_poi_all_chs_resh[:, ch] = epo_poi_all_chs[:, ch, :].flatten()
- data_poi_all_chs = rsatoolbox.data.Dataset(
- epo_poi_all_chs_resh,
- channel_descriptors={'names': epo.info['ch_names']},
- obs_descriptors={'stimulus': epo_labs_resh} # can reuse this as N trials doesn't change
- )
- # get the similarities for the exploratory P1 period
- epo_p1 = epo.copy().get_data(picks = post_chs, tmin=p1_period[0], tmax=p1_period[1])
- # reshape into n_obs * n_chs
- epo_p1_resh = np.zeros((epo_p1.shape[0] * epo_p1.shape[2], epo_p1.shape[1]))
- epo_labs_resh_p1 = np.repeat(epo_labs, epo_p1.shape[2])
- for ch in range(epo_p1.shape[1]):
- # for each channel, get data from all timepoints for epoch 1, then for epoch 2, epoch 3, etc.
- epo_p1_resh[:, ch] = epo_p1[:, ch, :].flatten()
- data_p1 = rsatoolbox.data.Dataset(
- epo_p1_resh,
- channel_descriptors={'names': epo.copy().pick(post_chs).info['ch_names']},
- obs_descriptors={'stimulus': epo_labs_resh_p1}
- )
- # calculate RDMs and Rs
- rdms_data_poi = rsatoolbox.rdm.calc_rdm(data_poi, method='correlation', descriptor='stimulus')
- rdms_data_poi_all_chs = rsatoolbox.rdm.calc_rdm(data_poi_all_chs, method='correlation', descriptor='stimulus')
- rdms_data_p1 = rsatoolbox.rdm.calc_rdm(data_p1, method='correlation', descriptor='stimulus')
- jacc_r_poi = rsatoolbox.rdm.compare(jacc_mod, rdms_data_poi, method='rho-a')
- ot_r_poi = rsatoolbox.rdm.compare(ot_mod, rdms_data_poi, method='rho-a')
- jacc_r_p1 = rsatoolbox.rdm.compare(jacc_mod, rdms_data_p1, method='rho-a')
- ot_r_p1 = rsatoolbox.rdm.compare(ot_mod, rdms_data_p1, method='rho-a')
- r_comb_ot_poi.append(ot_r_poi)
- r_comb_jacc_poi.append(jacc_r_poi)
- # save to csv for later model
- mat_poi = rdms_data_poi.get_matrices().squeeze()
- vec_poi = mat_poi[np.tril_indices(mat_poi.shape[0], k=-1)]
- char1_ids = np.array(chars)[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)[0]]
- char2_ids = np.array(chars)[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)[1]]
- poi_df = pd.DataFrame({
- 'subj_id': subj_id,
- 'char1': char1_ids,
- 'char2': char2_ids,
- # 'ot_dist': ot_vec,
- # 'jacc_dissim': jacc_vec,
- 'eeg_dissim': vec_poi
- })
- poi_data_filepath = op.join(out_path, 'period_of_interest', f'{subj_id}.csv')
- poi_df.to_csv(poi_data_filepath, index=False)
- # also save the results for all channels
- mat_poi_all_chs = rdms_data_poi_all_chs.get_matrices().squeeze()
- vec_poi_all_chs = mat_poi_all_chs[np.tril_indices(mat_poi_all_chs.shape[0], k=-1)]
- char1_ids = np.array(chars)[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)[0]]
- char2_ids = np.array(chars)[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)[1]]
- poi_all_chs_df = pd.DataFrame({
- 'subj_id': subj_id,
- 'char1': char1_ids,
- 'char2': char2_ids,
- # 'ot_dist': ot_vec,
- # 'jacc_dissim': jacc_vec,
- 'eeg_dissim': vec_poi_all_chs
- })
- poi_data_all_chs_filepath = op.join(out_path, 'period_of_interest_all_chs', f'{subj_id}.csv')
- poi_all_chs_df.to_csv(poi_data_all_chs_filepath, index=False)
- # also save the P1 results
- mat_p1 = rdms_data_p1.get_matrices().squeeze()
- vec_p1 = mat_p1[np.tril_indices(mat_p1.shape[0], k=-1)]
- char1_ids = np.array(chars)[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)[0]]
- char2_ids = np.array(chars)[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)[1]]
- p1_df = pd.DataFrame({
- 'subj_id': subj_id,
- 'char1': char1_ids,
- 'char2': char2_ids,
- # 'ot_dist': ot_vec,
- # 'jacc_dissim': jacc_vec,
- 'eeg_dissim': vec_p1
- })
- p1_data_filepath = op.join(out_path, 'p1_period', f'{subj_id}.csv')
- p1_df.to_csv(p1_data_filepath, index=False)
-
- # get the time resolved results
- data = rsatoolbox.data.TemporalDataset(
- epo.copy().get_data(picks = post_chs),
- channel_descriptors={'names': post_chs},
- obs_descriptors={'stimulus': epo_labs},
- time_descriptors={'time': epo.times}
- )
- # bins with bin_len time points per bin (first sample, in baseline, will have bin_len+1 time points, because of odd number)
- # at 1000 ms, samples and ms are equivalent, so bins are bin_len ms wide
- # bin_len = 4
- bin_len=10
- times_binned = np.array_split(epo.times, len(epo.times)/bin_len)
- # rdms_data = rsatoolbox.rdm.calc_rdm_movie(data, method='correlation', descriptor='stimulus', bins=times_binned, noise=noise_prec_shrink)
- rdms_data = rsatoolbox.rdm.calc_rdm_movie(data, method='correlation', descriptor='stimulus', bins=times_binned)
-
- jacc_r = rsatoolbox.rdm.compare(jacc_mod, rdms_data, method='rho-a')
-
- ot_r = rsatoolbox.rdm.compare(ot_mod, rdms_data, method='rho-a')
- complexity_r = rsatoolbox.rdm.compare(complexity_mod, rdms_data, method='rho-a')
- # save to csv for later model
- mats_per_tp = list(rdms_data.get_matrices())
- vecs_per_tp = [x[np.tril_indices(x.shape[0], k=-1)] for x in mats_per_tp]
- tp_dfs = [pd.DataFrame({
- 'subj_id': subj_id,
- 'char1': char1_ids,
- 'char2': char2_ids,
- 'time': t,
- # 'ot_dist': ot_vec,
- # 'jacc_dissim': jacc_vec,
- 'eeg_dissim': vecs_per_tp[i]
- }) for i, t in enumerate(rdms_data.rdm_descriptors['time'])]
- subj_df = pd.concat(tp_dfs)
- rdm_data_filepath = op.join(out_path, 'time_resolved', f'{subj_id}.csv')
- subj_df.to_csv(rdm_data_filepath, index=False)
- # also save time-resolved data for all channels
- data_all_chs = rsatoolbox.data.TemporalDataset(
- epo.copy().get_data(picks = 'eeg'),
- channel_descriptors={'names': epo.info['ch_names']},
- obs_descriptors={'stimulus': epo_labs},
- time_descriptors={'time': epo.times}
- )
- rdms_data_all_chs = rsatoolbox.rdm.calc_rdm_movie(data_all_chs, method='correlation', descriptor='stimulus', bins=times_binned)
- mats_per_tp_all_chs = list(rdms_data_all_chs.get_matrices())
- vecs_per_tp_all_chs = [x[np.tril_indices(x.shape[0], k=-1)] for x in mats_per_tp_all_chs]
- tp_dfs_all_chs = [pd.DataFrame({
- 'subj_id': subj_id,
- 'char1': char1_ids,
- 'char2': char2_ids,
- 'time': t,
- 'eeg_dissim': vecs_per_tp_all_chs[i]
- }) for i, t in enumerate(rdms_data_all_chs.rdm_descriptors['time'])]
- subj_df_all_chs = pd.concat(tp_dfs_all_chs)
- rdm_data_all_chs_filepath = op.join(out_path, 'time_resolved_all_chs', f'{subj_id}.csv')
- subj_df_all_chs.to_csv(rdm_data_all_chs_filepath, index=False)
- # # get Rs
- # r = [scipy.stats.spearmanr(jacc_vec, x).statistic for x in vecs_per_tp]
- # Spearman's Rho
- r_comb_ot.append(ot_r)
- r_comb_jacc.append(jacc_r)
- r_comb_complexity.append(complexity_r)
- rdms_data_list.append(rdms_data)
- rdms_data_poi_list.append(rdms_data_poi)
- rdms_data_p1_list.append(rdms_data_p1)
- rdms_data_list_all_chs.append(rdms_data_all_chs)
- rdms_data_poi_list_all_chs.append(rdms_data_poi_all_chs)
- # calculate noise ceiling with average of per-participant leave-one-out cross-validation and full-set comparison
- rdms_data_poi_concat = rsatoolbox.rdm.concat(rdms_data_poi_list)
- nc = rsatoolbox.inference.noise_ceiling.boot_noise_ceiling(rdms_data_poi_concat, method='rho-a', rdm_descriptor='index') # index descriptor will specify participants
- nc_df = pd.DataFrame({'lwr': [nc[0]], 'upr': [nc[1]]})
- # calculate P1 noise ceiling
- rdms_data_p1_concat = rsatoolbox.rdm.concat(rdms_data_p1_list)
- nc_p1 = rsatoolbox.inference.noise_ceiling.boot_noise_ceiling(rdms_data_p1_concat, method='rho-a', rdm_descriptor='index') # index descriptor will specify participants
- nc_p1_df = pd.DataFrame({'lwr': [nc_p1[0]], 'upr': [nc_p1[1]]})
- # calculate POI noise ceiling for all channels
- rdms_data_poi_all_chs_concat = rsatoolbox.rdm.concat(rdms_data_poi_list_all_chs)
- nc_poi_all_chs = rsatoolbox.inference.noise_ceiling.boot_noise_ceiling(rdms_data_poi_all_chs_concat, method='rho-a', rdm_descriptor='index') # index descriptor will specify participants
- nc_all_chs_df = pd.DataFrame({'lwr': [nc_poi_all_chs[0]], 'upr': [nc_poi_all_chs[1]]})
- rdms_by_time = [rsatoolbox.rdm.concat([rdms_s.subset('index', i)
- for rdms_s in rdms_data_list])
- for i in range(len(times_binned))]
- nc_time = np.array([rsatoolbox.inference.noise_ceiling.boot_noise_ceiling(x, method='rho-a', rdm_descriptor='index') for x in rdms_by_time])
- nc_df_time = pd.DataFrame({'time': rdms_data.rdm_descriptors['time'], 'lwr': nc_time[:, 0], 'upr': nc_time[:, 1]})
- rdms_all_chs_by_time = [rsatoolbox.rdm.concat([rdms_s.subset('index', i)
- for rdms_s in rdms_data_list_all_chs])
- for i in range(len(times_binned))]
- nc_all_chs_time = np.array([rsatoolbox.inference.noise_ceiling.boot_noise_ceiling(x, method='rho-a', rdm_descriptor='index') for x in rdms_all_chs_by_time])
- nc_all_chs_df_time = pd.DataFrame({'time': rdms_data_all_chs.rdm_descriptors['time'], 'lwr': nc_all_chs_time[:, 0], 'upr': nc_all_chs_time[:, 1]})
- nc_df.to_csv(op.join('noise_ceiling', 'noise_ceiling_poi.csv'), index=False)
- nc_p1_df.to_csv(op.join('noise_ceiling', 'noise_ceiling_p1.csv'), index=False)
- nc_df_time.to_csv(op.join('noise_ceiling', 'noise_ceiling_time.csv'), index=False)
- nc_all_chs_df.to_csv(op.join('noise_ceiling', 'noise_ceiling_poi_all_chs.csv'), index=False)
- nc_all_chs_df_time.to_csv(op.join('noise_ceiling', 'noise_ceiling_time_all_chs.csv'), index=False)
- # plt.plot(np.array([np.mean(x) for x in times_binned]), nc_time)
- # plt.fill_between(np.array([np.mean(x) for x in times_binned]), nc_time[:, 0], nc_time[:, 1], facecolor='lightgrey')
- Rs_ot = np.array(r_comb_ot).squeeze()
- Rs_jacc = np.array(r_comb_jacc).squeeze()
- Rs_complexity = np.array(r_comb_complexity).squeeze()
- print('AVERAGE ESTIMATES')
- print(f'Jaccard POI Rho: {np.array(r_comb_jacc_poi).squeeze().mean()}')
- print(f'Optimal Transport POI Rho: {np.array(r_comb_ot_poi).squeeze().mean()}')
- # %%
- # quick figures
- # save images of the channel locations
- # load the first epochs file
- tmp_epo = mne.read_epochs(os.path.join(epo_path, f'{str(all_subj_ids[0]).zfill(3)}-epo.fif'))
- # get a list of all channels not in the posterior ROI
- non_post_chs = [x['ch_name'] for x in tmp_epo.info['chs'] if (~np.isin(x['ch_name'], post_chs)) & (x['kind']==2)]
- # get a list of lists containing indices for the plot colours
- ch_groups = [[tmp_epo.ch_names.index(ch) for ch in g] for g in [non_post_chs, post_chs]]
- # plot ROI in red, and all other channels in black
- ch_pl = mne.viz.plot_sensors(tmp_epo.info, ch_groups=ch_groups, cmap = mpl.colors.ListedColormap(['black', 'red']), linewidth=0, pointsize=50, show_names=False, show=False)
- # set the figure size
- ch_pl.set_size_inches(2, 2)
- # edit the width of the lines for the head
- for line in ch_pl.axes[0].lines:
- line.set_linewidth(2)
- # explort as pdf
- ch_pl.savefig(os.path.join('fig', 'channels.pdf'), bbox_inches='tight', pad_inches=0)
- plt.close()
|