123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307 |
- import neuroshare as ns
- from pylab import *
- import numpy as np
- import scipy as sp
- import scipy.ndimage
- import scipy.signal
- import datetime
- import time
- from os import listdir, makedirs
- from os.path import isfile, join, isdir
- import math
- import random
- from scipy.io import savemat
- from Helpers.burst_tools import find_bursts, extract_bursttimes_from_bursts, find_global_bursts
- from Helpers.sorter import sorting
- from Helpers.file_helpers import data_directory, data_dictionary_di
- experiments = [4]
- extract_from_hd = True
- do_savefig = True
- sorting_flag = True
- plotanalog = False
- HD = '/run/media/manuel/TOSHIBA EXT/'
- savepath = '/home/manuel/bla/OptoGeneticsData/Lightdisco_raw/'
- min_rate = 0.1
- for ex in experiments:
-
- if not extract_from_hd:
- di = data_dictionary_di(ex)
- saveresults = savepath + 'rawdata_' + str(ex)
- a = np.load(saveresults + '.npy')
- final_results = a.item()
- sorted_dict = final_results['sorted_dict']
- data_dict = final_results['data_dict']
- burst_dict = final_results['burst_dict']
- savepath_results_main = savepath + 'analysis_'+ str(ex) + '/'
-
- if plotanalog:
- #Just in case, the analog channel should be plottet
- onlyfiles, pathtofiles, di = data_directory(ex, HD)
- analog_data_file = np.array([])
- rawdata_stim = []
- shrinkage = 20
- for fil in onlyfiles:
- fd = ns.File (pathtofiles + fil)
- print "file: " + fil + " opened."
- for entity in fd.list_entities():
- if entity.label[0:4]=='anlg':
- rawdata_stim = fd.entities[entity.id]
- if not rawdata_stim == []:
- data_section, times, count = rawdata_stim.get_data()
- snippet_shrinked = np.array([ mean( data_section[shrinkage*i:shrinkage*(i+1)] ) for i in np.arange(0,len(data_section)/shrinkage) ])
- analog_data_file = np.append(analog_data_file,snippet_shrinked)
- fd.close()
- figure(3)
- plot(analog_data_file[::10])
- if extract_from_hd:
- onlyfiles, pathtofiles, di = data_directory(ex, HD)
- savepath_results_main = savepath + 'analysis_'+ str(ex) + '/'
- if not (isdir(savepath_results_main)):
- print 'Folder ' + savepath_results_main + ' created!'
- makedirs(savepath_results_main)
-
- ######################################################
- # extract data from the mcd files
- ######################################################
- fd = ns.File (pathtofiles + onlyfiles[0])
- numberofspikeentities = len( [int(entity.label[-2:]) for entity in fd.list_entities() if entity.entity_type == 3] )
- waveforms = [[] for j in range(0,numberofspikeentities)]
- spiketimes = [[] for j in range(0,numberofspikeentities)]
- fd.close()
- for fil in onlyfiles:
- fd = ns.File (pathtofiles + fil)
- print ".........................................................................."
- print "File: " + fil + " opened."
- dt = datetime.datetime(fd.metadata_raw['Time_Year'], fd.metadata_raw['Time_Month'], fd.metadata_raw['Time_Day'], fd.metadata_raw['Time_Hour'], fd.metadata_raw['Time_Min'], fd.metadata_raw['Time_Sec'], fd.metadata_raw['Time_MilliSec']*1000, tzinfo = None)
- epoch = datetime.datetime.utcfromtimestamp(0);
- delta = dt - epoch
- t0 = delta.total_seconds()
- print 'Current time: ', t0, '\n\n'
-
- samplingrate = 25000.
- elenumber = []
- electrodes = \
- [11, 12, 13, 14, 15, 16, 17, 18,\
- 21, 22, 23, 24, 25, 26, 27, 28,\
- 31, 32, 33, 34, 35, 36, 37, 38,\
- 41, 42, 43, 44, 45, 46, 47, 48,\
- 51, 52, 53, 54, 55, 56, 57, 58,\
- 61, 62, 63, 64, 65, 66, 67, 68,\
- 71, 72, 73, 74, 75, 76, 77, 78,\
- 81, 82, 83, 84, 85, 86, 87, 88 ]
- spike_entities = [int(entity.label[-2:]) for entity in fd.list_entities() if entity.entity_type == 3]
- numberofspikeentities = size(spike_entities)
- elenumber = []
- counter = 0 ##Counts how many datasets are read in
- for entity in fd.list_entities():
- if entity.entity_type == 3:
- #print entity.label, entity.entity_type
- spikes1 = fd.entities[entity.id]
- if int(spikes1.label[-2:]) in electrodes:
- elenumber.append(int(spikes1.label[-2:]))
- for i in range(1, spikes1.item_count):
- waveform, time, a, b = spikes1.get_data(i)
- spiketimes[counter].append(time + t0)
- waveforms[counter].append(waveform[0])
- counter = counter + 1
- fd.close()
-
- allspikes = np.concatenate( [ np.array(spiketimes[c]) for c in range(counter - 1 )] )
- first_spike_in_data = np.min(allspikes)
- last_spike_in_data = np.max(allspikes)
- data_dict= {'waveforms': waveforms, 'spiketimes': spiketimes, 'elenumber': elenumber, 'length_recording': last_spike_in_data - first_spike_in_data}
-
-
- ######################################################
- # Process the data, False positive Sorting
- ######################################################
- if sorting_flag:
- sorting_flag = 'spiketimes_good'
- sorted_dict = sorting(data_dict)
- else:
- sorting_flag = 'spiketimes_good'
- sorted_dict= {'spiketimes_good': spiketimes}
-
- burst_dict = find_bursts(sorted_dict, data_dict, sorting_flag)
- burst_inf = extract_bursttimes_from_bursts( burst_dict, data_dict )
- global_bursts = find_global_bursts( burst_inf, data_dict, acc_time = 1.25 )
-
- good_spikes = sorted_dict['spiketimes_good']
- bad_spikes = sorted_dict['spiketimes_bad']
- good_channels = electrodes
-
- counter = 0
- cvs1 = []
- rates = []
- data = {}
- for ele in range(len(elenumber)):
- if (not spiketimes[ele] == []) and (elenumber[ele] in good_channels):
- data['Electrode_' + str(elenumber[ele])] = good_spikes[ele] - first_spike_in_data
- counter = counter + 1
- if len(good_spikes[ele]) > (last_spike_in_data - first_spike_in_data)*min_rate:
- mu = np.mean(np.diff(good_spikes[ele]))
- sigma = np.std(np.diff(good_spikes[ele]))
- cv = sigma/mu
- cvs1.append(cv)
- rates.append( len(good_spikes[ele])/(last_spike_in_data - first_spike_in_data) )
- ######################################################
- # Finally, save data as numpy file
- ######################################################
- final_results = {'data_dict': data_dict, 'sorted_dict': sorted_dict, 'burst_dict': burst_dict, 'burst_inf': burst_inf, 'global_bursts':global_bursts, 'cvs': cvs1, 'rates': rates, 'electrode_data': data}
- saveresults = savepath + 'rawdata_' + str(ex)
- np.save(saveresults + '.npy', final_results)
-
-
-
-
-
- ######################################################
- # And plot the spike trains of active channels
- ######################################################
- good_spikes = sorted_dict['spiketimes_good']
- bad_spikes = sorted_dict['spiketimes_bad']
- averaged_waveform_good = sorted_dict['mean_waveforms_good']
- averaged_waveform_good_sc = sorted_dict['std_waveforms_good']
- averaged_waveform_bad = sorted_dict['mean_waveforms_bad']
- averaged_waveform_bad_sc = sorted_dict['std_waveforms_bad']
- [gblist, principal_components, labels] = sorted_dict['principal_components']
- good_channels = burst_dict['good_channels']
- good_channels_file = di[ex]['channels']
- N_ele = len(data_dict['elenumber'])
- elenumber = data_dict['elenumber']
- averaged_waveform_all = [np.mean(data_dict['waveforms'][ele],axis=0) for ele in range(N_ele)] #average across time, not across channels!
- averaged_waveform_sc = [np.std(data_dict['waveforms'][ele],axis=0) for ele in range(N_ele)]
-
-
- waveformlength = 150
- shrinkage = 200
- samplingrate = 25000
- if sorting_flag:
- spiketimes = good_spikes
- else:
- spiketimes = spiketimes
-
- allspikes = np.concatenate( [ np.array(spiketimes[c]) for c in range(N_ele)] )
- first_spike_in_data = np.min(allspikes)
- last_spike_in_data = np.max(allspikes)
- for ele in range(N_ele):
- print "printing... " + str(ele)
- if (not spiketimes[ele] == []):
- figure(1, figsize = (14,14))
- plot(good_spikes[ele] - first_spike_in_data, ones(len(good_spikes[ele]))*elenumber[ele]/100,'b|')
- plot(bad_spikes[ele] - first_spike_in_data, ones(len(bad_spikes[ele]))*elenumber[ele]/100,'r|')
- if elenumber[ele] in good_channels_file:
- text(0,elenumber[ele]/100.,str(elenumber[ele]))
-
- figure(2)
- if len(good_spikes[ele]) > (last_spike_in_data - first_spike_in_data)*min_rate:
- hist(np.diff(good_spikes[ele]),np.linspace(0,1,100), alpha=0.5,normed = True, label = str(elenumber[ele]))
- xlim([0,1])
- figure(111, figsize = (14,14))
- xx = int(elenumber[ele]/10);
- yy = int(elenumber[ele] - 10*xx)-1
- if elenumber[ele] in good_channels_file:
- ax = subplot(8,8,xx + 8*yy, axisbg='lightyellow')
- else:
- ax = subplot(8,8,xx + 8*yy)
- setp( ax.get_xticklabels(), visible=False)
- setp( ax.get_yticklabels(), visible=False)
- setp( ax.get_xticklines(), visible=False)
- setp( ax.get_yticklines(), visible=False)
- plot(averaged_waveform_good[ele],color = 'b')
- fill_between(np.arange(waveformlength), averaged_waveform_good[ele]-averaged_waveform_good_sc[ele], averaged_waveform_good[ele]+averaged_waveform_good_sc[ele],color='b', alpha = 0.3)
- plot(averaged_waveform_bad[ele],color = 'r')
- fill_between(np.arange(waveformlength), averaged_waveform_bad[ele]-averaged_waveform_bad_sc[ele], averaged_waveform_bad[ele]+averaged_waveform_bad_sc[ele],color='r', alpha = 0.3)
- plot([0,waveformlength],[-20*1e-6,-20*1e-6],'k--')
- plot([0,waveformlength],[-10*1e-6,-10*1e-6],'k:')
- plot([0,waveformlength],[0,0],'k-')
- plot([0,waveformlength],[10*1e-6,10*1e-6],'k:')
- figure(112, figsize = (14,14))
- xx = int(elenumber[ele]/10);
- yy = int(elenumber[ele] - 10*xx)-1
- if elenumber[ele] in good_channels_file:
- ax = subplot(8,8,xx + 8*yy, axisbg='lightyellow')
- else:
- ax = subplot(8,8,xx + 8*yy)
- setp( ax.get_xticklabels(), visible=False)
- setp( ax.get_yticklabels(), visible=False)
- setp( ax.get_xticklines(), visible=False)
- setp( ax.get_yticklines(), visible=False)
- plot(averaged_waveform_all[ele],color = 'k')
- fill_between(np.arange(waveformlength), averaged_waveform_all[ele]-averaged_waveform_sc[ele], averaged_waveform_all[ele]+averaged_waveform_sc[ele],color='b', alpha = 0.3)
- plot([0,waveformlength],[-20*1e-6,-20*1e-6],'k--')
- plot([0,waveformlength],[-10*1e-6,-10*1e-6],'k:')
- plot([0,waveformlength],[0,0],'k-')
- plot([0,waveformlength],[10*1e-6,10*1e-6],'k:')
-
- if not principal_components[ele] == []:
- pcs = principal_components[ele][0]
- figure(11, figsize = (14,14))
- subplot(8,8,xx + 8*yy)
- for i in [ randint(0,len(principal_components[ele][0])-1) for i in np.arange(0,100) ]:
- if gblist[ele][i] == 0:
- co = 'r'
- else:
- co = 'b'
- xs, ys = pcs[i]
- plot(xs, ys, marker='.', color=co, markersize=1)
- axis('off')
- if elenumber[ele] in good_channels_file:
- plot([np.nanmin(pcs[:,0]), np.nanmax(pcs[:,0]), np.nanmax(pcs[:,0]), np.nanmin(pcs[:,0]), np.nanmin(pcs[:,0])],[np.nanmin(pcs[:,1]), np.nanmin(pcs[:,1]), np.nanmax(pcs[:,1]), np.nanmax(pcs[:,1]), np.nanmin(pcs[:,1])], 'y-', linewidth = 5)
-
- if do_savefig:
- figure(1)
- savefig(savepath_results_main + 'spiketrain.png', bbox_inches='tight')
- xlim([0,60])
- savefig(savepath_results_main + 'spiketrain_zoom.png', bbox_inches='tight')
- clf()
- close()
- figure(111)
- savefig(savepath_results_main + 'waveforms_sorted.pdf', bbox_inches='tight')
- clf()
- close()
- figure(112)
- savefig(savepath_results_main + 'waveforms_all.pdf', bbox_inches='tight')
- clf()
- close()
- figure(11)
- savefig(savepath_results_main + 'principal_components.pdf', bbox_inches='tight')
- clf()
- close()
- figure(2)
- legend()
- savefig(savepath_results_main + 'isi.pdf', bbox_inches='tight')
- clf()
- close()
- if plotanalog:
- figure(3)
- legend()
- savefig(savepath_results_main + 'analogcheck.png', bbox_inches='tight')
- clf()
- close()
|