# -*- coding: utf-8 -*- # %% general setup # import libraries import os import os.path as op from string import ascii_lowercase import re import numpy as np # import scipy import pandas as pd import mne import rsatoolbox import matplotlib.pyplot as plt import matplotlib as mpl from mpl_toolkits.axes_grid1 import make_axes_locatable fig_font_size = 8 plt.rcParams.update({ "font.family": "Helvetica", 'font.size': fig_font_size }) # from tqdm import tqdm beh_path = op.join('beh', 'data') eeg_path = 'eeg' epo_path = 'epo' # get list of all subject IDs from behavioural file names plist = pd.read_csv(op.join('eeg', 'participants.tsv'), delimiter='\t') plist = plist.loc[plist.recording_restarted == 0] all_subj_ids = plist.participant_id # the list used to order the similarity matrices chars = [*ascii_lowercase, 'ä', 'ö', 'ü', 'ß'] # %% import all data # get all participants' data all_epos = [mne.read_epochs(os.path.join(epo_path, f'{subj_id}-epo.fif')) for subj_id in all_subj_ids] # remove non-eeg channels all_epos = [e.pick(['eeg']) for e in all_epos] # for each participant get an array with all stimulus IDs all_epo_labs = [np.array(e.metadata.stimulus) for e in all_epos] # list of all posterior electrodes post_chs = ['TP9', 'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'TP8', 'TP10', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO3', 'POz', 'PO4', 'PO8', 'PO9', 'O1', 'Oz', 'O2', 'PO10'] epos_concat = mne.concatenate_epochs(all_epos) # %% # for illustration fig = epos_concat.average().plot(picks = post_chs, time_unit='ms', highlight=(150, 225), show=False, selectable=False); fig.set_size_inches(2.4, 0.75) fig.savefig(op.join('fig', 'ERP.svg')) plt.close() eg_chars = chars # plot the patterns for one participant eg_participant_idx = 1 # for scaling the colour bar min_avg = np.min([all_epos[eg_participant_idx][c].get_data(picks=post_chs, tmin=.150, tmax=.225).mean(axis=(0,2)) for c in eg_chars]) max_avg = np.max([all_epos[eg_participant_idx][c].get_data(picks=post_chs, tmin=.150, tmax=.225).mean(axis=(0,2)) for c in eg_chars]) abs_max = max([abs(min_avg), abs(max_avg)]) n_chs = len(all_epos[0].info['ch_names']) rdbu_cmap = mpl.colormaps['RdBu_r'] cb_fig, ax = plt.subplots(figsize=(1.7, .075), layout='constrained') cb_fig.colorbar(mpl.cm.ScalarMappable( norm=mpl.colors.Normalize(vmin=-abs_max * 1e6, vmax=abs_max * 1e6), cmap=rdbu_cmap), cax=ax, orientation='horizontal', label='µV', # ticks=[-3,-2,-1,0,1,2,3] ticks=[-8,-4,0,4,8] ) cb_fig.savefig(op.join('fig', 'pattern_examples', 'colorbar.svg'), bbox_inches='tight', pad_inches=0) plt.close() for c in eg_chars: epo_dat = all_epos[eg_participant_idx][c].get_data(picks='eeg', tmin=.150, tmax=.225) epo_avg_per_ch = epo_dat.mean(axis = (0,2)) # ch_grps = [ch if epos_concat.info['ch_names'][ch] in post_chs # else np.nan # for ch in range(len(epos_concat.info['ch_names']))] ch_grps = [[i] for i in range(n_chs)] ch_vals = np.array([epo_avg_per_ch[i] for i in range(n_chs)]) ch_vals_scaled = (ch_vals + abs_max) / (abs_max*2) cmap_custom = mpl.colors.ListedColormap([rdbu_cmap(x) for x in ch_vals_scaled]) # twilight.set_bad('white', 0) point_sizes = [125 if epos_concat.info.ch_names[i] in post_chs else 0 for i in range(n_chs)] ch_epo_fig = mne.viz.plot_sensors(epos_concat.info, ch_groups=ch_grps, cmap=cmap_custom, pointsize=point_sizes, linewidth=0) ch_epo_fig.set_size_inches((2, 2)) ch_epo_fig.savefig(op.join('fig', 'pattern_examples', f'{c}.svg'), bbox_inches='tight', pad_inches=0) plt.close() # illustrate correlation distance RDMs for all participants eg_rdms = [] for p in range(len(all_epos)): vecs = [all_epos[p][c].get_data(picks='eeg', tmin=.150, tmax=.225).mean(axis = (0,2)) for c in chars] eg_rdm = np.zeros((len(chars), len(chars))) for i in range(len(chars)): for j in range(len(chars)): eg_rdm[i, j] = 1 - np.corrcoef(x=vecs[i], y=vecs[j])[0, 1] eg_rdms.append(eg_rdm) rank_eg_rdms = [] for p in range(len(all_epos)): eg_rdm = eg_rdms[p] eg_rdm_tril = eg_rdm[np.tril_indices(n=eg_rdm.shape[0], k=-1)] eg_rdm_tril_rank = eg_rdm_tril.argsort().argsort() + 1 rank_eg_rdm = np.zeros((len(chars), len(chars))) rank_eg_rdm[np.tril_indices(n=eg_rdm.shape[0], k=-1)] = eg_rdm_tril_rank rank_eg_rdm[np.triu_indices(n=eg_rdm.shape[0], k=0)] = rank_eg_rdm.T[np.triu_indices(n=eg_rdm.shape[0], k=0)] rank_eg_rdm[np.diag_indices(n=rank_eg_rdm.shape[0])] = np.nan rank_eg_rdms.append(rank_eg_rdm) rdm_cmap = 'viridis' for p in range(len(all_epos)): eg_rdm = eg_rdms[p] fig, ax = plt.subplots(1, 1, figsize=(.5, .5)) plt.imshow(eg_rdm, interpolation='none', cmap=rdm_cmap) # plt.imshow(eg_rdm, interpolation='none', cmap=rdm_cmap, vmin=0, vmax=2) ax.set_xticks([]) ax.set_yticks([]) fig.savefig(op.join('fig', 'rdm_examples', f'rdm_{p}.svg'), bbox_inches='tight', pad_inches=0) plt.close() rank_eg_rdm = rank_eg_rdms[p] fig, ax = plt.subplots(1, 1, figsize=(.5, .5)) plt.imshow(rank_eg_rdm, interpolation='none', cmap=rdm_cmap, vmin=1, vmax=435) ax.set_xticks([]) ax.set_yticks([]) fig.savefig(op.join('fig', 'rank_rdm_examples', f'rank_rdm_{p}.svg'), bbox_inches='tight', pad_inches=0) plt.close() cb_fig, ax = plt.subplots(figsize=(0.5, .075), layout='constrained') cb_fig.colorbar(mpl.cm.ScalarMappable( norm=mpl.colors.Normalize(vmin=0, vmax=2), cmap=rdm_cmap), cax=ax, orientation='horizontal', label='1-r', # ticks=[-3,-2,-1,0,1,2,3] ticks=[] ) cb_fig.savefig(op.join('fig', 'rdm_examples', 'colorbar.svg'), bbox_inches='tight', pad_inches=0) plt.close() cb_fig, ax = plt.subplots(figsize=(0.5, .075), layout='constrained') cb_fig.colorbar(mpl.cm.ScalarMappable( norm=mpl.colors.Normalize(vmin=1, vmax=435), cmap=rdm_cmap), cax=ax, orientation='horizontal', label='Rank', ticks=[] ) cb_fig.savefig(op.join('fig', 'rank_rdm_examples', 'colorbar.svg'), bbox_inches='tight', pad_inches=0) plt.close() # %% RSA toolbox method r_comb_ot = [] r_comb_jacc = [] r_comb_complexity = [] r_comb_ot_poi = [] r_comb_jacc_poi = [] rdms_data_list = [] rdms_data_poi_list = [] rdms_data_p1_list = [] rdms_data_list_all_chs = [] rdms_data_poi_list_all_chs = [] rdms_data_p1_list_all_chs = [] sim_path = 'stim_sim' # sim_path_jacc = op.join(sim_path, 'jacc') # sim_path_ot = op.join(sim_path, 'ot') sim_path_prereg = op.join(sim_path, 'preregistered') jacc_dissim_mat = np.load(op.join(sim_path_prereg, 'jacc.npy')) jacc_vec = jacc_dissim_mat[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)] jacc_mod = rsatoolbox.rdm.RDMs( dissimilarities = jacc_dissim_mat[np.newaxis, :, :], dissimilarity_measure = 'jaccard', descriptors = { 'index': np.arange(len(chars)), 'stimulus': chars } ) ot_dissim_mat = np.load(op.join(sim_path_prereg, 'ot.npy')) ot_vec = ot_dissim_mat[np.tril_indices(ot_dissim_mat.shape[0], k=-1)] ot_mod = rsatoolbox.rdm.RDMs( dissimilarities = ot_dissim_mat[np.newaxis, :, :], dissimilarity_measure = 'optimal transport', descriptors = { 'index': np.arange(len(chars)), 'stimulus': chars } ) sim_path_complexity = op.join(sim_path, 'complexity') complexity_dissim_mat = np.load(op.join(sim_path_complexity, 'complexity.npy')) complexity_vec = complexity_dissim_mat[np.tril_indices(complexity_dissim_mat.shape[0], k=-1)] complexity_mod = rsatoolbox.rdm.RDMs( dissimilarities = complexity_dissim_mat[np.newaxis, :, :], dissimilarity_measure = 'complexity', descriptors = { 'index': np.arange(len(chars)), 'stimulus': chars } ) out_path = 'rdm_data' n1_period = [0.15, 0.225] p1_period = [0.080, 0.130] for s, subj_id in enumerate(all_subj_ids): print(f'Participant {subj_id} ({s+1}/{len(all_subj_ids)})') epo = all_epos[s].pick('eeg') # epo = all_epos[s] epo_labs = all_epo_labs[s] # get the similarities for the period of interest epo_poi = epo.copy().get_data(picks = post_chs, tmin=n1_period[0], tmax=n1_period[1]) # reshape into n_obs * n_chs epo_poi_resh = np.zeros((epo_poi.shape[0] * epo_poi.shape[2], epo_poi.shape[1])) epo_labs_resh = np.repeat(epo_labs, epo_poi.shape[2]) for ch in range(epo_poi.shape[1]): # for each channel, get data from all timepoints for epoch 1, then for epoch 2, epoch 3, etc. epo_poi_resh[:, ch] = epo_poi[:, ch, :].flatten() data_poi = rsatoolbox.data.Dataset( epo_poi_resh, channel_descriptors={'names': epo.copy().pick(post_chs).info['ch_names']}, obs_descriptors={'stimulus': epo_labs_resh} ) # get the equivalent for all EEG channels (not just the posterior region) epo_poi_all_chs = epo.copy().get_data(picks = 'eeg', tmin=n1_period[0], tmax=n1_period[1]) epo_poi_all_chs_resh = np.zeros((epo_poi_all_chs.shape[0] * epo_poi_all_chs.shape[2], epo_poi_all_chs.shape[1])) for ch in range(epo_poi_all_chs.shape[1]): epo_poi_all_chs_resh[:, ch] = epo_poi_all_chs[:, ch, :].flatten() data_poi_all_chs = rsatoolbox.data.Dataset( epo_poi_all_chs_resh, channel_descriptors={'names': epo.info['ch_names']}, obs_descriptors={'stimulus': epo_labs_resh} # can reuse this as N trials doesn't change ) # get the similarities for the exploratory P1 period epo_p1 = epo.copy().get_data(picks = post_chs, tmin=p1_period[0], tmax=p1_period[1]) # reshape into n_obs * n_chs epo_p1_resh = np.zeros((epo_p1.shape[0] * epo_p1.shape[2], epo_p1.shape[1])) epo_labs_resh_p1 = np.repeat(epo_labs, epo_p1.shape[2]) for ch in range(epo_p1.shape[1]): # for each channel, get data from all timepoints for epoch 1, then for epoch 2, epoch 3, etc. epo_p1_resh[:, ch] = epo_p1[:, ch, :].flatten() data_p1 = rsatoolbox.data.Dataset( epo_p1_resh, channel_descriptors={'names': epo.copy().pick(post_chs).info['ch_names']}, obs_descriptors={'stimulus': epo_labs_resh_p1} ) # calculate RDMs and Rs rdms_data_poi = rsatoolbox.rdm.calc_rdm(data_poi, method='correlation', descriptor='stimulus') rdms_data_poi_all_chs = rsatoolbox.rdm.calc_rdm(data_poi_all_chs, method='correlation', descriptor='stimulus') rdms_data_p1 = rsatoolbox.rdm.calc_rdm(data_p1, method='correlation', descriptor='stimulus') jacc_r_poi = rsatoolbox.rdm.compare(jacc_mod, rdms_data_poi, method='rho-a') ot_r_poi = rsatoolbox.rdm.compare(ot_mod, rdms_data_poi, method='rho-a') jacc_r_p1 = rsatoolbox.rdm.compare(jacc_mod, rdms_data_p1, method='rho-a') ot_r_p1 = rsatoolbox.rdm.compare(ot_mod, rdms_data_p1, method='rho-a') r_comb_ot_poi.append(ot_r_poi) r_comb_jacc_poi.append(jacc_r_poi) # save to csv for later model mat_poi = rdms_data_poi.get_matrices().squeeze() vec_poi = mat_poi[np.tril_indices(mat_poi.shape[0], k=-1)] char1_ids = np.array(chars)[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)[0]] char2_ids = np.array(chars)[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)[1]] poi_df = pd.DataFrame({ 'subj_id': subj_id, 'char1': char1_ids, 'char2': char2_ids, # 'ot_dist': ot_vec, # 'jacc_dissim': jacc_vec, 'eeg_dissim': vec_poi }) poi_data_filepath = op.join(out_path, 'period_of_interest', f'{subj_id}.csv') poi_df.to_csv(poi_data_filepath, index=False) # also save the results for all channels mat_poi_all_chs = rdms_data_poi_all_chs.get_matrices().squeeze() vec_poi_all_chs = mat_poi_all_chs[np.tril_indices(mat_poi_all_chs.shape[0], k=-1)] char1_ids = np.array(chars)[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)[0]] char2_ids = np.array(chars)[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)[1]] poi_all_chs_df = pd.DataFrame({ 'subj_id': subj_id, 'char1': char1_ids, 'char2': char2_ids, # 'ot_dist': ot_vec, # 'jacc_dissim': jacc_vec, 'eeg_dissim': vec_poi_all_chs }) poi_data_all_chs_filepath = op.join(out_path, 'period_of_interest_all_chs', f'{subj_id}.csv') poi_all_chs_df.to_csv(poi_data_all_chs_filepath, index=False) # also save the P1 results mat_p1 = rdms_data_p1.get_matrices().squeeze() vec_p1 = mat_p1[np.tril_indices(mat_p1.shape[0], k=-1)] char1_ids = np.array(chars)[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)[0]] char2_ids = np.array(chars)[np.tril_indices(jacc_dissim_mat.shape[0], k=-1)[1]] p1_df = pd.DataFrame({ 'subj_id': subj_id, 'char1': char1_ids, 'char2': char2_ids, # 'ot_dist': ot_vec, # 'jacc_dissim': jacc_vec, 'eeg_dissim': vec_p1 }) p1_data_filepath = op.join(out_path, 'p1_period', f'{subj_id}.csv') p1_df.to_csv(p1_data_filepath, index=False) # get the time resolved results data = rsatoolbox.data.TemporalDataset( epo.copy().get_data(picks = post_chs), channel_descriptors={'names': post_chs}, obs_descriptors={'stimulus': epo_labs}, time_descriptors={'time': epo.times} ) # bins with bin_len time points per bin (first sample, in baseline, will have bin_len+1 time points, because of odd number) # at 1000 ms, samples and ms are equivalent, so bins are bin_len ms wide # bin_len = 4 bin_len=10 times_binned = np.array_split(epo.times, len(epo.times)/bin_len) # rdms_data = rsatoolbox.rdm.calc_rdm_movie(data, method='correlation', descriptor='stimulus', bins=times_binned, noise=noise_prec_shrink) rdms_data = rsatoolbox.rdm.calc_rdm_movie(data, method='correlation', descriptor='stimulus', bins=times_binned) jacc_r = rsatoolbox.rdm.compare(jacc_mod, rdms_data, method='rho-a') ot_r = rsatoolbox.rdm.compare(ot_mod, rdms_data, method='rho-a') complexity_r = rsatoolbox.rdm.compare(complexity_mod, rdms_data, method='rho-a') # save to csv for later model mats_per_tp = list(rdms_data.get_matrices()) vecs_per_tp = [x[np.tril_indices(x.shape[0], k=-1)] for x in mats_per_tp] tp_dfs = [pd.DataFrame({ 'subj_id': subj_id, 'char1': char1_ids, 'char2': char2_ids, 'time': t, # 'ot_dist': ot_vec, # 'jacc_dissim': jacc_vec, 'eeg_dissim': vecs_per_tp[i] }) for i, t in enumerate(rdms_data.rdm_descriptors['time'])] subj_df = pd.concat(tp_dfs) rdm_data_filepath = op.join(out_path, 'time_resolved', f'{subj_id}.csv') subj_df.to_csv(rdm_data_filepath, index=False) # also save time-resolved data for all channels data_all_chs = rsatoolbox.data.TemporalDataset( epo.copy().get_data(picks = 'eeg'), channel_descriptors={'names': epo.info['ch_names']}, obs_descriptors={'stimulus': epo_labs}, time_descriptors={'time': epo.times} ) rdms_data_all_chs = rsatoolbox.rdm.calc_rdm_movie(data_all_chs, method='correlation', descriptor='stimulus', bins=times_binned) mats_per_tp_all_chs = list(rdms_data_all_chs.get_matrices()) vecs_per_tp_all_chs = [x[np.tril_indices(x.shape[0], k=-1)] for x in mats_per_tp_all_chs] tp_dfs_all_chs = [pd.DataFrame({ 'subj_id': subj_id, 'char1': char1_ids, 'char2': char2_ids, 'time': t, 'eeg_dissim': vecs_per_tp_all_chs[i] }) for i, t in enumerate(rdms_data_all_chs.rdm_descriptors['time'])] subj_df_all_chs = pd.concat(tp_dfs_all_chs) rdm_data_all_chs_filepath = op.join(out_path, 'time_resolved_all_chs', f'{subj_id}.csv') subj_df_all_chs.to_csv(rdm_data_all_chs_filepath, index=False) # # get Rs # r = [scipy.stats.spearmanr(jacc_vec, x).statistic for x in vecs_per_tp] # Spearman's Rho r_comb_ot.append(ot_r) r_comb_jacc.append(jacc_r) r_comb_complexity.append(complexity_r) rdms_data_list.append(rdms_data) rdms_data_poi_list.append(rdms_data_poi) rdms_data_p1_list.append(rdms_data_p1) rdms_data_list_all_chs.append(rdms_data_all_chs) rdms_data_poi_list_all_chs.append(rdms_data_poi_all_chs) # calculate noise ceiling with average of per-participant leave-one-out cross-validation and full-set comparison rdms_data_poi_concat = rsatoolbox.rdm.concat(rdms_data_poi_list) nc = rsatoolbox.inference.noise_ceiling.boot_noise_ceiling(rdms_data_poi_concat, method='rho-a', rdm_descriptor='index') # index descriptor will specify participants nc_df = pd.DataFrame({'lwr': [nc[0]], 'upr': [nc[1]]}) # calculate P1 noise ceiling rdms_data_p1_concat = rsatoolbox.rdm.concat(rdms_data_p1_list) nc_p1 = rsatoolbox.inference.noise_ceiling.boot_noise_ceiling(rdms_data_p1_concat, method='rho-a', rdm_descriptor='index') # index descriptor will specify participants nc_p1_df = pd.DataFrame({'lwr': [nc_p1[0]], 'upr': [nc_p1[1]]}) # calculate POI noise ceiling for all channels rdms_data_poi_all_chs_concat = rsatoolbox.rdm.concat(rdms_data_poi_list_all_chs) nc_poi_all_chs = rsatoolbox.inference.noise_ceiling.boot_noise_ceiling(rdms_data_poi_all_chs_concat, method='rho-a', rdm_descriptor='index') # index descriptor will specify participants nc_all_chs_df = pd.DataFrame({'lwr': [nc_poi_all_chs[0]], 'upr': [nc_poi_all_chs[1]]}) rdms_by_time = [rsatoolbox.rdm.concat([rdms_s.subset('index', i) for rdms_s in rdms_data_list]) for i in range(len(times_binned))] nc_time = np.array([rsatoolbox.inference.noise_ceiling.boot_noise_ceiling(x, method='rho-a', rdm_descriptor='index') for x in rdms_by_time]) nc_df_time = pd.DataFrame({'time': rdms_data.rdm_descriptors['time'], 'lwr': nc_time[:, 0], 'upr': nc_time[:, 1]}) rdms_all_chs_by_time = [rsatoolbox.rdm.concat([rdms_s.subset('index', i) for rdms_s in rdms_data_list_all_chs]) for i in range(len(times_binned))] nc_all_chs_time = np.array([rsatoolbox.inference.noise_ceiling.boot_noise_ceiling(x, method='rho-a', rdm_descriptor='index') for x in rdms_all_chs_by_time]) nc_all_chs_df_time = pd.DataFrame({'time': rdms_data_all_chs.rdm_descriptors['time'], 'lwr': nc_all_chs_time[:, 0], 'upr': nc_all_chs_time[:, 1]}) nc_df.to_csv(op.join('noise_ceiling', 'noise_ceiling_poi.csv'), index=False) nc_p1_df.to_csv(op.join('noise_ceiling', 'noise_ceiling_p1.csv'), index=False) nc_df_time.to_csv(op.join('noise_ceiling', 'noise_ceiling_time.csv'), index=False) nc_all_chs_df.to_csv(op.join('noise_ceiling', 'noise_ceiling_poi_all_chs.csv'), index=False) nc_all_chs_df_time.to_csv(op.join('noise_ceiling', 'noise_ceiling_time_all_chs.csv'), index=False) # plt.plot(np.array([np.mean(x) for x in times_binned]), nc_time) # plt.fill_between(np.array([np.mean(x) for x in times_binned]), nc_time[:, 0], nc_time[:, 1], facecolor='lightgrey') Rs_ot = np.array(r_comb_ot).squeeze() Rs_jacc = np.array(r_comb_jacc).squeeze() Rs_complexity = np.array(r_comb_complexity).squeeze() print('AVERAGE ESTIMATES') print(f'Jaccard POI Rho: {np.array(r_comb_jacc_poi).squeeze().mean()}') print(f'Optimal Transport POI Rho: {np.array(r_comb_ot_poi).squeeze().mean()}') # %% # quick figures # save images of the channel locations # load the first epochs file tmp_epo = mne.read_epochs(os.path.join(epo_path, f'{str(all_subj_ids[0]).zfill(3)}-epo.fif')) # get a list of all channels not in the posterior ROI non_post_chs = [x['ch_name'] for x in tmp_epo.info['chs'] if (~np.isin(x['ch_name'], post_chs)) & (x['kind']==2)] # get a list of lists containing indices for the plot colours ch_groups = [[tmp_epo.ch_names.index(ch) for ch in g] for g in [non_post_chs, post_chs]] # plot ROI in red, and all other channels in black ch_pl = mne.viz.plot_sensors(tmp_epo.info, ch_groups=ch_groups, cmap = mpl.colors.ListedColormap(['black', 'red']), linewidth=0, pointsize=50, show_names=False, show=False) # set the figure size ch_pl.set_size_inches(2, 2) # edit the width of the lines for the head for line in ch_pl.axes[0].lines: line.set_linewidth(2) # explort as pdf ch_pl.savefig(os.path.join('fig', 'channels.pdf'), bbox_inches='tight', pad_inches=0) plt.close()