Browse Source

uploaded helpers

Manuel Schottdorf 4 years ago
parent
commit
e372436b19
4 changed files with 419 additions and 0 deletions
  1. 1 0
      Helpers/__init__.py
  2. 134 0
      Helpers/burst_tools.py
  3. 191 0
      Helpers/file_helpers.py
  4. 93 0
      Helpers/sorter.py

+ 1 - 0
Helpers/__init__.py

@@ -0,0 +1 @@
+ 

+ 134 - 0
Helpers/burst_tools.py

@@ -0,0 +1,134 @@
+import numpy as np
+
+def find_bursts(sorted_dict, data_dict, sorting_flag = 'spiketimes_good'):
+ 
+  ##### Find good channels #####
+  good_channels = []
+  for ele in range(len(data_dict['elenumber'])):
+    if len(sorted_dict['spiketimes_good'][ele]) > 1:
+      good_channels.append(ele)
+      
+      
+  ##### Calculate time constant for burst detection #####
+  spikecount = 0
+  for i in good_channels:
+    spikecount = spikecount + len(sorted_dict['spiketimes_good'][i])
+  if good_channels:
+    average_rate_active_electrodes = spikecount / (data_dict['length_recording'] * len(good_channels) )
+    th = 0.25 * 1/average_rate_active_electrodes
+    print '\n Burst time constant is: ', th
+  else:
+    th = 0
+    
+    
+  # find whenever spikes are closer together than burst time constant.  
+  
+  if sorting_flag == 'spiketimes_all':
+    spikes = data_dict['spiketimes']
+  else:
+    spikes = sorted_dict[sorting_flag]
+
+  burst_dict = {}
+  for i in range(len(data_dict['elenumber'])):
+    burstflag = np.diff(spikes[i][:]) < th
+    if len(burstflag >0):
+      burstflag_ri = np.append(burstflag[0], burstflag)
+      burstflag_le = np.append(burstflag, burstflag[-1])
+      burstflag = np.logical_or(burstflag_ri, burstflag_le)
+    burstinfo = {'sp_allspikes':spikes[i][:],\
+    'sp_inbursts': [spikes[i][s] for s in np.where(burstflag)[0] ],\
+    'sp_notinbursts': [spikes[i][s] for s in np.where(~burstflag)[0] ]}
+
+    burst_dict[i] = burstinfo
+  burst_dict['th'] = th
+  burst_dict['good_channels'] = good_channels
+  
+  scounter = 0
+  for i in range(len(data_dict['elenumber'])):
+    scounter = scounter + len(burst_dict[i]['sp_notinbursts'])
+  burst_dict['spikes_not_in_bursts'] = scounter
+  
+  return burst_dict
+
+
+
+def extract_bursttimes_from_bursts( burst_dict, data_dict ):
+  th = burst_dict['th']
+
+  burst_inf = [[] for j in range(len(data_dict['elenumber']))]
+
+  for i in range(len(data_dict['elenumber'])):
+    spikes_on_this_channel_in_bursts = burst_dict[i]['sp_inbursts']
+    if len(spikes_on_this_channel_in_bursts)>0:
+      s0 = 0
+      last_spike_of_burst = 0
+      
+      for s in range(len(spikes_on_this_channel_in_bursts)-1):
+	if (spikes_on_this_channel_in_bursts[s+1] - spikes_on_this_channel_in_bursts[s]) < th:
+	  #if there is one flaged spike later and less than th apart, then this spike belongs to burst,
+	  last_spike_of_burst = s + 1
+	else:
+	  #burst ended
+	  starttime = spikes_on_this_channel_in_bursts[s0]
+	  if last_spike_of_burst > s0:
+	    burst_inf[i].append({'time':starttime , 'length':spikes_on_this_channel_in_bursts[last_spike_of_burst] - starttime, 'spikes': 1 + last_spike_of_burst - s0})
+	    #print last_spike_of_burst, s0
+	  #And a new burst begins with this s
+	  last_spike_of_burst = s + 1
+	  s0 = s + 1
+	  
+  return burst_inf
+
+
+
+def find_global_bursts( burst_inf, data_dict, acc_time = 1.25 ):
+  #next, overlapping bursts are merged:
+
+  #step 1: merge all burst events into one:
+  all_bursts = []
+  for i in range(len(data_dict['elenumber'])):
+    if len(burst_inf[i])>0:
+      for s in burst_inf[i]:
+	all_bursts.append([s['time'], s['length']])
+    
+  #step 2: sort them from initial time
+  all_bursts.sort(key=lambda x: x[0])
+
+
+  #step 3: start with first, go length, any other burst in that time*(5/4), add it. See when end is reached.
+  global_bursts = []
+  if len(all_bursts) > 0:
+    t0 = all_bursts[0][0]
+    t_end = all_bursts[0][0] + all_bursts[0][1]*acc_time
+
+    for s in range(len(all_bursts)):
+      t_new = all_bursts[s][0]
+      deltat_new = all_bursts[s][1]
+
+      if t_new < t_end:
+	#then the burst extends until here:
+	t_end = t_new + deltat_new*acc_time
+      else:
+	#burst has ended and the next starts
+	global_bursts.append({'time':t0 , 'length':t_end - t0})
+	
+	t0 = all_bursts[s][0]
+	t_end = all_bursts[s][0] + all_bursts[s][1]*acc_time
+    
+  return global_bursts 
+      
+    
+    
+    
+    
+    
+    
+    
+    
+    
+    
+    
+    
+    
+    
+

