123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354 |
- # %%
- # import modules
- import os.path as op
- import glob
- import pandas as pd
- import numpy as np
- import re
- import mne
- from mne_icalabel import label_components
- import mne_bids
- # import matplotlib as mpl
- from string import ascii_lowercase
- import matplotlib.pyplot as plt
- chars = [*ascii_lowercase, 'ä', 'ö', 'ü', 'ß']
- # %matplotlib qt
- fig_font_size = 8
- plt.rcParams.update({
- "text.usetex": False,
- "font.family": "Helvetica",
- 'font.size': fig_font_size
- })
- example_sub_idx = 10 # will index sub-12 after subsetting data
- # function to plot sources and topographies side by side
- def custom_icaplot(ica, inst, top_n=20, start=0, stop=5.001, title=None):
- picks = [f'ICA{i:03d}' for i in range(top_n)]
- sources = ica.get_sources(inst, start=start, stop=stop)
- source_dat = sources.get_data(picks=picks)
- source_time = sources.times + start
- eog_dat = inst.get_data(tmin=start, tmax=stop, picks='eog')
- eog_names = inst.copy().pick('eog').info.ch_names
- nrows = int(np.ceil((top_n+2) / 2))
- fig = plt.figure(figsize=(6.5, 7), layout="constrained")
- spec = fig.add_gridspec(nrows, 2*4)
- for i in range(top_n):
- prop_explained_i = ica.get_explained_variance_ratio(inst, components=[i])
- if i > (nrows-1):
- fig_col = 1
- fig_row = i - nrows
- else:
- fig_col = 0
- fig_row = i
- if fig_col==0:
- ax_L = fig.add_subplot(spec[fig_row, 1:4])
- elif fig_col==1:
- ax_L = fig.add_subplot(spec[fig_row, 5:8])
- ax_L.plot(source_time, source_dat[i, :], color='k', linewidth=0.75)
- ax_L.yaxis.set_ticks_position('none')
- ax_L.yaxis.set_ticklabels('')
- ax_L.set_ylabel(f'{picks[i]}\n{round(prop_explained_i["eeg"]*100, 1)}%', rotation='horizontal', horizontalalignment='right', verticalalignment='center')
- ax_L.set_ylim(-np.abs(source_dat).max(), np.abs(source_dat).max())
- ax_L.spines['left'].set_visible(False)
- ax_L.spines['right'].set_visible(False)
- ax_L.spines['top'].set_visible(False)
- ax_L.spines['bottom'].set_visible(False)
- ax_L.margins(0)
- if i == nrows-1:
- ax_L.set_xlabel('Time (s)')
- else:
- ax_L.xaxis.set_ticks_position('none')
- ax_L.xaxis.set_ticklabels('')
- if fig_col==0:
- ax_R = fig.add_subplot(spec[fig_row, 0])
- elif fig_col==1:
- ax_R = fig.add_subplot(spec[fig_row, 4])
- ica.plot_components(picks=i, axes=ax_R, show=False, sensors=False, title=None)
- ax_R.set_title(None)
- for j, eog_lab in enumerate(eog_names):
- ax_L = fig.add_subplot(spec[fig_row+1+j, 5:8])
- ax_L.plot(source_time, eog_dat[j, :], color='k', linewidth=0.75)
- ax_L.yaxis.set_ticks_position('none')
- ax_L.yaxis.set_ticklabels('')
- if j < len(eog_names)-1:
- ax_L.xaxis.set_ticks_position('none')
- ax_L.xaxis.set_ticklabels('')
- else:
- ax_L.set_xlabel('Time (s)')
- ax_L.set_ylabel(eog_lab, rotation='horizontal', horizontalalignment='right', verticalalignment='center')
- ax_L.set_ylim(-np.abs(eog_dat).max(), np.abs(eog_dat).max())
- ax_L.spines['left'].set_visible(False)
- ax_L.spines['right'].set_visible(False)
- ax_L.spines['top'].set_visible(False)
- ax_L.spines['bottom'].set_visible(False)
- ax_L.margins(0)
-
- if title is not None:
- fig.suptitle(f'{title} (Total {round(ica.get_explained_variance_ratio(inst)["eeg"]*100, 1)}% Variance)')
-
- return fig
- # import the example participant's data and preprocess up to the ICA step
- # %%
- # settings
- # maximum number of mistakes (false positives + false negatives)
- max_mistakes = 50
- # max number of channels to interpolate
- max_chs_interp = 8
- # min number of trials in the cleaned dataset
- min_total_epochs = 0
- # minimum number of acceptable trials per character
- min_trials_per_char = 10
- # rejection criteria for trials and channels
- reject_criteria = dict(eeg=150e-6) # max peak-to-peak amplitude of 150 µV
- flat_criteria = dict(eeg=5e-7) # min peak-to-peak amplitude of 0.5 µV
- prop_bad_to_remove_ch = 0.5 # proportion of trials that a given channel needs to be rejected by the above criteria for the whole channel to just be interpolated
- # %%
- # import participant list
- p_list = pd.read_csv(op.join('eeg', 'participants.tsv'), delimiter='\t')
- # exclude the participant for whom the recording was restarted
- p_list = p_list.loc[p_list.recording_restarted == 0]
- i, p_row = list(p_list.iterrows())[example_sub_idx]
- subj_id = p_row.participant_id
- subj_nr = re.sub('^sub-', '', subj_id) # the number as a string, including leading zeroes but without thr 'sub-' prefix, as expected by mne_bids
- head_circum_cm = p_row.head_circum
- head_radius_m = (head_circum_cm / 2 / np.pi) / 100 # used in setting up montage - relevant to speherical interpolation
- subj_eeg_path = mne_bids.BIDSPath(
- root = 'eeg',
- subject = subj_nr,
- task = 'alphabeticdecision',
- suffix = 'eeg',
- datatype = 'eeg',
- extension = 'vhdr'
- )
- subj_events_tsv_path = mne_bids.BIDSPath(
- root = 'eeg',
- subject = subj_nr,
- task = 'alphabeticdecision',
- suffix = 'events',
- datatype = 'eeg',
- extension = 'tsv'
- )
- # import the events file tsv (behavioural data is stored in there)
- subj_events = pd.read_csv(subj_events_tsv_path, delimiter='\t')
- # get just the trial information
- subj_beh = subj_events.loc[(subj_events.trial_type =='target') | (subj_events.trial_type =='nontarget')]
- new_chs_to_interp = []
- print(f'\nPREPROCESSING {subj_id} up to ICA\n')
- # will contain metadata for this participant
- subj_metadata = {'subj_id': subj_id}
- # import EEG
- raw = mne_bids.read_raw_bids(subj_eeg_path)
- raw.load_data()
- # check sampling frequency
- assert raw.info['sfreq'] == 1000
- # set electrode locations =============
- mon = mne.channels.read_custom_montage('AC-64.bvef', head_size=head_radius_m)
- raw.set_montage(mon)
- # get event triggers ==================
- events, event_dict = mne.events_from_annotations(raw)
- # mne.viz.plot_events(events, sfreq=raw.info['sfreq'], first_samp=raw.first_samp, event_id=event_dict)
- # handling of bad channels ============
- # check number of interpolated channels is not too high
- assert len(raw.info['bads']) <= max_chs_interp
- # interpolate bad channels (spherical spline method) to avoid biasing average reference
- print('Interpolating electrodes: ', raw.info['bads'])
- raw = raw.interpolate_bads(method={'eeg': 'spline'})
- # raw.plot()
- # filter ==============================
- # use an average EEG reference
- raw.set_eeg_reference('average', ch_type='eeg')
- # filter the data =====================
- raw_ica = raw.copy() # create copy for ICA
- iir_params = {'order': 1, 'ftype': 'butter'}
- # bandpass filter with effective 4th order butterworth between .1 and 40 Hz
- raw.filter(0.1, 40, method='iir', phase='zero-double', iir_params=iir_params)
- # version of the data with a highpass of 1 Hz to improve ICA performance
- # raw_ica.filter(1, 40, method='iir', phase='zero-double', iir_params=iir_params)
- raw_ica.filter(1, 100, method='iir', phase='zero-double', iir_params=iir_params) # changed from preregistration to match iclabel dataset
- # example 5 seconds with two blinks
- plot_start = 1486.0
- plot_stop = 1491.001
- evoked_dict = {}
- for n_comp in [32, 64]:
- ica = mne.preprocessing.ICA(
- method='infomax',
- fit_params=dict(extended=True),
- n_components=n_comp,
- max_iter='auto',
- random_state=97)
- ica.fit(raw_ica)
- fig = custom_icaplot(ica, inst=raw, top_n=20, title=f'{n_comp} Components', start=plot_start, stop=plot_stop)
- fig.savefig(op.join('fig', 'ica', f'ica_N_{n_comp}_example.pdf'))
- # predict ICA labels from iclabel
- # extract the labels of each component
- ic_labels = label_components(raw_ica, ica, method='iclabel')
- print(ic_labels)
- # exclude components classified as eye blinks or muscle artefacts with at least 85% predicted probability
- exclude_idx = np.where(
- (~np.isin(np.array(ic_labels['labels']), ['brain', 'other'])) &
- (ic_labels['y_pred_proba']>0.85)
- )[0]
- print(f'Excluding these ICA components: {exclude_idx}')
- # apply ICA ===========================
- raw_clean = raw.copy()
- ica.apply(raw_clean, exclude=exclude_idx)
- # epoch and plot ERP
- # Note: nt stands for non-target
- beh_nt = subj_beh.drop(subj_beh[subj_beh.target==1].index).reset_index()
- events_nt = events[events[:, 2]==event_dict['nontarget'], :]
- # manually change the event IDs to have a unique ID for each character
- event_ids_char_ids = {c: i+1 for i, c in enumerate(chars)}
- events_nt[:, 2] = [event_ids_char_ids[c_x] for c_x in beh_nt.stimulus]
- # adjust for the 8 ms delay between trigger onset and stimuli appearing at the centre of the screen
- events_nt[:, 0] += np.round( 8 * (raw.info['sfreq'] / 1000) ).astype(int) # will be 8 samples, as we recorded at 1000 Hz
- epochs_nt = mne.Epochs(
- raw_clean,
- events=events_nt,
- # event_id={'nontarget': 1},
- event_id = event_ids_char_ids,
- tmin=-0.2,
- tmax=1.0,
- preload=True,
- reject_by_annotation=True,
- reject=None,
- flat=None
- )
- evoked_dict[str(n_comp)] = epochs_nt.average().copy()
- epochs_nt = mne.Epochs(
- raw,
- events=events_nt,
- # event_id={'nontarget': 1},
- event_id = event_ids_char_ids,
- tmin=-0.2,
- tmax=1.0,
- preload=True,
- reject_by_annotation=True,
- reject=None,
- flat=None
- )
- evoked_dict['pre_ica'] = epochs_nt.average().copy()
- # plot ERPs
- plot_labels = {'pre_ica': 'No ICA',
- '32': '32-Component\nICA + ICLabel',
- '64': '64-Component\nICA + ICLabel'}
- yextent = np.max([np.abs(x.get_data(units='uV')).max() for x in evoked_dict.values()])
- ylims = [-np.round(yextent + 0.05, 1), np.round(yextent + 0.05, 1)]
- # reduced ylims for 64-component results so it is visible
- yextent_64 = np.abs(evoked_dict['64'].get_data(units='uV')).max()
- ylims_64 = [-np.round(yextent_64 + 0.005, 1), np.round(yextent_64 + 0.005, 1)]
- fig = plt.figure(figsize=(4.5, 4.5), layout="constrained")
- spec = fig.add_gridspec(len(plot_labels), 5)
- for r, id in enumerate(plot_labels.keys()):
- label_ax = fig.add_subplot(spec[r, 0])
- label_ax.annotate(text=plot_labels[id], xy=(0.5,0.5), xycoords='axes fraction', horizontalalignment='center', verticalalignment='center')
- label_ax.axis('off')
- ax = fig.add_subplot(spec[r, 1:])
- pl = evoked_dict[id].plot(time_unit='ms', axes=ax, show=False, selectable=False)
- ax.set_title(None)
- ax.spines[['left', 'right', 'bottom', 'top']].set_visible(False)
- if id=='64':
- ax.set_ylim(ylims_64[0], ylims_64[1])
- else:
- ax.set_ylim(ylims[0], ylims[1])
- for L, ln in enumerate(ax.lines):
- ln.set_zorder(L+2)
-
- hl = ax.axhline(y=0, xmin=-200, xmax=1000, color='k')
- vl = ax.axvline(x=0, ymin=-100, ymax=100, color='k')
- hl.set_zorder = 0
- vl.set_zorder = 1
- ax.texts[1].remove()
- fig.savefig(op.join('fig', 'ica', 'ica_N_example_impact.pdf'))
|