figure1_analysis.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import os
  2. import numpy as np
  3. import nixio as nix
  4. import scipy.signal as signal
  5. import matplotlib.mlab as mlab
  6. from scipy.stats import circmean
  7. from ..util import gaussKernel, detect_eod_times, load_white_noise_stim
  8. def get_firing_rate(spikes, duration, dt, sigma):
  9. time = np.arange(0.0, duration, dt)
  10. rates = np.zeros((len(spikes), len(time)))
  11. k = gaussKernel(sigma, dt)
  12. for i, sp in enumerate(spikes):
  13. binary = np.zeros(len(time))
  14. binary[np.array(sp/dt, dtype=int)] = 1
  15. rates[i, :] = np.convolve(binary, k, mode="same")
  16. return time, rates
  17. def analyse_white_noise_data(stimulus, spikes):
  18. f, coherence_spectra = get_coherence(stimulus, spikes, 0.001)
  19. f, transfer_functions = get_transfer_function(stimulus, spikes, 0.001)
  20. return f, transfer_functions, coherence_spectra
  21. def create_eod_template(eod, eod_times, eod_indices, spike_times, avg_count=None):
  22. phases = np.zeros((len(spike_times)-1, 1))
  23. relative_times = np.zeros((len(spike_times)-1, 1))
  24. for i, sp in enumerate(spike_times[:-1]):
  25. idx = np.where(eod_times < sp)[0][-1]
  26. if idx > len(eod_times)-1:
  27. break
  28. period = eod_times[idx+1] - eod_times[idx]
  29. dt = sp - eod_times[idx]
  30. phases[i] = dt/period * 2 * np.pi
  31. relative_times[i] = dt
  32. indices = np.arange(2, len(eod_times[:-1]))
  33. np.random.shuffle(indices)
  34. indices = indices[:avg_count]
  35. x = np.arange(0.0, 2*np.pi+0.01, 0.1)
  36. eod_snippets = np.zeros((len(indices), len(x)))
  37. for i, idx in enumerate(indices):
  38. fp = eod[eod_indices[idx]:eod_indices[idx+1]]
  39. xp = np.linspace(0.0, 2 * np.pi, len(fp))
  40. eod_snippets[i, :] = np.interp(x, xp, fp)
  41. return np.mean(eod_snippets, axis=0), np.std(eod_snippets, axis=0), phases, relative_times
  42. def get_transfer_function(stimulus, spikes, sigma):
  43. csds = None
  44. psds = None
  45. f = None
  46. transfer_functions = None
  47. _, responses = get_firing_rate(spikes, 10., 1./20000., sigma)
  48. for i, r in enumerate(responses):
  49. f, csd = signal.csd(stimulus[1], responses[i, :], fs=20000, nperseg=2**14, noverlap=2**13,
  50. detrend="constant", window="hann")
  51. f, psd = signal.csd(stimulus[1], stimulus[1], fs=20000, nperseg=2**14, noverlap=2**13,
  52. detrend="constant", window="hann")
  53. if csds is None:
  54. csds = np.zeros((len(responses), len(f)), dtype=np.complex128)
  55. psds = np.zeros((len(responses), len(f)), dtype=np.complex128)
  56. transfer_functions = np.zeros((len(responses), len(f)), dtype=np.complex128)
  57. csds[i, :] = csd
  58. psds[i, :] = psd
  59. transfer_functions[i, :] = csd/psd
  60. return f, transfer_functions
  61. def get_coherence(stimulus, spikes, sigma=0.001):
  62. _, responses = get_firing_rate(spikes, 10., 1./20000., sigma)
  63. f = None
  64. coherence_spectra = None
  65. for i, r in enumerate(responses):
  66. c, f = mlab.cohere(r, stimulus[1], NFFT=2**14,
  67. noverlap=2**13, Fs=20000, detrend=mlab.detrend_mean, window=mlab.window_hanning)
  68. if coherence_spectra is None:
  69. coherence_spectra = np.zeros((len(responses), len(f)))
  70. coherence_spectra[i, :] = c
  71. return f, coherence_spectra
  72. def save_white_noise_data(time, stimulus, spikes, responses, frequency, coherence_spectra,
  73. gain_functions):
  74. if not os.path.exists(os.path.sep.join(["..", "data"])):
  75. os.mkdir(os.path.sep.join(["..", "data"]))
  76. trial_data = np.empty(0)
  77. spike_data = np.empty(0)
  78. for i, sp in enumerate(spikes):
  79. trial = np.ones_like(sp) * i
  80. trial_data = np.hstack((trial_data, trial))
  81. spike_data = np.hstack((spike_data, sp))
  82. np.savez_compressed(os.path.sep.join(["derived_data", "figure1_whitenoise_data"]),
  83. spikes=spike_data, trial_data=trial_data,
  84. time=time, rates=responses,
  85. stimulus=stimulus, frequency=frequency,
  86. coherence_spectra=coherence_spectra, gain_spectra=gain_functions)
  87. def load_white_noise_data_from_nix(block):
  88. tags = []
  89. tags = []
  90. tag_times = []
  91. mtags = []
  92. mtag_positions = []
  93. stims = []
  94. stim_contrasts = []
  95. for t in block.tags:
  96. if "filestimulus" in t.name.lower():
  97. secs = t.metadata.find_sections(lambda sec: "file" in sec.props)
  98. if len(secs) == 0:
  99. continue
  100. filename = secs[0].props["file"].values[0]
  101. if "gwn" not in filename:
  102. continue
  103. else:
  104. continue
  105. tags.append(t.name)
  106. tag_times.append((t.position[0], t.position[0] + t.extent[0]))
  107. for mt in block.multi_tags:
  108. pos = mt.positions[:]
  109. ext = mt.extents[:]
  110. for i, (p, e) in enumerate(zip(pos, ext)):
  111. if e < 0.01:
  112. continue
  113. if p > t.position[0] and (p + e) < t.position[0] + t.extent[0]:
  114. mtags.append(mt.name)
  115. mtag_positions.append(i)
  116. stims.append(filename)
  117. stim_contrasts.append(secs[0].props["contrast"].values[0] * 100)
  118. max_contr = np.max(stim_contrasts)
  119. unique_stims = np.unique(stims)
  120. stim_dict = {}
  121. stimuli = {}
  122. for s in unique_stims:
  123. stim_dict[s] = []
  124. for i, c in enumerate(stim_contrasts):
  125. if c == max_contr:
  126. stim_dict[s].append(i)
  127. time, stim = load_white_noise_stim(s, stim_duration=10, sampling_rate=20000,
  128. folder="stimuli")
  129. if time is None:
  130. return None, None, None
  131. stimuli[s] = (time, stim)
  132. # at this point we should have the stimulus used at max contrast,
  133. # the respective indices, and the stimulus itself
  134. # let's work only on the first stimulus
  135. stimulus = list(stimuli.keys())[0]
  136. positions = np.array(mtag_positions)[stim_dict[stimulus]]
  137. spike_times = []
  138. for p in positions:
  139. mt = block.multi_tags[mtags[p]]
  140. st = mt.retrieve_data(p, "Spikes-1")[:] - mt.positions[p][0]
  141. if st[0] < 0:
  142. st = st[1:]
  143. spike_times.append(st)
  144. return stimuli[stimulus], spike_times
  145. def analyze_driven_activity(args):
  146. if not os.path.exists(args.dataset):
  147. raise ValueError("Dataset %s was not found!" % args.dataset)
  148. nf = nix.File.open(args.dataset, nix.FileMode.ReadOnly)
  149. block = nf.blocks[0]
  150. stim, spikes = load_white_noise_data_from_nix(block)
  151. nf.close()
  152. stimulus = (stim[0], stim[1])
  153. f, tfs, cspecs = analyse_white_noise_data(stimulus, spikes)
  154. time, rates = get_firing_rate(spikes, 10., 1./20000, 0.001)
  155. save_white_noise_data(time, stimulus, spikes, rates, f, cspecs, tfs)
  156. def analyze_baseline_activity(args):
  157. if not os.path.exists(args.dataset):
  158. raise ValueError("Dataset %s was not found!" % args.dataset)
  159. nf = nix.File.open(args.dataset, nix.FileMode.ReadOnly)
  160. block = nf.blocks[0]
  161. baseline_tag = None
  162. for t in block.tags:
  163. if "baseline" in t.name.lower():
  164. if baseline_tag:
  165. baseline_tag = baseline_tag if t.extent[0] <= baseline_tag.extent[0] else t
  166. else:
  167. baseline_tag = t
  168. if not baseline_tag:
  169. nf.close()
  170. raise ValueError("No baseline data found in dataset %s!" % args.dataset)
  171. baseline_eod = baseline_tag.retrieve_data("EOD")[:]
  172. baseline_voltage = baseline_tag.retrieve_data("V-1")[:]
  173. baseline_spikes = baseline_tag.retrieve_data("Spikes-1")[:]
  174. baseline_time = np.asarray(baseline_tag.references["EOD"].dimensions[0].axis(len(baseline_eod)))
  175. nf.close()
  176. baseline_rate = len(baseline_spikes)/(baseline_time[-1] - baseline_time[0])
  177. # eod times
  178. eod_times, eod_indices = detect_eod_times(baseline_time, baseline_eod)
  179. eodf = len(eod_times)/(baseline_time[-1]-baseline_time[0])
  180. # spike times
  181. isis = np.diff(baseline_spikes)
  182. cv = np.std(isis)/np.mean(isis)
  183. burstiness = len(isis[isis < 1.5 * np.mean(np.diff(eod_times))])/len(isis)
  184. # phase locking
  185. eod_template, template_std, spike_phases, spike_times_rel = create_eod_template(baseline_eod,
  186. eod_times,
  187. eod_indices,
  188. baseline_spikes,
  189. avg_count=1000)
  190. eod_ampl = np.max(eod_template) - np.min(eod_template)
  191. phase_locking = np.mean(np.exp(1j * 2*np.pi*np.mean(1/np.diff(eod_times)) * spike_times_rel))
  192. save_baseline_data(baseline_time, baseline_eod, baseline_voltage, baseline_spikes,
  193. spike_phases, spike_times_rel, eod_times, eod_template, template_std,
  194. isis, cv, burstiness, phase_locking, eodf, baseline_rate)
  195. def save_baseline_data(time, eod, voltage, spikes, phases, spike_cycle_times,
  196. eod_times, eod_template, template_std, isis, cv, bi, vs, eodf, baseline_rate):
  197. if not os.path.exists(os.path.join(["./", "derived_data"])):
  198. os.mkdir(os.path.join(["./", "derived_data"]))
  199. np.savez_compressed(os.path.join(["derived_data", "figure1_baseline_data"]), time=time, eod=eod,
  200. voltage=voltage, spike_times=spikes[:-1], spike_phases=np.squeeze(phases), spike_cycle_times=np.squeeze(spike_cycle_times), eod_times=eod_times, eod_template=eod_template, template_error=template_std, isis=isis, cv=cv, burstiness=bi,
  201. vector_strength=vs, eodf=eodf, baseline_rate=baseline_rate,
  202. preferred_phase=circmean(np.squeeze(phases)))