+ 191 - 0
Helpers/file_helpers.py

@@ -0,0 +1,191 @@
+from os import listdir
+import string
+import os
+import os.path
+
+
+def data_dictionary_di(ex, verbose = 'True', infofile = '../data.txt'):
+    f = open(infofile, 'r')
+    di = {}
+    for line in f:
+        s = line.split()
+        if (len(s) > 7) and (s[0] != 'number'):
+            #print s
+            if s[7] != '0':
+                #print [electrode for electrode in s[9].split(',')]
+                channels = [int(electrode) for electrode in s[7].split(',')]
+            else:
+                channels = []
+            di[int(s[0])] = {'folder':s[1], 'culture':s[2], 'DIV': int(s[3]), 'duration': s[4], 'disco': s[5], 'imaging': s[6], 'channels':channels}
+    f.close()
+    return di
+
+def data_directory(ex, HD, verbose = 'True', infofile = '../data.txt'):
+    path0 = HD
+    f = open(infofile, 'r')
+    di = data_dictionary_di(ex, verbose = 'False')
+    path = di[ex]['folder']
+    pathtofiles = path0 + path + '/'
+    ls = listdir(pathtofiles)
+    if ls[-1][-4:] == '.txt':
+        ls = ls[:-1]
+    if ls[-1] == 'analysis':
+        ls = ls[:-1]
+    if (len(ls) > 0):
+        ls.sort()
+        ###a = [ls[-1]] + ls[0:-1]
+        con = list( "%s" % item for item in ls )
+    if verbose:
+        print('For experiment ' + str(ex) + ' we have the files:')
+        for f in con:
+            print(f)
+        print('With parameters:')
+        for f in di[ex].keys():
+            print(str(f) + ' is ' + str(di[ex][f]))
+    return con, pathtofiles, di
+
+def plot_analog_channel_of_single_file(fil, shrinkage = 20):
+  import neuroshare as ns
+  import pylab as pl
+  import scipy as sp
+  import numpy as np
+  import neuroshare as ns
+  import scipy.ndimage
+  from Helpers.pull_all_data import remove_artifacts, estimate_from_derivative
+  
+  fd = ns.File (fil)
+
+  samplingrate = 25000.
+  framerate = 10.
+
+  frame_on = []
+  frame_off = []
+  new_sequence = []
+
+  for entity in fd.list_entities():
+    if entity.label[0:4]=='anlg':
+      rawdata_stim = fd.entities[entity.id]
+      data1, times, count = rawdata_stim.get_data()
+  data = remove_artifacts(data1)
+  ls = sp.ndimage.filters.gaussian_filter1d(data, sigma=30)
+  ls_diff = np.diff(ls)
+  ls_diff = np.insert(ls_diff,0,0) # first have a zero so that total number of elements works out
+
+  #b = (ls<0.10) - 0.5;
+  b = (ls<0.07) - 0.5;
+  temp = np.sort(np.argwhere((b*np.roll(b,1)<0) & (ls_diff<0)))
+  [frame_off.append(s[0]) for s in times[temp]]
+  temp = np.sort(np.argwhere((b*np.roll(b,1)<0) & (ls_diff>0)))
+  [frame_on.append(s[0]) for s in times[temp]]
+
+  #b = (ls>1) - 0.5;
+  b = (ls>0.5) - 0.5;
+  temp = np.sort(np.argwhere((b*np.roll(b,1)<0) & (ls_diff<0)))
+  [new_sequence.append(s[0]) for s in times[temp]]
+
+
+  snippet_shrinked = [ np.mean( data[shrinkage*i:shrinkage*(i+1)] ) for i in np.arange(0,len(data)/shrinkage) ]
+  snippet_shrinked_smoothed = sp.ndimage.filters.gaussian_filter1d(snippet_shrinked, sigma=(samplingrate/(framerate*shrinkage))/30)
+  snippet_shrinked_smoothed_derivative = np.diff(snippet_shrinked_smoothed)**2
+
+  ##shift stuff with 
+
+  frame_on = np.array(frame_on)*samplingrate/shrinkage
+  frame_off = np.array(frame_off)*samplingrate/shrinkage
+  new_sequence = np.array(new_sequence)*samplingrate/shrinkage
+
+  peaks4 = estimate_from_derivative(snippet_shrinked_smoothed_derivative)
+
+  if len(new_sequence) > 0:
+    peaks5 = np.array([t for t in peaks4 if ((np.min(abs(t - new_sequence)) > 0.2 /framerate) and (snippet_shrinked[t]<1))] )
+    #frame transitions can only occur (at least 200ms away from a new sequence) and (if the stimulus is on)
+  else:
+    #if there is no transition, everything is ok as long as stimulus is on
+    peaks5 = np.array([t for t in peaks4 if snippet_shrinked[t]<1] )
+
+
+  peaks5 = peaks5*shrinkage/samplingrate;
+  frame_on = frame_on*shrinkage/samplingrate;
+  frame_off = frame_off*shrinkage/samplingrate;
+  new_sequence = new_sequence*shrinkage/samplingrate;
+
+  #plot(peaks,zeros(len(peaks))+0.10,'r|')
+  #plot(peaks3,zeros(len(peaks3))+0.11,'g|')
+  pl.plot(peaks5, np.zeros(len(peaks5))+0.12,'b|')
+  pl.plot(frame_on, np.zeros(len(frame_on)) + 0.05, 'g|')
+  pl.plot(frame_off, np.zeros(len(frame_off)) + 0.04, 'k|')
+  pl.plot(new_sequence, np.zeros(len(new_sequence)) + 0.10, 'r|')
+  pl.plot(np.arange(len(snippet_shrinked))*shrinkage/samplingrate,snippet_shrinked,'-')
+  
+
+def plot_data_of_single_experiment(ex, full_random=False, storage_location = '/home/manuel/OptogeneticsData/'):
+    di = data_dictionary_di(ex) #contains the good channel
+    a = np.load(storage_location + 'analysis_' + str(ex) + '/results_' + str(ex) +'.npy')
+    ex_string = storage_location + 'analysis_'+ str(ex) + '/' + 'results_' + str(ex)
+    
+    print("Data loaded, now plotting...")
+    
+    final_results = a.item()
+    data_dict = final_results['data_dict']
+    sorted_dict = final_results['sorted_dict']
+    burst_dict = final_results['burst_dict']
+    burst_inf = final_results['burst_inf']
+    global_bursts = final_results['global_bursts']
+        
+    good_spikes = sorted_dict['spiketimes_good']
+    bad_spikes = sorted_dict['spiketimes_bad']
+    averaged_waveform_good = sorted_dict['mean_waveforms_good']
+    averaged_waveform_good_sc = sorted_dict['std_waveforms_good']
+    averaged_waveform_bad = sorted_dict['mean_waveforms_bad']
+    averaged_waveform_bad_sc = sorted_dict['std_waveforms_bad']
+    [gblist, principal_components, labels] = sorted_dict['principal_components']
+
+    frame_transitions_all = np.concatenate( data_dict['frame_transitions_all'])
+    frame_on_all = np.concatenate( data_dict['frame_on_all'])
+    frame_off_all  = np.concatenate( data_dict['frame_off_all'])
+    new_sequence_all = np.concatenate( data_dict['new_sequence_all'])
+    analoge_data_shrinked = np.concatenate(final_results['data_dict']['analogdata_shrinked'])
+    analogdata_times = np.concatenate(data_dict['analogdata_times'])
+    tuningcurves_data = final_results['tuningcurves_data']
+    framelength = tuningcurves_data['framelength']
+    average_framelength = np.mean(framelength)
+    
+    t0 = data_dict['t0']
+    elenumber = data_dict['elenumber']
+    
+    good_channels_file = di[ex]['channels']
+    good_channels = burst_dict['good_channels']  
+
+    waveformlength = 150
+    shrinkage = 20
+    samplingrate = 25000
+    
+    N_ele = len(data_dict['elenumber'])
+
+    figure(1, figsize = (14,14))
+    if len(frame_transitions_all)>0:
+        plot(frame_transitions_all,zeros(len(frame_transitions_all))+0.12,'b|')
+        plot(frame_on_all, zeros(len(frame_on_all)) + 0.12, 'go')
+        plot(frame_off_all, zeros(len(frame_off_all)) + 0.12, 'ko')
+        plot(new_sequence_all, zeros(len(new_sequence_all)) + 0.12, 'ro')
+        plot(analogdata_times,analoge_data_shrinked,'-')
+        if full_random:
+            label_dict = tuningcurves_data['label_dict']
+            for j in label_dict.keys():
+                text(j,0.10,label_dict[j])
+
+    for s in global_bursts:
+        gca().add_patch(Rectangle(( s['time'],0.2), s['length'],.5, facecolor='black', edgecolor='none',alpha = 0.3))
+
+    for ele in range(N_ele):
+        print("printing... " + str(ele))
+        if not good_spikes[ele] == []:
+            figure(1, figsize = (14,14))
+            plot(good_spikes[ele], ones(len(good_spikes[ele]))*elenumber[ele]/100,'b|') 
+            plot(bad_spikes[ele], ones(len(bad_spikes[ele]))*elenumber[ele]/100,'r|') 
+            burstspikes = burst_dict[ele]['sp_inbursts']
+            plot(burstspikes, ones(len(burstspikes))*elenumber[ele]/100,'g|',linewidth=2.0)
+            if elenumber[ele] in good_channels_file:
+                text(t0,elenumber[ele]/100.,str(elenumber[ele]))
+        for s in burst_inf[ele]:
+            plot([ s['time'],s['time'] + s['length']], [elenumber[ele]/100., elenumber[ele]/100.],'y-')

