Design-Lab 4 лет назад
Родитель
Сommit
533d926461
3 измененных файлов с 1527 добавлено и 0 удалено
  1. 1044 0
      code/microstate.py
  2. 306 0
      code/prep_pipeline.py
  3. 177 0
      code/task_related_power.py

Разница между файлами не показана из-за своего большого размера
+ 1044 - 0
code/microstate.py


+ 306 - 0
code/prep_pipeline.py

@@ -0,0 +1,306 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# EEG signals respond differently to idea generation, idea evolution, and evaluation in a loosely controlled creativity experiment
+# Time    : 2019-10-10
+# Author  : Wenjun Jia
+# File    : pre_pipeline.py
+
+
+
+import numpy as np
+import codecs, json
+import mne
+from scipy import stats
+import matlab
+import matlab.engine
+from pyprep.noisy import Noisydata
+from eeg_tool.model.raw_data import RawData
+from collections import OrderedDict, Counter
+from openpyxl import Workbook, load_workbook
+from eeg_tool.utilis import read_subject_info
+import os
+from multiprocessing import Pool
+import warnings
+from sklearn.decomposition import FastICA
+
+
+class PreProcessPipeline:
+    engine = matlab.engine.start_matlab()
+    def __init__(self, raw, tasks, trial_info):
+        self.raw = raw
+        self.tasks = tasks
+        self.trial_info = trial_info
+        self.trial_name = []
+        self.trial_duration = []
+        self.tasks_merged = None
+        self.tasks_cleaned = None
+        self.epochs_cleaned = None
+        self.global_bads = None
+        self.global_good_index = None
+        self.global_good_name = None
+        self.n_ch = None
+        self.n_t = None
+        self.fast_ica_convergence = None
+
+
+    def concatenate_tasks(self):
+        flag = True
+        duration = []
+        onset = [0.]
+        for key, value in self.tasks.items():
+            n_time = value.shape[1] / self.raw.info['sfreq']
+            duration.append(n_time)
+            self.trial_name.append(key)
+            self.trial_duration.append(0 if len(self.trial_duration) == 0 else self.trial_duration[-1]+1)
+            self.trial_duration.append(self.trial_duration[-1] + 1 + value.shape[1])
+            if flag:
+                self.tasks_merged = value
+                flag = False
+            else:
+                onset.append(np.sum(duration))
+                self.tasks_merged = np.concatenate((self.tasks_merged, value), axis=1)
+
+        self.tasks_merged = mne.io.RawArray(self.tasks_merged, self.raw.info)
+        self.tasks_merged = self.tasks_merged.set_annotations(mne.Annotations(onset, duration, self.trial_name), False)
+        self.n_t = self.tasks_merged.n_times
+        self.n_ch = len(self.tasks_merged.ch_names)
+
+    def filter(self):
+        if self.tasks_merged.info['sfreq'] >= 500:
+            self.tasks_cleaned = self.tasks_merged.copy().filter(1., 50.)
+        else:
+            self.tasks_cleaned = self.tasks_merged.copy().nortch_filter(1.)
+
+    def remove_line_noise(self):
+        self.tasks_cleaned = self.tasks_cleaned.notch_filter(np.arange(60, 241, 60), filter_length='auto', phase='zero')
+
+    @staticmethod
+    def bad_channel(data):
+        nd = Noisydata(data)
+        nd.find_all_bads()
+        bads = nd.get_bads(verbose=True)
+        return bads
+
+    def remove_bad_channel(self, thread=5, threshold=0.3):
+        pool = Pool(thread)
+        res = []
+        bads_list = []
+        bads = []
+        threshold = threshold * len(self.trial_duration) / 2
+        for i in range(0, len(self.trial_duration), 2):
+            start_index = self.trial_duration[i]
+            end_index = self.trial_duration[i + 1]
+            data_obj = mne.io.RawArray(self.tasks_cleaned._data[:, start_index:end_index], self.raw.info)
+            res.append(pool.apply_async(PreProcessPipeline.bad_channel, (data_obj,)))
+        pool.close()
+        pool.join()
+        for temp in res:
+            bads_list.extend(temp.get())
+        bads_set = Counter(bads_list)
+        for key, value in bads_set.items():
+            if value > threshold:
+                bads.append(key)
+        print(bads_set)
+        print(bads)
+        self.tasks_cleaned.info['bads'] = bads
+        self.global_bads = bads
+        self.global_good_index = np.asarray([i for i in range(self.n_ch) if self.tasks_cleaned.ch_names[i] not in bads])
+        self.global_good_name = [self.tasks_cleaned.ch_names[i] for i in range(self.n_ch) if self.tasks_cleaned.ch_names[i] not in bads]
+        del pool
+
+
+    def remove_artifact_wica(self, wave_name='coif5', level=5, multipier=1, fast_ica_iter=3, tol=0.025):
+        with warnings.catch_warnings():
+            warnings.filterwarnings("error")
+            loop_break = False
+            while 1:
+                for i in range(fast_ica_iter):
+                    try:
+                        ica = FastICA(max_iter=200, whiten=True, tol=tol)
+                        sources = ica.fit_transform(self.tasks_cleaned.get_data(picks=self.global_good_name).T).T
+                        loop_break = True
+                        break
+                    except Warning:
+                        print('FastICA has not converge at {} rounds with tol = {}'.format(i, tol))
+                        pass
+                if loop_break:
+                    break
+                else:
+                    tol = tol * 5
+            self.fast_ica_convergence = 'FastICA converges at {} rounds with tol = {}.'.format(i, tol)
+
+        n_ch = sources.shape[0]
+        n_t = sources.shape[1]
+        artifacts = np.zeros((1, n_t))
+
+        pool = Pool(11)
+        multi_res = [pool.apply_async(PreProcessPipeline.get_artifact_wica, ([sources[i, :], n_t, wave_name, level, multipier],)) for i in range(n_ch)]
+        pool.close()
+        pool.join()
+        for res in multi_res:
+            temp = np.asarray(res.get()).reshape(1, -1)
+            artifacts = np.concatenate((artifacts, temp), axis=0)
+
+        self.tasks_cleaned._data[self.global_good_index] = self.tasks_cleaned.get_data(picks=self.global_good_name) - (np.dot(artifacts[1::, :].T, ica.mixing_.T) + ica.mean_).T
+        del pool
+
+
+    @staticmethod
+    def get_artifact_wica(para):
+        sources = para[0]
+        n_t = para[1]
+        wave_name = para[2]
+        level = para[3]
+        multipier = para[4]
+        modulus = 2 ** level - n_t % 2 ** level
+        sig = np.concatenate((sources, np.zeros(modulus))) if modulus != 0 else sources
+        sig = matlab.double(sig.tolist())
+        thresh, sorh, _ = PreProcessPipeline.engine.ddencmp('den', 'wv', sig, nargout=3)
+        thresh = thresh * multipier
+        swc = PreProcessPipeline.engine.swt(sig, level, wave_name)
+        y = PreProcessPipeline.engine.wthresh(swc, sorh, thresh)
+        w_ic = PreProcessPipeline.engine.iswt(y, wave_name)
+
+        return w_ic[0][0:n_t]
+
+
+    def remove_bad_epochs(self, drop_epoch=0.25, n_times=2):
+        epochs_info = OrderedDict()
+        data = self.tasks_cleaned.get_data()
+        for i in range(0, len(self.trial_duration), 2):
+            task_data = np.zeros((self.n_ch, 1))
+            trial_name = self.trial_name[int(i/2)]
+            epochs_info[trial_name] = OrderedDict()
+
+            start = int(self.trial_duration[i])
+            end = int(self.trial_duration[i+1] - (self.trial_duration[i+1]-self.trial_duration[i]) % (n_times * self.tasks_cleaned.info['sfreq']))
+
+            n_epochs = int((end - start) / (n_times * self.tasks_cleaned.info['sfreq']))
+
+            data_cleaned_epochs = np.asarray(np.hsplit(data[self.global_good_index, start:end], n_epochs))
+            data_epochs = np.asarray(np.hsplit(data[:, start:end], n_epochs))
+
+            bad_channel_epochs = self.bad_epochs_faster(data_cleaned_epochs)
+
+            for j in range(n_epochs):
+                ch = self.get_ch_index(self.tasks_merged.ch_names, self.global_bads, bad_channel_epochs[j])
+                ch_bad_index = np.argwhere(ch == 1)
+                ch_bad_index = ch_bad_index.reshape(ch_bad_index.shape[0]).tolist()
+                ch_bad_name = [self.tasks_merged.ch_names[i] for i in range(self.n_ch) if i in ch_bad_index]
+                ratio = np.sum(ch) / len(ch)
+                temp = mne.io.RawArray(data_epochs[j], info=self.tasks_merged.info.copy())
+                temp.info['bads'] = ch_bad_name
+                if ratio < drop_epoch:
+                    temp = temp.interpolate_bads()
+                    drop = 0
+                    task_data = np.concatenate((task_data, temp.get_data()), axis=1)
+                else:
+                    drop = 1
+                epochs_info[trial_name][str(j)] = {'epoch_data': temp.get_data().tolist(), 'bad_channel': ch_bad_name, 'interpolate_ratio': ratio, 'drop': drop}
+            epochs_info[trial_name]['task_data'] = task_data[:, 1::].tolist()
+        self.epochs_cleaned = epochs_info
+
+    def save_epochs_data(self, data_name=None, info_name=None, sheet_name=None):
+        json.dump(self.epochs_cleaned, codecs.open(data_name, 'w', encoding='utf-8'), separators=(',', ':'), sort_keys=True, indent=4)
+        wb = load_workbook(info_name)
+        sheet = wb.create_sheet(sheet_name)
+        row = 1
+        column = 1
+        for task_name, task_data in self.epochs_cleaned.items():
+            sheet.cell(row=row, column=column).value = task_name
+            column += 1
+            sheet.cell(row=row, column=column).value = "_".join(str(x) for x in self.global_bads)
+            for data_type, data in task_data.items():
+                if data_type != 'task_data':
+                    bads = data['bad_channel']
+                    bads.insert(0, data['drop'])
+                    bads_str = "_".join(str(x) for x in bads)
+                    column += 1
+                    sheet.cell(row=row, column=column).value = bads_str
+
+            row += 1
+            column = 1
+
+        sheet.cell(row=row+1, column=1).value = prep.fast_ica_convergence
+        wb.save(info_name)
+        wb.close()
+
+    @staticmethod
+    def read_epochs_data(fname=None):
+        file_path = r'D:\EEGdata\clean_data_creativity\eeg_jan_29_2014'
+        epochs_text = codecs.open(file_path, 'r', encoding='utf-8').read()
+        epochs_data = json.loads(epochs_text)
+        # print(123)
+
+    @staticmethod
+    def get_ch_index(ch_names, global_bads, local_bads):
+        ch = np.ones(len(ch_names))
+        good_ch_index = np.asarray([i for i in range(len(ch_names)) if ch_names[i] not in global_bads])
+        good_ch_index = good_ch_index[np.argwhere(local_bads == 0)]
+        ch[good_ch_index] = 0
+        return ch
+
+    @staticmethod
+    def bad_epochs_faster(data, threshold=3.29053):
+        shape = data.shape
+        n_epochs = shape[0]
+        n_ch = shape[1]
+        n_times = shape[2]
+        criteria = []
+        criteria.append(np.var(data, axis=2))
+        criteria.append(np.median(np.gradient(data, axis=2), axis=2))
+        criteria.append(np.ptp(data, axis=2))
+        mean_epochs_channel = np.mean(data, axis=2)
+        mean_epochs = np.mean(mean_epochs_channel, axis=1)
+        criteria.append(mean_epochs_channel - mean_epochs.reshape(n_epochs, 1))
+        res = np.zeros((n_epochs, n_ch))
+        for i in range(len(criteria)):
+            zscore = stats.zscore(criteria[i], axis=0)
+            res += np.where(zscore > threshold, 1, 0)
+        res = np.where(res > 0, 1, 0)
+        return res
+
+
+    def reference(self):
+        pass
+
+if __name__ == '__main__':
+    clean_data_fname = r'D:\EEGdata\clean_data_six_problem'
+    clean_data_info = r'D:\EEGdata\clean_data_six_problem\clean_data_six_problem.xlsx'
+
+    raw_data_fname = r'D:\EEGdata\raw_data_six_problem'
+    subject_fname = r'D:\EEGdata\clean_data_six_problem\notebook.xlsx'
+    subjects = read_subject_info(input_fname=r'D:\EEGdata\clean_data_six_problem\subjects.xlsx')
+
+    # raw_data_fname = r'D:\EEGdata\raw_data_creativity'
+    # subject_fname = r'C:\Users\umroot\Desktop\creativity\notebook_res.xlsx'
+    # subjects = read_subject_info(input_fname='C:\\Users\\umroot\\Desktop\\creativity\\subject.xlsx')
+
+    # subjects = ['april_22']
+    subjects_finished = os.listdir(r'D:\EEGdata\clean_data_six_problem')
+
+    for i in subjects_finished:
+        name = i.split('.')[0]
+        if name in subjects:
+            subjects.remove(name)
+    # subjects = ['april_16(3)']
+
+    for subject in subjects:
+        print(subject)
+        data_fname = raw_data_fname + "\\" + subject + "\\" +subject + ".vhdr"
+        clean_data_save = clean_data_fname + "\\" +subject + ".json"
+
+        raw = RawData()
+        raw.read_raw_data(fname=data_fname, montage='D:\\workspace\\eeg_tool\\Cap63.locs', preload=True, scale=1e3)
+        raw.read_trial_info(fname=subject_fname, sheet_name=subject)
+        raw.split_tasks()
+        prep = PreProcessPipeline(raw=raw.raw_data, tasks=raw.tasks_data, trial_info=raw.trial_info)
+        prep.concatenate_tasks()
+        prep.filter()
+        prep.remove_line_noise()
+        prep.remove_bad_channel()
+        prep.remove_artifact_wica()
+        prep.remove_bad_epochs()
+        prep.save_epochs_data(data_name=clean_data_save, info_name=clean_data_info, sheet_name=subject)
+
+        # PreProcessPipeline.read_epochs_data()

