extract_data.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. import neuroshare as ns
  2. from pylab import *
  3. import numpy as np
  4. import scipy as sp
  5. import scipy.ndimage
  6. import scipy.signal
  7. import datetime
  8. import time
  9. from os import listdir, makedirs
  10. from os.path import isfile, join, isdir
  11. import math
  12. import random
  13. from scipy.io import savemat
  14. from Helpers.burst_tools import find_bursts, extract_bursttimes_from_bursts, find_global_bursts
  15. from Helpers.sorter import sorting
  16. from Helpers.file_helpers import data_directory, data_dictionary_di
  17. experiments = [4]
  18. extract_from_hd = True
  19. do_savefig = True
  20. sorting_flag = True
  21. plotanalog = False
  22. HD = '/run/media/manuel/TOSHIBA EXT/'
  23. savepath = '/home/manuel/bla/OptoGeneticsData/Lightdisco_raw/'
  24. min_rate = 0.1
  25. for ex in experiments:
  26. if not extract_from_hd:
  27. di = data_dictionary_di(ex)
  28. saveresults = savepath + 'rawdata_' + str(ex)
  29. a = np.load(saveresults + '.npy')
  30. final_results = a.item()
  31. sorted_dict = final_results['sorted_dict']
  32. data_dict = final_results['data_dict']
  33. burst_dict = final_results['burst_dict']
  34. savepath_results_main = savepath + 'analysis_'+ str(ex) + '/'
  35. if plotanalog:
  36. #Just in case, the analog channel should be plottet
  37. onlyfiles, pathtofiles, di = data_directory(ex, HD)
  38. analog_data_file = np.array([])
  39. rawdata_stim = []
  40. shrinkage = 20
  41. for fil in onlyfiles:
  42. fd = ns.File (pathtofiles + fil)
  43. print "file: " + fil + " opened."
  44. for entity in fd.list_entities():
  45. if entity.label[0:4]=='anlg':
  46. rawdata_stim = fd.entities[entity.id]
  47. if not rawdata_stim == []:
  48. data_section, times, count = rawdata_stim.get_data()
  49. snippet_shrinked = np.array([ mean( data_section[shrinkage*i:shrinkage*(i+1)] ) for i in np.arange(0,len(data_section)/shrinkage) ])
  50. analog_data_file = np.append(analog_data_file,snippet_shrinked)
  51. fd.close()
  52. figure(3)
  53. plot(analog_data_file[::10])
  54. if extract_from_hd:
  55. onlyfiles, pathtofiles, di = data_directory(ex, HD)
  56. savepath_results_main = savepath + 'analysis_'+ str(ex) + '/'
  57. if not (isdir(savepath_results_main)):
  58. print 'Folder ' + savepath_results_main + ' created!'
  59. makedirs(savepath_results_main)
  60. ######################################################
  61. # extract data from the mcd files
  62. ######################################################
  63. fd = ns.File (pathtofiles + onlyfiles[0])
  64. numberofspikeentities = len( [int(entity.label[-2:]) for entity in fd.list_entities() if entity.entity_type == 3] )
  65. waveforms = [[] for j in range(0,numberofspikeentities)]
  66. spiketimes = [[] for j in range(0,numberofspikeentities)]
  67. fd.close()
  68. for fil in onlyfiles:
  69. fd = ns.File (pathtofiles + fil)
  70. print ".........................................................................."
  71. print "File: " + fil + " opened."
  72. dt = datetime.datetime(fd.metadata_raw['Time_Year'], fd.metadata_raw['Time_Month'], fd.metadata_raw['Time_Day'], fd.metadata_raw['Time_Hour'], fd.metadata_raw['Time_Min'], fd.metadata_raw['Time_Sec'], fd.metadata_raw['Time_MilliSec']*1000, tzinfo = None)
  73. epoch = datetime.datetime.utcfromtimestamp(0);
  74. delta = dt - epoch
  75. t0 = delta.total_seconds()
  76. print 'Current time: ', t0, '\n\n'
  77. samplingrate = 25000.
  78. elenumber = []
  79. electrodes = \
  80. [11, 12, 13, 14, 15, 16, 17, 18,\
  81. 21, 22, 23, 24, 25, 26, 27, 28,\
  82. 31, 32, 33, 34, 35, 36, 37, 38,\
  83. 41, 42, 43, 44, 45, 46, 47, 48,\
  84. 51, 52, 53, 54, 55, 56, 57, 58,\
  85. 61, 62, 63, 64, 65, 66, 67, 68,\
  86. 71, 72, 73, 74, 75, 76, 77, 78,\
  87. 81, 82, 83, 84, 85, 86, 87, 88 ]
  88. spike_entities = [int(entity.label[-2:]) for entity in fd.list_entities() if entity.entity_type == 3]
  89. numberofspikeentities = size(spike_entities)
  90. elenumber = []
  91. counter = 0 ##Counts how many datasets are read in
  92. for entity in fd.list_entities():
  93. if entity.entity_type == 3:
  94. #print entity.label, entity.entity_type
  95. spikes1 = fd.entities[entity.id]
  96. if int(spikes1.label[-2:]) in electrodes:
  97. elenumber.append(int(spikes1.label[-2:]))
  98. for i in range(1, spikes1.item_count):
  99. waveform, time, a, b = spikes1.get_data(i)
  100. spiketimes[counter].append(time + t0)
  101. waveforms[counter].append(waveform[0])
  102. counter = counter + 1
  103. fd.close()
  104. allspikes = np.concatenate( [ np.array(spiketimes[c]) for c in range(counter - 1 )] )
  105. first_spike_in_data = np.min(allspikes)
  106. last_spike_in_data = np.max(allspikes)
  107. data_dict= {'waveforms': waveforms, 'spiketimes': spiketimes, 'elenumber': elenumber, 'length_recording': last_spike_in_data - first_spike_in_data}
  108. ######################################################
  109. # Process the data, False positive Sorting
  110. ######################################################
  111. if sorting_flag:
  112. sorting_flag = 'spiketimes_good'
  113. sorted_dict = sorting(data_dict)
  114. else:
  115. sorting_flag = 'spiketimes_good'
  116. sorted_dict= {'spiketimes_good': spiketimes}
  117. burst_dict = find_bursts(sorted_dict, data_dict, sorting_flag)
  118. burst_inf = extract_bursttimes_from_bursts( burst_dict, data_dict )
  119. global_bursts = find_global_bursts( burst_inf, data_dict, acc_time = 1.25 )
  120. good_spikes = sorted_dict['spiketimes_good']
  121. bad_spikes = sorted_dict['spiketimes_bad']
  122. good_channels = electrodes
  123. counter = 0
  124. cvs1 = []
  125. rates = []
  126. data = {}
  127. for ele in range(len(elenumber)):
  128. if (not spiketimes[ele] == []) and (elenumber[ele] in good_channels):
  129. data['Electrode_' + str(elenumber[ele])] = good_spikes[ele] - first_spike_in_data
  130. counter = counter + 1
  131. if len(good_spikes[ele]) > (last_spike_in_data - first_spike_in_data)*min_rate:
  132. mu = np.mean(np.diff(good_spikes[ele]))
  133. sigma = np.std(np.diff(good_spikes[ele]))
  134. cv = sigma/mu
  135. cvs1.append(cv)
  136. rates.append( len(good_spikes[ele])/(last_spike_in_data - first_spike_in_data) )
  137. ######################################################
  138. # Finally, save data as numpy file
  139. ######################################################
  140. final_results = {'data_dict': data_dict, 'sorted_dict': sorted_dict, 'burst_dict': burst_dict, 'burst_inf': burst_inf, 'global_bursts':global_bursts, 'cvs': cvs1, 'rates': rates, 'electrode_data': data}
  141. saveresults = savepath + 'rawdata_' + str(ex)
  142. np.save(saveresults + '.npy', final_results)
  143. ######################################################
  144. # And plot the spike trains of active channels
  145. ######################################################
  146. good_spikes = sorted_dict['spiketimes_good']
  147. bad_spikes = sorted_dict['spiketimes_bad']
  148. averaged_waveform_good = sorted_dict['mean_waveforms_good']
  149. averaged_waveform_good_sc = sorted_dict['std_waveforms_good']
  150. averaged_waveform_bad = sorted_dict['mean_waveforms_bad']
  151. averaged_waveform_bad_sc = sorted_dict['std_waveforms_bad']
  152. [gblist, principal_components, labels] = sorted_dict['principal_components']
  153. good_channels = burst_dict['good_channels']
  154. good_channels_file = di[ex]['channels']
  155. N_ele = len(data_dict['elenumber'])
  156. elenumber = data_dict['elenumber']
  157. averaged_waveform_all = [np.mean(data_dict['waveforms'][ele],axis=0) for ele in range(N_ele)] #average across time, not across channels!
  158. averaged_waveform_sc = [np.std(data_dict['waveforms'][ele],axis=0) for ele in range(N_ele)]
  159. waveformlength = 150
  160. shrinkage = 200
  161. samplingrate = 25000
  162. if sorting_flag:
  163. spiketimes = good_spikes
  164. else:
  165. spiketimes = spiketimes
  166. allspikes = np.concatenate( [ np.array(spiketimes[c]) for c in range(N_ele)] )
  167. first_spike_in_data = np.min(allspikes)
  168. last_spike_in_data = np.max(allspikes)
  169. for ele in range(N_ele):
  170. print "printing... " + str(ele)
  171. if (not spiketimes[ele] == []):
  172. figure(1, figsize = (14,14))
  173. plot(good_spikes[ele] - first_spike_in_data, ones(len(good_spikes[ele]))*elenumber[ele]/100,'b|')
  174. plot(bad_spikes[ele] - first_spike_in_data, ones(len(bad_spikes[ele]))*elenumber[ele]/100,'r|')
  175. if elenumber[ele] in good_channels_file:
  176. text(0,elenumber[ele]/100.,str(elenumber[ele]))
  177. figure(2)
  178. if len(good_spikes[ele]) > (last_spike_in_data - first_spike_in_data)*min_rate:
  179. hist(np.diff(good_spikes[ele]),np.linspace(0,1,100), alpha=0.5,normed = True, label = str(elenumber[ele]))
  180. xlim([0,1])
  181. figure(111, figsize = (14,14))
  182. xx = int(elenumber[ele]/10);
  183. yy = int(elenumber[ele] - 10*xx)-1
  184. if elenumber[ele] in good_channels_file:
  185. ax = subplot(8,8,xx + 8*yy, axisbg='lightyellow')
  186. else:
  187. ax = subplot(8,8,xx + 8*yy)
  188. setp( ax.get_xticklabels(), visible=False)
  189. setp( ax.get_yticklabels(), visible=False)
  190. setp( ax.get_xticklines(), visible=False)
  191. setp( ax.get_yticklines(), visible=False)
  192. plot(averaged_waveform_good[ele],color = 'b')
  193. fill_between(np.arange(waveformlength), averaged_waveform_good[ele]-averaged_waveform_good_sc[ele], averaged_waveform_good[ele]+averaged_waveform_good_sc[ele],color='b', alpha = 0.3)
  194. plot(averaged_waveform_bad[ele],color = 'r')
  195. fill_between(np.arange(waveformlength), averaged_waveform_bad[ele]-averaged_waveform_bad_sc[ele], averaged_waveform_bad[ele]+averaged_waveform_bad_sc[ele],color='r', alpha = 0.3)
  196. plot([0,waveformlength],[-20*1e-6,-20*1e-6],'k--')
  197. plot([0,waveformlength],[-10*1e-6,-10*1e-6],'k:')
  198. plot([0,waveformlength],[0,0],'k-')
  199. plot([0,waveformlength],[10*1e-6,10*1e-6],'k:')
  200. figure(112, figsize = (14,14))
  201. xx = int(elenumber[ele]/10);
  202. yy = int(elenumber[ele] - 10*xx)-1
  203. if elenumber[ele] in good_channels_file:
  204. ax = subplot(8,8,xx + 8*yy, axisbg='lightyellow')
  205. else:
  206. ax = subplot(8,8,xx + 8*yy)
  207. setp( ax.get_xticklabels(), visible=False)
  208. setp( ax.get_yticklabels(), visible=False)
  209. setp( ax.get_xticklines(), visible=False)
  210. setp( ax.get_yticklines(), visible=False)
  211. plot(averaged_waveform_all[ele],color = 'k')
  212. fill_between(np.arange(waveformlength), averaged_waveform_all[ele]-averaged_waveform_sc[ele], averaged_waveform_all[ele]+averaged_waveform_sc[ele],color='b', alpha = 0.3)
  213. plot([0,waveformlength],[-20*1e-6,-20*1e-6],'k--')
  214. plot([0,waveformlength],[-10*1e-6,-10*1e-6],'k:')
  215. plot([0,waveformlength],[0,0],'k-')
  216. plot([0,waveformlength],[10*1e-6,10*1e-6],'k:')
  217. if not principal_components[ele] == []:
  218. pcs = principal_components[ele][0]
  219. figure(11, figsize = (14,14))
  220. subplot(8,8,xx + 8*yy)
  221. for i in [ randint(0,len(principal_components[ele][0])-1) for i in np.arange(0,100) ]:
  222. if gblist[ele][i] == 0:
  223. co = 'r'
  224. else:
  225. co = 'b'
  226. xs, ys = pcs[i]
  227. plot(xs, ys, marker='.', color=co, markersize=1)
  228. axis('off')
  229. if elenumber[ele] in good_channels_file:
  230. plot([np.nanmin(pcs[:,0]), np.nanmax(pcs[:,0]), np.nanmax(pcs[:,0]), np.nanmin(pcs[:,0]), np.nanmin(pcs[:,0])],[np.nanmin(pcs[:,1]), np.nanmin(pcs[:,1]), np.nanmax(pcs[:,1]), np.nanmax(pcs[:,1]), np.nanmin(pcs[:,1])], 'y-', linewidth = 5)
  231. if do_savefig:
  232. figure(1)
  233. savefig(savepath_results_main + 'spiketrain.png', bbox_inches='tight')
  234. xlim([0,60])
  235. savefig(savepath_results_main + 'spiketrain_zoom.png', bbox_inches='tight')
  236. clf()
  237. close()
  238. figure(111)
  239. savefig(savepath_results_main + 'waveforms_sorted.pdf', bbox_inches='tight')
  240. clf()
  241. close()
  242. figure(112)
  243. savefig(savepath_results_main + 'waveforms_all.pdf', bbox_inches='tight')
  244. clf()
  245. close()
  246. figure(11)
  247. savefig(savepath_results_main + 'principal_components.pdf', bbox_inches='tight')
  248. clf()
  249. close()
  250. figure(2)
  251. legend()
  252. savefig(savepath_results_main + 'isi.pdf', bbox_inches='tight')
  253. clf()
  254. close()
  255. if plotanalog:
  256. figure(3)
  257. legend()
  258. savefig(savepath_results_main + 'analogcheck.png', bbox_inches='tight')
  259. clf()
  260. close()