Browse Source

Upload files to 'raw_code'

Manuel Schottdorf 4 năm trước cách đây
mục cha
commit
fb92595330
1 tập tin đã thay đổi với 307 bổ sung0 xóa
  1. 307 0
      raw_code/extract_data.py

+ 307 - 0
raw_code/extract_data.py

@@ -0,0 +1,307 @@
+import neuroshare as ns
+from pylab import *
+import numpy as np
+import scipy as sp
+import scipy.ndimage
+import scipy.signal
+import datetime
+import time
+from os import listdir, makedirs
+from os.path import isfile, join, isdir
+import math
+import random
+from scipy.io import savemat
+from Helpers.burst_tools import find_bursts, extract_bursttimes_from_bursts, find_global_bursts
+from Helpers.sorter import sorting
+from Helpers.file_helpers import data_directory, data_dictionary_di
+
+
+
+experiments = [4]
+
+
+extract_from_hd = True
+do_savefig = True
+sorting_flag = True
+plotanalog = False
+
+
+HD = '/run/media/manuel/TOSHIBA EXT/'
+savepath = '/home/manuel/bla/OptoGeneticsData/Lightdisco_raw/'
+
+min_rate = 0.1
+
+for ex in experiments:
+    
+    if not extract_from_hd:
+        di = data_dictionary_di(ex)
+        saveresults = savepath + 'rawdata_' + str(ex)
+        a = np.load(saveresults + '.npy')
+        final_results = a.item()
+        sorted_dict = final_results['sorted_dict']
+        data_dict = final_results['data_dict']
+        burst_dict = final_results['burst_dict']
+        savepath_results_main = savepath + 'analysis_'+ str(ex) + '/'
+        
+        if plotanalog:
+            #Just in case, the analog channel should be plottet
+            onlyfiles, pathtofiles, di = data_directory(ex, HD)
+            analog_data_file = np.array([])
+            rawdata_stim = []
+            shrinkage = 20
+            for fil in onlyfiles:
+                fd = ns.File (pathtofiles + fil)
+                print "file:  " + fil + "   opened."
+                for entity in fd.list_entities():
+                    if entity.label[0:4]=='anlg':
+                        rawdata_stim = fd.entities[entity.id]
+                if not rawdata_stim == []:
+                    data_section, times, count = rawdata_stim.get_data()
+                    snippet_shrinked = np.array([ mean( data_section[shrinkage*i:shrinkage*(i+1)] ) for i in np.arange(0,len(data_section)/shrinkage) ])
+                    analog_data_file = np.append(analog_data_file,snippet_shrinked)
+                fd.close()
+            figure(3)
+            plot(analog_data_file[::10])
+
+    if extract_from_hd:
+        onlyfiles, pathtofiles, di = data_directory(ex, HD)
+        savepath_results_main = savepath + 'analysis_'+ str(ex) + '/'
+        if not (isdir(savepath_results_main)):
+            print 'Folder ' + savepath_results_main + ' created!'
+            makedirs(savepath_results_main)
+        
+        ######################################################
+        #   extract data from the mcd files
+        ######################################################
+        fd = ns.File (pathtofiles + onlyfiles[0])
+        numberofspikeentities = len( [int(entity.label[-2:]) for entity in fd.list_entities() if entity.entity_type == 3] )
+        waveforms = [[] for j in range(0,numberofspikeentities)]
+        spiketimes = [[] for j in range(0,numberofspikeentities)]
+        fd.close()
+
+        for fil in onlyfiles:
+            fd = ns.File (pathtofiles + fil)
+            print ".........................................................................."
+            print "File:  " + fil + "   opened."
+            dt = datetime.datetime(fd.metadata_raw['Time_Year'], fd.metadata_raw['Time_Month'], fd.metadata_raw['Time_Day'], fd.metadata_raw['Time_Hour'], fd.metadata_raw['Time_Min'], fd.metadata_raw['Time_Sec'], fd.metadata_raw['Time_MilliSec']*1000, tzinfo = None)
+            epoch = datetime.datetime.utcfromtimestamp(0);
+            delta = dt - epoch
+            t0 = delta.total_seconds()
+            print 'Current time: ', t0, '\n\n'
+            
+            samplingrate = 25000.
+            elenumber = []
+            electrodes = \
+            [11, 12, 13, 14, 15, 16, 17, 18,\
+            21, 22, 23, 24, 25, 26, 27, 28,\
+            31, 32, 33, 34, 35, 36, 37, 38,\
+            41, 42, 43, 44, 45, 46, 47, 48,\
+            51, 52, 53, 54, 55, 56, 57, 58,\
+            61, 62, 63, 64, 65, 66, 67, 68,\
+            71, 72, 73, 74, 75, 76, 77, 78,\
+            81, 82, 83, 84, 85, 86, 87, 88 ]
+
+            spike_entities = [int(entity.label[-2:]) for entity in fd.list_entities() if entity.entity_type == 3]
+            numberofspikeentities = size(spike_entities)
+            elenumber = []
+            counter = 0	##Counts how many datasets are read in
+            for entity in fd.list_entities():
+                if entity.entity_type == 3:
+                    #print entity.label, entity.entity_type
+                    spikes1 = fd.entities[entity.id]
+                    if int(spikes1.label[-2:]) in electrodes:
+                        elenumber.append(int(spikes1.label[-2:]))
+                        for i in range(1, spikes1.item_count):
+                            waveform, time, a, b = spikes1.get_data(i)
+                            spiketimes[counter].append(time + t0)
+                            waveforms[counter].append(waveform[0])
+                        counter = counter + 1
+            fd.close()
+            
+        allspikes = np.concatenate(   [ np.array(spiketimes[c]) for c in range(counter - 1 )]  )
+        first_spike_in_data = np.min(allspikes)
+        last_spike_in_data = np.max(allspikes)
+        data_dict= {'waveforms': waveforms, 'spiketimes': spiketimes, 'elenumber': elenumber, 'length_recording': last_spike_in_data - first_spike_in_data}
+        
+        
+        ######################################################
+        #   Process the data, False positive Sorting
+        ######################################################
+        if sorting_flag:
+            sorting_flag = 'spiketimes_good'
+            sorted_dict = sorting(data_dict)
+        else:
+            sorting_flag = 'spiketimes_good'
+            sorted_dict= {'spiketimes_good': spiketimes}
+            
+        burst_dict = find_bursts(sorted_dict, data_dict, sorting_flag) 
+        burst_inf = extract_bursttimes_from_bursts( burst_dict, data_dict )
+        global_bursts = find_global_bursts( burst_inf, data_dict, acc_time = 1.25 )
+        
+        good_spikes = sorted_dict['spiketimes_good']
+        bad_spikes = sorted_dict['spiketimes_bad']
+        good_channels = electrodes 
+        
+        counter = 0
+        cvs1 = []
+        rates = []
+        data = {}
+        for ele in range(len(elenumber)):
+            if (not spiketimes[ele] == []) and (elenumber[ele] in good_channels):
+                data['Electrode_' + str(elenumber[ele])] = good_spikes[ele] - first_spike_in_data
+                counter = counter + 1 
+                if len(good_spikes[ele]) > (last_spike_in_data - first_spike_in_data)*min_rate:
+                    mu = np.mean(np.diff(good_spikes[ele]))
+                    sigma = np.std(np.diff(good_spikes[ele]))
+                    cv = sigma/mu
+                    cvs1.append(cv)
+                    rates.append( len(good_spikes[ele])/(last_spike_in_data - first_spike_in_data) )
+
+
+        ######################################################
+        #   Finally, save data as numpy file
+        ######################################################
+        final_results = {'data_dict': data_dict, 'sorted_dict': sorted_dict, 'burst_dict': burst_dict, 'burst_inf': burst_inf, 'global_bursts':global_bursts, 'cvs': cvs1, 'rates': rates, 'electrode_data': data}
+        saveresults = savepath + 'rawdata_' + str(ex)
+        np.save(saveresults + '.npy', final_results)
+    
+    
+    
+    
+    
+    ######################################################
+    #   And plot the spike trains of active channels
+    ######################################################
+    good_spikes = sorted_dict['spiketimes_good']
+    bad_spikes = sorted_dict['spiketimes_bad']
+    averaged_waveform_good = sorted_dict['mean_waveforms_good']
+    averaged_waveform_good_sc = sorted_dict['std_waveforms_good']
+    averaged_waveform_bad = sorted_dict['mean_waveforms_bad']
+    averaged_waveform_bad_sc = sorted_dict['std_waveforms_bad']
+    [gblist, principal_components, labels] = sorted_dict['principal_components']
+    good_channels = burst_dict['good_channels']  
+    good_channels_file = di[ex]['channels']
+
+
+    N_ele = len(data_dict['elenumber'])
+    elenumber = data_dict['elenumber']
+
+    averaged_waveform_all = [np.mean(data_dict['waveforms'][ele],axis=0) for ele in range(N_ele)] #average across time, not across channels!
+    averaged_waveform_sc = [np.std(data_dict['waveforms'][ele],axis=0) for ele in range(N_ele)]
+    
+    
+    waveformlength = 150
+    shrinkage = 200
+    samplingrate = 25000
+
+    if sorting_flag:
+        spiketimes = good_spikes
+    else:
+        spiketimes = spiketimes
+        
+    allspikes = np.concatenate(   [ np.array(spiketimes[c]) for c in range(N_ele)]  )
+    first_spike_in_data = np.min(allspikes)
+    last_spike_in_data = np.max(allspikes)
+
+    for ele in range(N_ele):
+        print "printing... " + str(ele)
+        if (not spiketimes[ele] == []):
+            figure(1, figsize = (14,14))
+            plot(good_spikes[ele] - first_spike_in_data, ones(len(good_spikes[ele]))*elenumber[ele]/100,'b|') 
+            plot(bad_spikes[ele] - first_spike_in_data, ones(len(bad_spikes[ele]))*elenumber[ele]/100,'r|') 
+            if elenumber[ele] in good_channels_file:
+                text(0,elenumber[ele]/100.,str(elenumber[ele]))
+                
+            figure(2)
+            if len(good_spikes[ele]) > (last_spike_in_data - first_spike_in_data)*min_rate:
+                hist(np.diff(good_spikes[ele]),np.linspace(0,1,100), alpha=0.5,normed = True, label = str(elenumber[ele]))
+                xlim([0,1])
+
+        figure(111, figsize = (14,14))
+        xx = int(elenumber[ele]/10);
+        yy = int(elenumber[ele] - 10*xx)-1
+        if elenumber[ele] in good_channels_file:
+            ax = subplot(8,8,xx + 8*yy, axisbg='lightyellow')
+        else:
+            ax = subplot(8,8,xx + 8*yy)
+        setp( ax.get_xticklabels(), visible=False)
+        setp( ax.get_yticklabels(), visible=False)  
+        setp( ax.get_xticklines(), visible=False)
+        setp( ax.get_yticklines(), visible=False)
+        plot(averaged_waveform_good[ele],color = 'b')
+        fill_between(np.arange(waveformlength), averaged_waveform_good[ele]-averaged_waveform_good_sc[ele], averaged_waveform_good[ele]+averaged_waveform_good_sc[ele],color='b', alpha = 0.3)
+        plot(averaged_waveform_bad[ele],color = 'r')
+        fill_between(np.arange(waveformlength), averaged_waveform_bad[ele]-averaged_waveform_bad_sc[ele], averaged_waveform_bad[ele]+averaged_waveform_bad_sc[ele],color='r', alpha = 0.3)
+        plot([0,waveformlength],[-20*1e-6,-20*1e-6],'k--')
+        plot([0,waveformlength],[-10*1e-6,-10*1e-6],'k:')
+        plot([0,waveformlength],[0,0],'k-')
+        plot([0,waveformlength],[10*1e-6,10*1e-6],'k:')
+
+        figure(112, figsize = (14,14))
+        xx = int(elenumber[ele]/10);
+        yy = int(elenumber[ele] - 10*xx)-1
+        if elenumber[ele] in good_channels_file:
+            ax = subplot(8,8,xx + 8*yy, axisbg='lightyellow')
+        else:
+            ax = subplot(8,8,xx + 8*yy)
+        setp( ax.get_xticklabels(), visible=False)
+        setp( ax.get_yticklabels(), visible=False)  
+        setp( ax.get_xticklines(), visible=False)
+        setp( ax.get_yticklines(), visible=False)
+        plot(averaged_waveform_all[ele],color = 'k')
+        fill_between(np.arange(waveformlength), averaged_waveform_all[ele]-averaged_waveform_sc[ele], averaged_waveform_all[ele]+averaged_waveform_sc[ele],color='b', alpha = 0.3)
+        plot([0,waveformlength],[-20*1e-6,-20*1e-6],'k--')
+        plot([0,waveformlength],[-10*1e-6,-10*1e-6],'k:')
+        plot([0,waveformlength],[0,0],'k-')
+        plot([0,waveformlength],[10*1e-6,10*1e-6],'k:')
+        
+        if not principal_components[ele] == []:
+            pcs = principal_components[ele][0]
+            figure(11, figsize = (14,14))
+            subplot(8,8,xx + 8*yy)
+            for i in [ randint(0,len(principal_components[ele][0])-1) for i in np.arange(0,100) ]:
+                if gblist[ele][i] == 0:
+                    co = 'r'
+                else:
+                    co = 'b'
+                xs, ys = pcs[i]
+                plot(xs, ys, marker='.', color=co, markersize=1)
+                axis('off')
+            if elenumber[ele] in good_channels_file:
+                plot([np.nanmin(pcs[:,0]), np.nanmax(pcs[:,0]), np.nanmax(pcs[:,0]), np.nanmin(pcs[:,0]), np.nanmin(pcs[:,0])],[np.nanmin(pcs[:,1]), np.nanmin(pcs[:,1]), np.nanmax(pcs[:,1]), np.nanmax(pcs[:,1]), np.nanmin(pcs[:,1])], 'y-', linewidth = 5)
+
+
+        
+
+    if do_savefig:
+        figure(1)
+        savefig(savepath_results_main + 'spiketrain.png', bbox_inches='tight')
+        xlim([0,60])
+        savefig(savepath_results_main + 'spiketrain_zoom.png', bbox_inches='tight')
+        clf()
+        close()
+        figure(111)
+        savefig(savepath_results_main + 'waveforms_sorted.pdf', bbox_inches='tight')
+        clf()
+        close()
+        figure(112)
+        savefig(savepath_results_main + 'waveforms_all.pdf', bbox_inches='tight')
+        clf()
+        close()
+        figure(11)
+        savefig(savepath_results_main + 'principal_components.pdf', bbox_inches='tight')
+        clf()
+        close()
+        figure(2)
+        legend()
+        savefig(savepath_results_main + 'isi.pdf', bbox_inches='tight')
+        clf()
+        close()
+        if plotanalog:
+            figure(3)
+            legend()
+            savefig(savepath_results_main + 'analogcheck.png', bbox_inches='tight')
+            clf()
+            close()
+
+