+ 177 - 0
code/task_related_power.py

@@ -0,0 +1,177 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# EEG signals respond differently to idea generation, idea evolution, and evaluation in a loosely controlled creativity experiment
+# Time    : 2019-10-10
+# Author  : Wenjun Jia
+# File    : task_related_power.py
+
+
+import matlab
+import matlab.engine
+import numpy as np
+import codecs, json
+from eeg_tool.utilis import read_subject_info
+from scipy.signal import welch
+from scipy.integrate import simps
+from eeg_tool.algorithm.statistic.spss import electrodes_spss
+from openpyxl import Workbook, load_workbook
+from collections import OrderedDict
+from mne.time_frequency import psd_array_multitaper
+
+
+class TaskRelatedPower:
+    def __init__(self, task_data=None, band_frequency=[8,10,10,12], low_frequency=1, high_frequency=40):
+        self.task_data = task_data
+        self.band_frequency = band_frequency
+        self.low_frequency = low_frequency
+        self.high_frequency = high_frequency
+
+    def band_powers(self, fs=500, nperseg=1000, noverlap=500, nfft=1000):
+        res = {}
+        for task_name, task_data in self.task_data.items():
+            res[task_name] = {}
+            for epoch_name, epoch_data in task_data.items():
+                if epoch_name == 'task_data':
+                    sig = np.asarray(epoch_data)
+                    if sig.shape[1] == 0:
+                        continue
+                elif epoch_data['drop'] == 0:
+                    sig = np.asarray(epoch_data['epoch_data'])
+                else:
+                    continue
+                # power by Welch
+                freq, pxx = welch(sig, fs=fs, nperseg=nperseg, noverlap=noverlap, nfft=nfft)
+                total_power = TaskRelatedPower.band_power(freq, pxx, self.low_frequency, self.high_frequency)
+
+                # power by fft
+                # fft_values = np.power(np.absolute(np.fft.rfft(sig, axis=1)), 2)
+                # fft_freq = np.fft.rfftfreq(sig.shape[1], 1./fs)
+
+                # power by multitaper
+                # pxx, freq = psd_array_multitaper(sig, fs, fmin=self.low_frequency, fmax=self.high_frequency, adaptive=True, normalization='full', verbose=0)
+                # total_power = TaskRelatedPower.band_power(freq, pxx, self.low_frequency, self.high_frequency)
+                # total_power = TaskRelatedPower.band_power_fft(fft_freq, fft_values, self.low_frequency, self.high_frequency)
+                res[task_name][epoch_name] = {'total_power': total_power.tolist(), 'total_log_power': (np.log10(total_power)).tolist()}
+                for i in range(0, len(self.band_frequency), 2):
+                    # power = TaskRelatedPower.band_power(freq, pxx, self.band_frequency[i], self.band_frequency[i + 1])
+                    # power = TaskRelatedPower.band_power_fft(fft_freq, fft_values, self.band_frequency[i], self.band_frequency[i+1])
+                    power = TaskRelatedPower.band_power(freq, pxx, self.band_frequency[i], self.band_frequency[i + 1])
+                    freq_str = str(self.band_frequency[i]) + "_" + str(self.band_frequency[i+1])
+                    res[task_name][epoch_name][freq_str] = {'band_power': power.tolist(), 'band_log_power': (np.log10(power)).tolist(), 'band_relative_power': (abs(power/total_power)).tolist()}
+        return res
+
+    def task_related_power(self, band_powers, active_task, reference_task, task_data_name='task_data', power_name='band_log_power'):
+        electrodes_position = electrodes_spss()
+        n_ch = len(electrodes_position)
+        res = {}
+        spss = {}
+        for k in range(len(active_task)):
+            name = active_task[k].split('_')[1]
+            if name not in res:
+                res[name] = {}
+            for i in range(0, len(self.band_frequency), 2):
+                band = str(self.band_frequency[i]) + "_" + str(self.band_frequency[i + 1])
+                trp = np.asarray(band_powers[active_task[k]][task_data_name][band][power_name]) - np.asarray(band_powers[reference_task[k]][task_data_name][band][power_name])
+                trp_spss = trp[electrodes_position].reshape(n_ch, -1)
+                if band not in res[name]:
+                    res[name][band] = trp_spss
+                else:
+                    res[name][band] = np.concatenate((res[name][band], trp_spss), axis=1)
+        for task_name, task_data in res.items():
+            spss[task_name] = {}
+            for band_name, band_data in task_data.items():
+                spss[task_name][band_name] = np.mean(band_data, axis=1)
+                # spss[task_name][band_name] = band_data
+
+        return spss
+
+
+    def write_trp_excel(self, fname, data, task_name):
+        wb = load_workbook(fname)
+        for i in range(0, len(self.band_frequency), 2):
+            sheet_name = str(self.band_frequency[i]) + "_" + str(self.band_frequency[i+1])
+            sheet = wb.create_sheet(sheet_name)
+            row = 1
+            for subject_name, subject_data in data.items():
+                col = 1
+                for name in task_name:
+                    row_data = subject_data[name][sheet_name]
+                    for k in range(row_data.shape[0]):
+                        if row_data.ndim == 2:
+                            for k1 in range(row_data.shape[1]):
+                                sheet.cell(row, col).value = row_data[k, k1]
+                                col += 1
+                        else:
+                            sheet.cell(row, col).value = row_data[k]
+                            col += 1
+                row += 1
+        wb.save(fname)
+        wb.close()
+
+
+    @staticmethod
+    def band_power(freq, pxx, low, high):
+        n_ch = pxx.shape[0]
+        idx = np.logical_and(freq >= low, freq <= high).reshape(1, -1)
+        idx = np.repeat(idx, pxx.shape[0], axis=0)
+        power = simps(pxx[idx].reshape(n_ch, -1), dx=freq[1] - freq[0], axis=1)
+        return power
+
+    @staticmethod
+    def band_power_fft(fft_freq, fft_values, low, high):
+        n_ch = fft_values.shape[0]
+        idx = np.logical_and(fft_freq >= low, fft_freq <= high).reshape(1, -1)
+        idx = np.repeat(idx, n_ch, axis=0)
+        # total_power = np.sum(fft_values[idx].reshape(n_ch, -1), axis=1)
+        return np.mean(fft_values[idx].reshape(n_ch, -1), axis=1)
+
+
+if __name__ == '__main__':
+    clean_data_fname = r'D:\EEGdata\clean_data_creativity\1_50_2s_multipaer'
+    # subjects = read_subject_info(input_fname='C:\\Users\\umroot\\Desktop\\creativity\\subject.xlsx')
+    subjects = read_subject_info(r'D:\EEGdata\clean_data_creativity\clean_subjects_28.xlsx')
+
+    trp_excel = r'D:\EEGdata\clean_data_creativity\1_50_2s_multipaer\trp_28_8_10_10_12hz.xlsx'
+    trp_res = OrderedDict()
+    # task_name = ['idea generation', 'idea evolution', 'idea rating']
+    # active_task = ['1_idea evolution', '2_idea evolution', '3_idea evolution']
+    # reference_task = ['1_rest', '1_rest', '1_rest']
+    active_task = []
+    reference_task = []
+
+    task_name = ['idea generation', 'idea evolution', 'idea rating']
+
+    # task_name = ['read problem', 'generate solution', 'rate generation', 'evaluate solution', 'type', 'rate evaluation']
+
+
+    for i in range(1, 4):
+        for j in range(len(task_name)):
+            active_task.append(str(i)+"_"+task_name[j])
+            reference_task.append('1_rest')
+
+    # task_name = ['1_idea generation', '2_idea generation', '3_idea generation',
+    #              '1_idea evolution', '2_idea evolution', '3_idea evolution',
+    #              '1_idea rating', '2_idea rating', '3_idea rating']
+    #
+    # rest_name = ['1_rest']
+
+    for subject in subjects:
+        print(subject)
+        # data_fname_save = clean_data_fname +"\\" + subject +"_epochs_power" +".json"
+        data_fname = clean_data_fname + "\\" +subject + "_power" + ".json"
+        # data_fname = r'D:\EEGdata\clean_data_six_problem\task_related_power' + "\\" +subject + "_power" + ".json"
+
+        # data_fname = clean_data_fname + "\\" +subject + ".json"
+
+        data_text = epochs_text = codecs.open(data_fname, 'r', encoding='utf-8').read()
+        data = json.loads(data_text)
+        # power = TaskRelatedPower(data)
+
+        # res = power.band_powers()
+        power = TaskRelatedPower()
+        # res = {'task': power.concatenate_tasks_powers(data, task_name), 'rest': power.concatenate_tasks_powers(data, ['1_rest'])}
+        # json.dump(res, codecs.open(data_fname_save, 'w', encoding='utf-8'), separators=(',', ':'), sort_keys=True, indent=4)
+
+        trp_res[subject] = power.task_related_power(data, active_task, reference_task)
+    power = TaskRelatedPower()
+    power.write_trp_excel(trp_excel, trp_res, task_name)