circuit_functions.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. #================================================================================
  2. #= Import
  3. #================================================================================
  4. import os
  5. import time
  6. tic = time.perf_counter()
  7. from os.path import join
  8. import sys
  9. import zipfile
  10. import matplotlib
  11. import matplotlib.pyplot as plt
  12. from matplotlib.gridspec import GridSpec
  13. from mpl_toolkits.mplot3d import Axes3D
  14. from matplotlib.collections import LineCollection
  15. from matplotlib.collections import PolyCollection
  16. import numpy as np
  17. np.seterr(divide='ignore', invalid='ignore')
  18. import scipy
  19. from scipy import signal as ss
  20. from scipy import stats as st
  21. from mpi4py import MPI
  22. import math
  23. import neuron
  24. from neuron import h, gui
  25. import LFPy
  26. from LFPy import NetworkCell, Network, Synapse, RecExtElectrode, StimIntElectrode
  27. # from net_params import *
  28. import warnings
  29. warnings.filterwarnings('ignore')
  30. import pandas as pd
  31. import itertools
  32. #================================================================================
  33. #= Controls
  34. #================================================================================
  35. # Frank ===============
  36. # Alex ===============
  37. plotnetworksomas = False
  38. plotrasterandrates = True
  39. plotephistimseriesandPSD = False
  40. plotsomavs = False # Specify cell indices to plot in 'cell_indices_to_plot' - Note: plotting too many cells can randomly cause TCP connection errors
  41. # Kant ===============
  42. #================================================================================
  43. #= Analysis
  44. #================================================================================
  45. #===============================
  46. #= Analysis Parameters
  47. #===============================
  48. transient = 2000 #used for plotting and analysis
  49. radii = [79000., 80000., 85000., 90000.] #4sphere model
  50. sigmas = [0.3, 1.5, 0.015, 0.3] #conductivity
  51. L23_pos = np.array([0., 0., 78200.]) #single dipole refernece for EEG/ECoG
  52. EEG_sensor = np.array([[0., 0., 90000]])
  53. ECoG_sensor = np.array([[0., 0., 79000]])
  54. EEG_args = LFPy.FourSphereVolumeConductor(radii, sigmas, EEG_sensor)
  55. ECoG_args = LFPy.FourSphereVolumeConductor(radii, sigmas, ECoG_sensor)
  56. sampling_rate = (1/0.025)*1000
  57. nperseg = int(sampling_rate/2)
  58. t1 = int(transient/0.025)
  59. #===============================
  60. def bandPassFilter(signal,low=.1, high=100.):
  61. order = 2
  62. # z, p, k = ss.butter(order, [low,high],btype='bandpass',fs=sampling_rate,output='zpk')
  63. # sos = ss.zpk2sos(z, p, k)
  64. # y = ss.sosfiltfilt(sos, signal)
  65. b, a = ss.butter(order, [low,high],btype='bandpass',fs=sampling_rate)
  66. y = ss.filtfilt(b, a, signal)
  67. return y
  68. #================================================================================
  69. #= Plotting
  70. #================================================================================
  71. #===============================
  72. #= Frank
  73. #===============================
  74. pop_colors = {'HL23PYR':'k', 'HL23PV':'red', 'HL23SST':'green', 'HL23VIP':'yellow'}
  75. popnames = ['HL23PYR', 'HL23PV', 'HL23SST', 'HL23VIP']
  76. #===============================
  77. #===============================
  78. #= Alex
  79. #===============================
  80. pop_colors = {'HL23PYR':'k', 'HL23PV':'crimson', 'HL23SST':'green', 'HL23VIP':'darkorange'}
  81. popnames = ['HL23PYR', 'HL23PV', 'HL23SST', 'HL23VIP']
  82. poplabels = ['PN', 'MN', 'BN', 'VN']
  83. font = {'family' : 'normal',
  84. 'weight' : 'normal',
  85. 'size' : 14}
  86. matplotlib.rc('font', **font)
  87. matplotlib.rc('legend',**{'fontsize':16})
  88. #===============================
  89. # Plot soma positions
  90. def plot_network_somas(OUTPUTPATH):
  91. filename = os.path.join(OUTPUTPATH,'cell_positions_and_rotations.h5')
  92. popDataArray = {}
  93. popDataArray[popnames[0]] = pd.read_hdf(filename,popnames[0])
  94. popDataArray[popnames[0]] = popDataArray[popnames[0]].sort_values('gid')
  95. popDataArray[popnames[1]] = pd.read_hdf(filename,popnames[1])
  96. popDataArray[popnames[1]] = popDataArray[popnames[1]].sort_values('gid')
  97. popDataArray[popnames[2]] = pd.read_hdf(filename,popnames[2])
  98. popDataArray[popnames[2]] = popDataArray[popnames[2]].sort_values('gid')
  99. popDataArray[popnames[3]] = pd.read_hdf(filename,popnames[3])
  100. popDataArray[popnames[3]] = popDataArray[popnames[3]].sort_values('gid')
  101. fig = plt.figure(figsize=(5, 5))
  102. ax = fig.add_subplot(111, projection='3d')
  103. ax.view_init(elev=5)
  104. for pop in popnames:
  105. for i in range(0,len(popDataArray[pop]['gid'])):
  106. ax.scatter(popDataArray[pop]['x'][i],popDataArray[pop]['y'][i],popDataArray[pop]['z'][i], c=pop_colors[pop], s=5)
  107. ax.set_xlim([-300, 300])
  108. ax.set_ylim([-300, 300])
  109. ax.set_zlim([-1200, -400])
  110. return fig
  111. # Plot spike raster plots & spike rates
  112. def plot_raster_and_rates(SPIKES,tstart_plot,tstop_plot,popnames,N_cells,network,OUTPUTPATH,GLOBALSEED):
  113. colors = ['dimgray', 'crimson', 'green', 'darkorange']
  114. fig = plt.figure(figsize=(10, 8))
  115. ax1 =fig.add_subplot(111)
  116. for name, spts, gids in zip(popnames, SPIKES['times'], SPIKES['gids']):
  117. t = []
  118. g = []
  119. for spt, gid in zip(spts, gids):
  120. t = np.r_[t, spt]
  121. g = np.r_[g, np.zeros(spt.size)+gid]
  122. ax1.plot(t[t >= tstart_plot], g[t >= tstart_plot], '|', color=pop_colors[name])
  123. ax1.set_ylim(0,N_cells)
  124. halftime = 750
  125. plt1 = int(tstart_plot+((tstop_plot-tstart_plot)/2)-halftime)
  126. plt2 = int(tstart_plot+((tstop_plot-tstart_plot)/2)+halftime)
  127. ax1.set_xlim(plt1,plt2)
  128. ax1.set_xlabel('Time (ms)')
  129. ax1.set_ylabel('Cell Number')
  130. PN = []
  131. MN = []
  132. BN = []
  133. VN = []
  134. SPIKE_list = [PN ,MN, BN, VN]
  135. SPIKE_MEANS = []
  136. SPIKE_STDEV = []
  137. for i in range(4):
  138. for j in range(len(SPIKES['times'][i])):
  139. scount = SPIKES['times'][i][j][SPIKES['times'][i][j]>transient]
  140. Hz = np.array([(scount.size)/((int(network.tstop)-transient)/1000)])
  141. SPIKE_list[i].append(Hz)
  142. SPIKE_MEANS.append(np.mean(SPIKE_list[i]))
  143. SPIKE_STDEV.append(np.std(SPIKE_list[i]))
  144. meanstdevstr1 = '\n' + str(np.around(SPIKE_MEANS[0], decimals=2)) + r' $\pm$ '+ str(np.around(SPIKE_STDEV[0], decimals=2))
  145. meanstdevstr2 = '\n' + str(np.around(SPIKE_MEANS[1], decimals=2)) + r' $\pm$ '+ str(np.around(SPIKE_STDEV[1], decimals=2))
  146. meanstdevstr3 = '\n' + str(np.around(SPIKE_MEANS[2], decimals=2)) + r' $\pm$ '+ str(np.around(SPIKE_STDEV[2], decimals=2))
  147. meanstdevstr4 = '\n' + str(np.around(SPIKE_MEANS[3], decimals=2)) + r' $\pm$ '+ str(np.around(SPIKE_STDEV[3], decimals=2))
  148. names = [poplabels[0]+meanstdevstr1,poplabels[1]+meanstdevstr2,poplabels[2]+meanstdevstr3,poplabels[3]+meanstdevstr4]
  149. Hzs_mean = np.array(SPIKE_MEANS)
  150. np.savetxt(os.path.join(OUTPUTPATH,'spikerates_Seed' + str(int(GLOBALSEED)) + '.txt'),Hzs_mean)
  151. w = 0.8
  152. fig2 = plt.figure(figsize=(10, 8))
  153. ax2 = fig2.add_subplot(111)
  154. ax2.bar(x = [0, 1, 2, 3],
  155. height=[pop for pop in SPIKE_MEANS],
  156. yerr=[np.std(pop) for pop in SPIKE_list],
  157. capsize=12,
  158. width=w,
  159. tick_label=names,
  160. color=[clr for clr in colors],
  161. edgecolor='k',
  162. ecolor='black',
  163. linewidth=4,
  164. error_kw={'elinewidth':3,'markeredgewidth':3})
  165. ax2.set_ylabel('Spike Frequency (Hz)')
  166. ax2.grid(False)
  167. return fig, fig2
  168. # Plot spike time histograms
  169. def plot_spiketimehists(SPIKES,network):
  170. colors = ['dimgray', 'crimson', 'green', 'darkorange']
  171. binsize = 10 # ms
  172. numbins = int((network.tstop - transient)/binsize)
  173. fig, axarr = plt.subplots(len(colors),1)
  174. for i in range(len(colors)):
  175. popspikes = list(itertools.chain.from_iterable(SPIKES['times'][i]))
  176. popspikes = [i for i in popspikes if i > transient]
  177. axarr[i].hist(popspikes,bins=numbins,color=colors[i],linewidth=0,edgecolor='none',range=(2000,7000))
  178. axarr[i].set_xlim(transient,network.tstop)
  179. if i < len(colors)-1:
  180. axarr[i].set_xticks([])
  181. axarr[-1:][0].set_xlabel('Time (ms)')
  182. return fig
  183. # Plot EEG & ECoG voltages & PSDs
  184. def plot_eeg(network,DIPOLEMOMENT):
  185. low = .1
  186. high = 100
  187. DP = DIPOLEMOMENT['HL23PYR']
  188. for pop in popnames[1:]:
  189. DP = np.add(DP,DIPOLEMOMENT[pop])
  190. EEG = EEG_args.calc_potential(DP, L23_pos)
  191. ECoG = ECoG_args.calc_potential(DP, L23_pos)
  192. EEG = EEG[0]
  193. ECoG = ECoG[0]
  194. EEG_filt = bandPassFilter(EEG[t1:])
  195. ECoG_filt = bandPassFilter(ECoG[t1:])
  196. EEG_freq, EEG_ps = ss.welch(EEG_filt[t1:], fs=sampling_rate, nperseg=nperseg)
  197. ECoG_freq, ECoG_ps = ss.welch(ECoG_filt[t1:], fs=sampling_rate, nperseg=nperseg)
  198. EEGraw_freq, EEGraw_ps = ss.welch(EEG[t1:], fs=sampling_rate, nperseg=nperseg)
  199. ECoGraw_freq, ECoGraw_ps = ss.welch(ECoG[t1:], fs=sampling_rate, nperseg=nperseg)
  200. tvec = np.arange((network.tstop)/(1000/sampling_rate)+1)*(1000/sampling_rate)
  201. fig = plt.figure(figsize=(10,10))
  202. ax1 = fig.add_subplot(221)
  203. ax1.plot(tvec[t1:], EEG_filt, c='k')
  204. ax1.set_xlim(transient,network.tstop)
  205. ax1.set_ylabel('EEG (mV)')
  206. ax2 = fig.add_subplot(222)
  207. ax2.plot(EEG_freq, EEG_ps, c='k')
  208. ax2.set_xlim(0,100)
  209. ax3 = fig.add_subplot(223)
  210. ax3.plot(tvec[t1:], ECoG_filt, c='k')
  211. ax3.set_xlim(transient,network.tstop)
  212. ax3.set_ylabel('ECoG (mV)')
  213. ax3.set_xlabel('Time (ms)')
  214. ax4 = fig.add_subplot(224)
  215. ax4.plot(ECoG_freq, ECoG_ps, c='k')
  216. ax4.set_xlim(0,100)
  217. ax4.set_xlabel('Frequency (Hz)')
  218. fig2 = plt.figure(figsize=(10,10))
  219. ax21 = fig2.add_subplot(221)
  220. ax21.plot(tvec[t1:], EEG[t1:], c='k')
  221. ax21.set_xlim(transient,network.tstop)
  222. ax21.set_ylabel('EEG (mV)')
  223. ax22 = fig2.add_subplot(222)
  224. ax22.plot(EEGraw_freq, EEGraw_ps, c='k')
  225. ax22.set_xlim(0,100)
  226. ax23 = fig2.add_subplot(223)
  227. ax23.plot(tvec[t1:], ECoG[t1:], c='k')
  228. ax23.set_xlim(transient,network.tstop)
  229. ax23.set_ylabel('ECoG (mV)')
  230. ax23.set_xlabel('Time (ms)')
  231. ax24 = fig2.add_subplot(224)
  232. ax24.plot(ECoGraw_freq, ECoGraw_ps, c='k')
  233. ax24.set_xlim(0,100)
  234. ax24.set_xlabel('Frequency (Hz)')
  235. return fig, fig2
  236. # Plot LFP voltages & PSDs
  237. def plot_lfp(network,OUTPUT):
  238. LFP1_freq, LFP1_ps = ss.welch(OUTPUT[0]['imem'][0][t1:], fs=sampling_rate, nperseg=nperseg)
  239. LFP2_freq, LFP2_ps = ss.welch(OUTPUT[0]['imem'][1][t1:], fs=sampling_rate, nperseg=nperseg)
  240. LFP3_freq, LFP3_ps = ss.welch(OUTPUT[0]['imem'][2][t1:], fs=sampling_rate, nperseg=nperseg)
  241. tvec = np.arange((network.tstop)/(1000/sampling_rate)+1)*(1000/sampling_rate)
  242. fig = plt.figure(figsize=(10,10))
  243. ax1 = fig.add_subplot(311)
  244. ax1.plot(tvec[t1:],OUTPUT[0]['imem'][0][t1:],'k')
  245. ax1.set_xlim(transient,network.tstop)
  246. ax2 = fig.add_subplot(312)
  247. ax2.plot(tvec[t1:],OUTPUT[0]['imem'][1][t1:],'k')
  248. ax2.set_ylabel('LFP (mV)')
  249. ax2.set_xlim(transient,network.tstop)
  250. ax3 = fig.add_subplot(313)
  251. ax3.plot(tvec[t1:],OUTPUT[0]['imem'][2][t1:],'k')
  252. ax3.set_xlabel('Time (ms)')
  253. ax3.set_xlim(transient,network.tstop)
  254. fig2 = plt.figure(figsize=(10,10))
  255. ax21 = fig2.add_subplot(311)
  256. ax21.plot(LFP1_freq,LFP1_ps,'k')
  257. ax21.set_xlim(0,100)
  258. ax22 = fig2.add_subplot(312)
  259. ax22.plot(LFP2_freq,LFP2_ps,'k')
  260. ax22.set_ylabel('PSD')
  261. ax22.set_xlim(0,100)
  262. ax23 = fig2.add_subplot(313)
  263. ax23.plot(LFP3_freq,LFP3_ps,'k')
  264. ax23.set_xlabel('Frequency (Hz)')
  265. ax23.set_xlim(0,100)
  266. return fig, fig2
  267. # Collect Somatic Voltages Across Ranks
  268. def somavCollect(network,cellindices,RANK,SIZE,COMM):
  269. if RANK == 0:
  270. volts = []
  271. gids2 = []
  272. for i, pop in enumerate(network.populations):
  273. svolts = []
  274. sgids = []
  275. for gid, cell in zip(network.populations[pop].gids, network.populations[pop].cells):
  276. if gid in cellindices:
  277. svolts.append(cell.somav)
  278. sgids.append(gid)
  279. volts.append([])
  280. gids2.append([])
  281. volts[i] += svolts
  282. gids2[i] += sgids
  283. for j in range(1, SIZE):
  284. volts[i] += COMM.recv(source=j, tag=15)
  285. gids2[i] += COMM.recv(source=j, tag=16)
  286. else:
  287. volts = None
  288. gids2 = None
  289. for i, pop in enumerate(network.populations):
  290. svolts = []
  291. sgids = []
  292. for gid, cell in zip(network.populations[pop].gids, network.populations[pop].cells):
  293. if gid in cellindices:
  294. svolts.append(cell.somav)
  295. sgids.append(gid)
  296. COMM.send(svolts, dest=0, tag=15)
  297. COMM.send(sgids, dest=0, tag=16)
  298. return dict(volts=volts, gids2=gids2)
  299. # Plot somatic voltages for each population
  300. def plot_somavs(network,VOLTAGES):
  301. tvec = np.arange(network.tstop/network.dt+1)*network.dt
  302. fig = plt.figure(figsize=(10,5))
  303. cls = ['black','crimson','green','darkorange']
  304. for i, pop in enumerate(network.populations):
  305. for v in range(0,len(VOLTAGES['volts'][i])):
  306. ax = plt.subplot2grid((len(VOLTAGES['volts']), len(VOLTAGES['volts'][i])), (i, v), rowspan=1, colspan=1, frameon=False)
  307. ax.plot(tvec,VOLTAGES['volts'][i][v], c=cls[i])
  308. ax.set_xlim(transient,network.tstop)
  309. ax.set_ylim(-85,40)
  310. if i < len(VOLTAGES['volts'])-1:
  311. ax.set_xticks([])
  312. if v > 0:
  313. ax.set_yticks([])
  314. return fig
  315. # Run Plot Functions
  316. if plotsomavs:
  317. VOLTAGES = somavCollect(network,cell_indices_to_plot,RANK,SIZE,COMM)
  318. N_cells = 1000
  319. if RANK ==0:
  320. if plotnetworksomas:
  321. fig = plot_network_somas(OUTPUTPATH)
  322. fig.savefig(os.path.join(OUTPUTPATH,'network_somas_'+str(GLOBALSEED)),bbox_inches='tight', dpi=300)
  323. if plotrasterandrates:
  324. fig, fig2 = plot_raster_and_rates(SPIKES,tstart_plot,tstop_plot,popnames,N_cells,network,OUTPUTPATH,GLOBALSEED)
  325. fig.savefig(os.path.join(OUTPUTPATH,'raster_'+str(GLOBALSEED)),bbox_inches='tight', dpi=300)
  326. fig2.savefig(os.path.join(OUTPUTPATH,'rates_'+str(GLOBALSEED)),bbox_inches='tight', dpi=300)
  327. fig = plot_spiketimehists(SPIKES,network)
  328. fig.savefig(os.path.join(OUTPUTPATH,'spiketimes_'+str(GLOBALSEED)),bbox_inches='tight', dpi=300)
  329. if plotephistimseriesandPSD:
  330. fig, fig2 = plot_eeg(network,DIPOLEMOMENT)
  331. fig.savefig(os.path.join(OUTPUTPATH,'eeg_filt_'+str(GLOBALSEED)),bbox_inches='tight', dpi=300)
  332. fig2.savefig(os.path.join(OUTPUTPATH,'eeg_raw_'+str(GLOBALSEED)),bbox_inches='tight', dpi=300)
  333. fig, fig2 = plot_lfp(network,OUTPUT)
  334. fig.savefig(os.path.join(OUTPUTPATH,'lfps_traces_'+str(GLOBALSEED)),bbox_inches='tight', dpi=300)
  335. fig2.savefig(os.path.join(OUTPUTPATH,'lfps_PSDs_'+str(GLOBALSEED)),bbox_inches='tight', dpi=300)
  336. if plotsomavs:
  337. fig = plot_somavs(network,VOLTAGES)
  338. fig.savefig(os.path.join(OUTPUTPATH,'somav_'+str(GLOBALSEED)),bbox_inches='tight', dpi=300)
  339. #===============================
  340. # Kant
  341. #===============================
  342. pop_colors = {'HL23PYR':'k', 'HL23PV':'red', 'HL23SST':'green', 'HL23VIP':'yellow'}
  343. popnames = ['HL23PYR', 'HL23PV', 'HL23SST', 'HL23VIP']
  344. #===============================