02_preprocessing.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # -*- coding: utf-8 -*-
  2. # %%
  3. # import modules
  4. import os.path as op
  5. import glob
  6. import pandas as pd
  7. import numpy as np
  8. import re
  9. import mne
  10. from mne_icalabel import label_components
  11. import mne_bids
  12. # import matplotlib as mpl
  13. from string import ascii_lowercase
  14. chars = [*ascii_lowercase, 'ä', 'ö', 'ü', 'ß']
  15. # %matplotlib qt
  16. # %%
  17. # settings
  18. # maximum number of mistakes (false positives + false negatives)
  19. max_mistakes = 50
  20. # max number of channels to interpolate
  21. max_chs_interp = 8
  22. # min number of trials in the cleaned dataset
  23. min_total_epochs = 0
  24. # minimum number of acceptable trials per character
  25. min_trials_per_char = 10
  26. # rejection criteria for trials and channels
  27. reject_criteria = dict(eeg=150e-6) # max peak-to-peak amplitude of 150 µV
  28. flat_criteria = dict(eeg=5e-7) # min peak-to-peak amplitude of 0.5 µV
  29. prop_bad_to_remove_ch = 0.5 # proportion of trials that a given channel needs to be rejected by the above criteria for the whole channel to just be interpolated
  30. # %%
  31. # import participant list
  32. p_list = pd.read_csv(op.join('eeg', 'participants.tsv'), delimiter='\t')
  33. # exclude the participant for whom the recording was restarted
  34. p_list = p_list.loc[p_list.recording_restarted == 0]
  35. # store all metadata from preprocessing
  36. all_subj_metadata = []
  37. # %%
  38. # preprocess
  39. for i, p_row in p_list.iterrows():
  40. subj_id = p_row.participant_id
  41. subj_nr = re.sub('^sub-', '', subj_id) # the number as a string, including leading zeroes but without thr 'sub-' prefix, as expected by mne_bids
  42. head_circum_cm = p_row.head_circum
  43. head_radius_m = (head_circum_cm / 2 / np.pi) / 100 # used in setting up montage - relevant to speherical interpolation
  44. subj_eeg_path = mne_bids.BIDSPath(
  45. root = 'eeg',
  46. subject = subj_nr,
  47. task = 'alphabeticdecision',
  48. suffix = 'eeg',
  49. datatype = 'eeg',
  50. extension = 'vhdr'
  51. )
  52. subj_events_tsv_path = mne_bids.BIDSPath(
  53. root = 'eeg',
  54. subject = subj_nr,
  55. task = 'alphabeticdecision',
  56. suffix = 'events',
  57. datatype = 'eeg',
  58. extension = 'tsv'
  59. )
  60. # import the events file tsv (behavioural data is stored in there)
  61. subj_events = pd.read_csv(subj_events_tsv_path, delimiter='\t')
  62. # get just the trial information
  63. subj_beh = subj_events.loc[(subj_events.trial_type =='target') | (subj_events.trial_type =='nontarget')]
  64. n_attempts = 0
  65. do_restart = True
  66. new_chs_to_interp = []
  67. while do_restart:
  68. n_attempts += 1
  69. assert n_attempts <= 2 # should exclude if fails to work on second attempt
  70. print(f'\nPREPROCESSING {subj_id}, attempt {n_attempts}\n')
  71. # will contain metadata for this participant
  72. subj_metadata = {'subj_id': subj_id}
  73. # check the participant didn't make too many mistakes
  74. # (note, accuracy is stored as string representation of float)
  75. subj_metadata['n_false_neg_resp'] = np.sum((subj_beh['accuracy'] == 0) & (subj_beh['target'] == 1)) # includes RT timeout
  76. subj_metadata['n_false_pos_resp'] = np.sum((subj_beh['accuracy'] == 0) & (subj_beh['target'] == 0))
  77. # import EEG
  78. raw = mne_bids.read_raw_bids(subj_eeg_path)
  79. raw.load_data()
  80. # check sampling frequency
  81. assert raw.info['sfreq'] == 1000
  82. # set electrode locations =============
  83. mon = mne.channels.read_custom_montage('AC-64.bvef', head_size=head_radius_m)
  84. raw.set_montage(mon)
  85. # get event triggers ==================
  86. events, event_dict = mne.events_from_annotations(raw)
  87. # mne.viz.plot_events(events, sfreq=raw.info['sfreq'], first_samp=raw.first_samp, event_id=event_dict)
  88. # handling of bad channels ============
  89. # if any extra bad channels, add these
  90. raw.info['bads'].extend(new_chs_to_interp)
  91. # check number of interpolated channels is not too high
  92. assert len(raw.info['bads']) <= max_chs_interp
  93. # interpolate bad channels (spherical spline method) to avoid biasing average reference
  94. print('Interpolating electrodes: ', raw.info['bads'])
  95. n_chs_interpolated = len(raw.info['bads'])
  96. raw = raw.interpolate_bads(method={'eeg': 'spline'})
  97. # raw.plot()
  98. # filter ==============================
  99. # use an average EEG reference
  100. raw.set_eeg_reference('average', ch_type='eeg')
  101. # filter the data =====================
  102. raw_ica = raw.copy() # create copy for ICA
  103. iir_params = {'order': 1, 'ftype': 'butter'}
  104. # bandpass filter with effective 4th order butterworth between .1 and 40 Hz
  105. raw.filter(0.1, 40, method='iir', phase='zero-double', iir_params=iir_params)
  106. # version of the data with a highpass of 1 Hz to improve ICA performance
  107. # raw_ica.filter(1, 40, method='iir', phase='zero-double', iir_params=iir_params)
  108. raw_ica.filter(1, 100, method='iir', phase='zero-double', iir_params=iir_params) # changed from preregistration to match iclabel dataset
  109. # fit ICA =============================
  110. ica = mne.preprocessing.ICA(
  111. method='infomax',
  112. fit_params=dict(extended=True),
  113. # n_components=64,
  114. n_components=32, # changed from preregistration, as 64 components led to problems
  115. max_iter='auto',
  116. random_state=97)
  117. ica.fit(raw_ica, picks='eeg')
  118. # ica.plot_components();
  119. # ica.plot_sources(raw);
  120. # ICLabel =============================
  121. # predict ICA labels from iclabel
  122. # extract the labels of each component
  123. ic_labels = label_components(raw_ica, ica, method='iclabel')
  124. print(ic_labels)
  125. # exclude components classified as something other than "brain" or "other" with at least 85% predicted probability
  126. exclude_idx = np.where(
  127. (~np.isin(np.array(ic_labels['labels']), ['brain', 'other'])) &
  128. (ic_labels['y_pred_proba']>0.85)
  129. )[0]
  130. print(f'Excluding these ICA components: {exclude_idx}')
  131. # apply ICA ===========================
  132. ica.apply(raw, exclude=exclude_idx)
  133. subj_metadata['n_ica_excluded'] = len(exclude_idx)
  134. # epoch the data ======================
  135. # Note: nt stands for non-target
  136. beh_nt = subj_beh.drop(subj_beh[subj_beh.target==1].index).reset_index()
  137. events_nt = events[events[:, 2]==event_dict['nontarget'], :]
  138. assert( len(beh_nt) == events_nt.shape[0] )
  139. # manually change the event IDs to have a unique ID for each character
  140. event_ids_char_ids = {c: i+1 for i, c in enumerate(chars)}
  141. events_nt[:, 2] = [event_ids_char_ids[c_x] for c_x in beh_nt.stimulus]
  142. # adjust for the 8 ms delay between trigger onset and stimuli appearing at the centre of the screen
  143. events_nt[:, 0] += np.round( 8 * (raw.info['sfreq'] / 1000) ).astype(int) # will be 8 samples, as we recorded at 1000 Hz
  144. epochs_nt = mne.Epochs(
  145. raw,
  146. events=events_nt,
  147. # event_id={'nontarget': 1},
  148. event_id = event_ids_char_ids,
  149. tmin=-0.2,
  150. tmax=1.0,
  151. preload=True,
  152. reject_by_annotation=True,
  153. reject=None,
  154. flat=None
  155. )
  156. beh_nt['drop_log'] = [';'.join(x) for x in epochs_nt.drop_log]
  157. beh_nt_clean = beh_nt[beh_nt['drop_log'] == '']
  158. # for each trial, also store the IDs as metadata in the epochs
  159. epochs_nt.metadata = beh_nt_clean[['stimulus']]
  160. # now reject bad epochs
  161. epochs_nt.drop_bad(reject=reject_criteria, flat=flat_criteria)
  162. # look for consistently noisy channels
  163. # ...first, extract the drop log to count the number of trials for which each channel is marked as bad...
  164. bad_counts = {k: 0 for k in raw.info['ch_names']}
  165. for bad_i in epochs_nt.drop_log:
  166. for str_j in bad_i:
  167. if str_j in raw.info['ch_names']:
  168. bad_counts[str_j] += 1
  169. # ...then convert this to a proportion of available non-target trials
  170. bad_props = {k: v/len(epochs_nt) for k, v in bad_counts.items()}
  171. # ...make note of these bad channels to be interpolated in the second attempt
  172. new_chs_to_interp = [k for k, v in bad_props.items() if v > prop_bad_to_remove_ch]
  173. # ...and finally, restart if this list is not empty, with those new channels interpolated at the start of the pipeline
  174. if len(new_chs_to_interp)==0:
  175. do_restart = False
  176. print(epochs_nt)
  177. # check the number of trials retained per item
  178. stim_counts = epochs_nt.metadata.groupby(['stimulus']).size().reset_index(name='n')
  179. assert np.min(stim_counts.n) >= min_trials_per_char
  180. subj_metadata['min_trials_per_char'] = np.min(stim_counts.n)
  181. subj_metadata['mean_trials_per_char'] = np.mean(stim_counts.n)
  182. subj_metadata['max_trials_per_char'] = np.max(stim_counts.n)
  183. subj_metadata['n_epochs'] = len(epochs_nt)
  184. subj_metadata['n_epochs_excluded'] = 750 - len(epochs_nt)
  185. subj_metadata['n_extra_chans_excluded'] = len(new_chs_to_interp)
  186. subj_metadata['n_total_chans_interpolated'] = n_chs_interpolated
  187. subj_metadata['n_preprocessing_attempts'] = n_attempts
  188. # save epochs object
  189. epo_file = f'{subj_id}-epo.fif'
  190. epochs_nt.save(op.join('epo', epo_file), overwrite=True)
  191. all_subj_metadata.append(subj_metadata)
  192. metadata_df = pd.DataFrame(all_subj_metadata)
  193. metadata_df.to_csv('preprocessing_metadata.csv', index=False)