99_plot_ica_N_change_examples.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. # %%
  2. # import modules
  3. import os.path as op
  4. import glob
  5. import pandas as pd
  6. import numpy as np
  7. import re
  8. import mne
  9. from mne_icalabel import label_components
  10. import mne_bids
  11. # import matplotlib as mpl
  12. from string import ascii_lowercase
  13. import matplotlib.pyplot as plt
  14. chars = [*ascii_lowercase, 'ä', 'ö', 'ü', 'ß']
  15. # %matplotlib qt
  16. fig_font_size = 8
  17. plt.rcParams.update({
  18. "text.usetex": False,
  19. "font.family": "Helvetica",
  20. 'font.size': fig_font_size
  21. })
  22. example_sub_idx = 10 # will index sub-12 after subsetting data
  23. # function to plot sources and topographies side by side
  24. def custom_icaplot(ica, inst, top_n=20, start=0, stop=5.001, title=None):
  25. picks = [f'ICA{i:03d}' for i in range(top_n)]
  26. sources = ica.get_sources(inst, start=start, stop=stop)
  27. source_dat = sources.get_data(picks=picks)
  28. source_time = sources.times + start
  29. eog_dat = inst.get_data(tmin=start, tmax=stop, picks='eog')
  30. eog_names = inst.copy().pick('eog').info.ch_names
  31. nrows = int(np.ceil((top_n+2) / 2))
  32. fig = plt.figure(figsize=(6.5, 7), layout="constrained")
  33. spec = fig.add_gridspec(nrows, 2*4)
  34. for i in range(top_n):
  35. prop_explained_i = ica.get_explained_variance_ratio(inst, components=[i])
  36. if i > (nrows-1):
  37. fig_col = 1
  38. fig_row = i - nrows
  39. else:
  40. fig_col = 0
  41. fig_row = i
  42. if fig_col==0:
  43. ax_L = fig.add_subplot(spec[fig_row, 1:4])
  44. elif fig_col==1:
  45. ax_L = fig.add_subplot(spec[fig_row, 5:8])
  46. ax_L.plot(source_time, source_dat[i, :], color='k', linewidth=0.75)
  47. ax_L.yaxis.set_ticks_position('none')
  48. ax_L.yaxis.set_ticklabels('')
  49. ax_L.set_ylabel(f'{picks[i]}\n{round(prop_explained_i["eeg"]*100, 1)}%', rotation='horizontal', horizontalalignment='right', verticalalignment='center')
  50. ax_L.set_ylim(-np.abs(source_dat).max(), np.abs(source_dat).max())
  51. ax_L.spines['left'].set_visible(False)
  52. ax_L.spines['right'].set_visible(False)
  53. ax_L.spines['top'].set_visible(False)
  54. ax_L.spines['bottom'].set_visible(False)
  55. ax_L.margins(0)
  56. if i == nrows-1:
  57. ax_L.set_xlabel('Time (s)')
  58. else:
  59. ax_L.xaxis.set_ticks_position('none')
  60. ax_L.xaxis.set_ticklabels('')
  61. if fig_col==0:
  62. ax_R = fig.add_subplot(spec[fig_row, 0])
  63. elif fig_col==1:
  64. ax_R = fig.add_subplot(spec[fig_row, 4])
  65. ica.plot_components(picks=i, axes=ax_R, show=False, sensors=False, title=None)
  66. ax_R.set_title(None)
  67. for j, eog_lab in enumerate(eog_names):
  68. ax_L = fig.add_subplot(spec[fig_row+1+j, 5:8])
  69. ax_L.plot(source_time, eog_dat[j, :], color='k', linewidth=0.75)
  70. ax_L.yaxis.set_ticks_position('none')
  71. ax_L.yaxis.set_ticklabels('')
  72. if j < len(eog_names)-1:
  73. ax_L.xaxis.set_ticks_position('none')
  74. ax_L.xaxis.set_ticklabels('')
  75. else:
  76. ax_L.set_xlabel('Time (s)')
  77. ax_L.set_ylabel(eog_lab, rotation='horizontal', horizontalalignment='right', verticalalignment='center')
  78. ax_L.set_ylim(-np.abs(eog_dat).max(), np.abs(eog_dat).max())
  79. ax_L.spines['left'].set_visible(False)
  80. ax_L.spines['right'].set_visible(False)
  81. ax_L.spines['top'].set_visible(False)
  82. ax_L.spines['bottom'].set_visible(False)
  83. ax_L.margins(0)
  84. if title is not None:
  85. fig.suptitle(f'{title} (Total {round(ica.get_explained_variance_ratio(inst)["eeg"]*100, 1)}% Variance)')
  86. return fig
  87. # import the example participant's data and preprocess up to the ICA step
  88. # %%
  89. # settings
  90. # maximum number of mistakes (false positives + false negatives)
  91. max_mistakes = 50
  92. # max number of channels to interpolate
  93. max_chs_interp = 8
  94. # min number of trials in the cleaned dataset
  95. min_total_epochs = 0
  96. # minimum number of acceptable trials per character
  97. min_trials_per_char = 10
  98. # rejection criteria for trials and channels
  99. reject_criteria = dict(eeg=150e-6) # max peak-to-peak amplitude of 150 µV
  100. flat_criteria = dict(eeg=5e-7) # min peak-to-peak amplitude of 0.5 µV
  101. prop_bad_to_remove_ch = 0.5 # proportion of trials that a given channel needs to be rejected by the above criteria for the whole channel to just be interpolated
  102. # %%
  103. # import participant list
  104. p_list = pd.read_csv(op.join('eeg', 'participants.tsv'), delimiter='\t')
  105. # exclude the participant for whom the recording was restarted
  106. p_list = p_list.loc[p_list.recording_restarted == 0]
  107. i, p_row = list(p_list.iterrows())[example_sub_idx]
  108. subj_id = p_row.participant_id
  109. subj_nr = re.sub('^sub-', '', subj_id) # the number as a string, including leading zeroes but without thr 'sub-' prefix, as expected by mne_bids
  110. head_circum_cm = p_row.head_circum
  111. head_radius_m = (head_circum_cm / 2 / np.pi) / 100 # used in setting up montage - relevant to speherical interpolation
  112. subj_eeg_path = mne_bids.BIDSPath(
  113. root = 'eeg',
  114. subject = subj_nr,
  115. task = 'alphabeticdecision',
  116. suffix = 'eeg',
  117. datatype = 'eeg',
  118. extension = 'vhdr'
  119. )
  120. subj_events_tsv_path = mne_bids.BIDSPath(
  121. root = 'eeg',
  122. subject = subj_nr,
  123. task = 'alphabeticdecision',
  124. suffix = 'events',
  125. datatype = 'eeg',
  126. extension = 'tsv'
  127. )
  128. # import the events file tsv (behavioural data is stored in there)
  129. subj_events = pd.read_csv(subj_events_tsv_path, delimiter='\t')
  130. # get just the trial information
  131. subj_beh = subj_events.loc[(subj_events.trial_type =='target') | (subj_events.trial_type =='nontarget')]
  132. new_chs_to_interp = []
  133. print(f'\nPREPROCESSING {subj_id} up to ICA\n')
  134. # will contain metadata for this participant
  135. subj_metadata = {'subj_id': subj_id}
  136. # import EEG
  137. raw = mne_bids.read_raw_bids(subj_eeg_path)
  138. raw.load_data()
  139. # check sampling frequency
  140. assert raw.info['sfreq'] == 1000
  141. # set electrode locations =============
  142. mon = mne.channels.read_custom_montage('AC-64.bvef', head_size=head_radius_m)
  143. raw.set_montage(mon)
  144. # get event triggers ==================
  145. events, event_dict = mne.events_from_annotations(raw)
  146. # mne.viz.plot_events(events, sfreq=raw.info['sfreq'], first_samp=raw.first_samp, event_id=event_dict)
  147. # handling of bad channels ============
  148. # check number of interpolated channels is not too high
  149. assert len(raw.info['bads']) <= max_chs_interp
  150. # interpolate bad channels (spherical spline method) to avoid biasing average reference
  151. print('Interpolating electrodes: ', raw.info['bads'])
  152. raw = raw.interpolate_bads(method={'eeg': 'spline'})
  153. # raw.plot()
  154. # filter ==============================
  155. # use an average EEG reference
  156. raw.set_eeg_reference('average', ch_type='eeg')
  157. # filter the data =====================
  158. raw_ica = raw.copy() # create copy for ICA
  159. iir_params = {'order': 1, 'ftype': 'butter'}
  160. # bandpass filter with effective 4th order butterworth between .1 and 40 Hz
  161. raw.filter(0.1, 40, method='iir', phase='zero-double', iir_params=iir_params)
  162. # version of the data with a highpass of 1 Hz to improve ICA performance
  163. # raw_ica.filter(1, 40, method='iir', phase='zero-double', iir_params=iir_params)
  164. raw_ica.filter(1, 100, method='iir', phase='zero-double', iir_params=iir_params) # changed from preregistration to match iclabel dataset
  165. # example 5 seconds with two blinks
  166. plot_start = 1486.0
  167. plot_stop = 1491.001
  168. evoked_dict = {}
  169. for n_comp in [32, 64]:
  170. ica = mne.preprocessing.ICA(
  171. method='infomax',
  172. fit_params=dict(extended=True),
  173. n_components=n_comp,
  174. max_iter='auto',
  175. random_state=97)
  176. ica.fit(raw_ica)
  177. fig = custom_icaplot(ica, inst=raw, top_n=20, title=f'{n_comp} Components', start=plot_start, stop=plot_stop)
  178. fig.savefig(op.join('fig', 'ica', f'ica_N_{n_comp}_example.pdf'))
  179. # predict ICA labels from iclabel
  180. # extract the labels of each component
  181. ic_labels = label_components(raw_ica, ica, method='iclabel')
  182. print(ic_labels)
  183. # exclude components classified as eye blinks or muscle artefacts with at least 85% predicted probability
  184. exclude_idx = np.where(
  185. (~np.isin(np.array(ic_labels['labels']), ['brain', 'other'])) &
  186. (ic_labels['y_pred_proba']>0.85)
  187. )[0]
  188. print(f'Excluding these ICA components: {exclude_idx}')
  189. # apply ICA ===========================
  190. raw_clean = raw.copy()
  191. ica.apply(raw_clean, exclude=exclude_idx)
  192. # epoch and plot ERP
  193. # Note: nt stands for non-target
  194. beh_nt = subj_beh.drop(subj_beh[subj_beh.target==1].index).reset_index()
  195. events_nt = events[events[:, 2]==event_dict['nontarget'], :]
  196. # manually change the event IDs to have a unique ID for each character
  197. event_ids_char_ids = {c: i+1 for i, c in enumerate(chars)}
  198. events_nt[:, 2] = [event_ids_char_ids[c_x] for c_x in beh_nt.stimulus]
  199. # adjust for the 8 ms delay between trigger onset and stimuli appearing at the centre of the screen
  200. events_nt[:, 0] += np.round( 8 * (raw.info['sfreq'] / 1000) ).astype(int) # will be 8 samples, as we recorded at 1000 Hz
  201. epochs_nt = mne.Epochs(
  202. raw_clean,
  203. events=events_nt,
  204. # event_id={'nontarget': 1},
  205. event_id = event_ids_char_ids,
  206. tmin=-0.2,
  207. tmax=1.0,
  208. preload=True,
  209. reject_by_annotation=True,
  210. reject=None,
  211. flat=None
  212. )
  213. evoked_dict[str(n_comp)] = epochs_nt.average().copy()
  214. epochs_nt = mne.Epochs(
  215. raw,
  216. events=events_nt,
  217. # event_id={'nontarget': 1},
  218. event_id = event_ids_char_ids,
  219. tmin=-0.2,
  220. tmax=1.0,
  221. preload=True,
  222. reject_by_annotation=True,
  223. reject=None,
  224. flat=None
  225. )
  226. evoked_dict['pre_ica'] = epochs_nt.average().copy()
  227. # plot ERPs
  228. plot_labels = {'pre_ica': 'No ICA',
  229. '32': '32-Component\nICA + ICLabel',
  230. '64': '64-Component\nICA + ICLabel'}
  231. yextent = np.max([np.abs(x.get_data(units='uV')).max() for x in evoked_dict.values()])
  232. ylims = [-np.round(yextent + 0.05, 1), np.round(yextent + 0.05, 1)]
  233. # reduced ylims for 64-component results so it is visible
  234. yextent_64 = np.abs(evoked_dict['64'].get_data(units='uV')).max()
  235. ylims_64 = [-np.round(yextent_64 + 0.005, 1), np.round(yextent_64 + 0.005, 1)]
  236. fig = plt.figure(figsize=(4.5, 4.5), layout="constrained")
  237. spec = fig.add_gridspec(len(plot_labels), 5)
  238. for r, id in enumerate(plot_labels.keys()):
  239. label_ax = fig.add_subplot(spec[r, 0])
  240. label_ax.annotate(text=plot_labels[id], xy=(0.5,0.5), xycoords='axes fraction', horizontalalignment='center', verticalalignment='center')
  241. label_ax.axis('off')
  242. ax = fig.add_subplot(spec[r, 1:])
  243. pl = evoked_dict[id].plot(time_unit='ms', axes=ax, show=False, selectable=False)
  244. ax.set_title(None)
  245. ax.spines[['left', 'right', 'bottom', 'top']].set_visible(False)
  246. if id=='64':
  247. ax.set_ylim(ylims_64[0], ylims_64[1])
  248. else:
  249. ax.set_ylim(ylims[0], ylims[1])
  250. for L, ln in enumerate(ax.lines):
  251. ln.set_zorder(L+2)
  252. hl = ax.axhline(y=0, xmin=-200, xmax=1000, color='k')
  253. vl = ax.axvline(x=0, ymin=-100, ymax=100, color='k')
  254. hl.set_zorder = 0
  255. vl.set_zorder = 1
  256. ax.texts[1].remove()
  257. fig.savefig(op.join('fig', 'ica', 'ica_N_example_impact.pdf'))