+ 93 - 0
Helpers/sorter.py

@@ -0,0 +1,93 @@
+from matplotlib.mlab import PCA
+from sklearn.cluster import MeanShift, DBSCAN, estimate_bandwidth
+import numpy as np
+
+
+np.random.seed(0)
+
+def sorting(data_dict, waveformlength = 150, minimum_number = 30) :
+  #minimum_number is the minimum number of spikes in a cluster, could also scale with running time minimum_number = 0.2 * fd.metadata_raw['TimeSpan'] or similar
+  #waveformlength is the number of points per waveform. typically 150
+    waveforms = data_dict['waveforms']
+    spiketimes = data_dict['spiketimes']
+    N_ele = len(data_dict['elenumber'])
+    averaged_waveform_bad = np.zeros((N_ele,waveformlength))
+    averaged_waveform_bad_sc = np.zeros((N_ele,waveformlength))
+    averaged_waveform_good = np.zeros((N_ele,waveformlength))
+    averaged_waveform_good_sc = np.zeros((N_ele,waveformlength))
+    good_spikes = [[] for j in range(N_ele)]
+    bad_spikes = [[] for j in range(N_ele)]
+    principal_components = [[] for j in range(N_ele)]
+    event_labels = [[] for j in range(N_ele)]
+    gblist = [[] for j in range(N_ele)]
+    
+    for ele in range(N_ele):
+        print ele
+        tmp = np.zeros((len(waveforms[ele][:]),waveformlength))
+        for j in range(0,len(waveforms[ele][:])):
+            tmp[j,:] =  waveforms[ele][j]
+        a,b = tmp.shape
+        if a>b:
+            results = PCA(tmp)
+            data = np.array([results.Y[:,0], results.Y[:,1]])
+            data = data.transpose()
+            principal_components[ele].append(data)
+
+            ########meanShift
+            bandwidth = estimate_bandwidth(data, quantile = 0.3)
+            ms = MeanShift(cluster_all=True, bandwidth = bandwidth)
+            ms.fit(data)
+            labels = ms.labels_
+            labels_unique = np.unique(labels)
+            n_clusters_ = len(labels_unique)
+            number_of_points = [ np.sum(np.array(labels == i, dtype = int)) for i in np.sort(labels_unique)]
+            labels_sorted = np.argsort(number_of_points)# order from biggest to smallest
+            counter = 0
+            for i in labels_sorted:
+                labels[labels == i] = 100 + counter
+                counter = counter + 1
+            labels = labels - 100
+            for i in np.sort(labels_unique):
+                number_of_points = np.sum(np.array(labels == i, dtype = int))
+                if number_of_points < minimum_number:
+                    labels[labels == i] = -1
+            w = [[] for i in labels_unique]
+            averaged_waveform = np.zeros((len(labels_unique),waveformlength))
+            averaged_waveform_sc = np.zeros((len(labels_unique),waveformlength))
+            for i in np.sort(labels_unique):
+                number_of_events = np.sum(np.array(labels == i, dtype = int))
+                for j in np.arange(0,len(results.Y)):
+                    if labels[j] == i:
+                        w[i].append(tmp[j,:])
+                averaged_waveform[i][:] = np.mean(w[i][:], axis = 0)
+            event_labels[ele].append(labels)
+            ####### End spikesorting
+
+            # Select good clusters
+            w_good = []
+            w_bad = []
+            for i in np.arange(0,len(results.Y)):
+                #~ print ele, np.std(averaged_waveform[labels[i]]), np.max(averaged_waveform[labels[i]])
+                if (np.std(averaged_waveform[labels[i]]) < 5*1e-6) and (np.max(averaged_waveform[labels[i]]) < 4*1e-6):
+                    bad_spikes[ele].append(spiketimes[ele][i])
+                    w_bad.append(tmp[i,:])
+                    gblist[ele].append(0)
+                else:
+                    good_spikes[ele].append(spiketimes[ele][i])
+                    w_good.append(tmp[i,:])
+                    gblist[ele].append(1)
+
+            averaged_waveform_good[ele][:] = np.mean(w_good, axis = 0)
+            averaged_waveform_good_sc[ele][:] = np.std(w_good, axis = 0) 
+            averaged_waveform_bad[ele][:] = np.mean(w_bad, axis = 0) 
+            averaged_waveform_bad_sc[ele][:] = np.std(w_bad, axis = 0) 
+
+
+    sorted_dict= {\
+    'spiketimes_good': good_spikes,                'spiketimes_bad': bad_spikes,\
+    'mean_waveforms_good': averaged_waveform_good, 'std_waveforms_good': averaged_waveform_good_sc,\
+    'mean_waveforms_bad': averaged_waveform_bad,   'std_waveforms_bad': averaged_waveform_bad_sc,\
+    'principal_components': [gblist, principal_components, labels]}
+
+
+    return sorted_dict