123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388 |
- #================================================================================
- #= Import
- #================================================================================
- import os
- import time
- tic = time.perf_counter()
- from os.path import join
- import sys
- import zipfile
- import matplotlib
- import matplotlib.pyplot as plt
- from matplotlib.gridspec import GridSpec
- from mpl_toolkits.mplot3d import Axes3D
- from matplotlib.collections import LineCollection
- from matplotlib.collections import PolyCollection
- import numpy as np
- np.seterr(divide='ignore', invalid='ignore')
- import scipy
- from scipy import signal as ss
- from scipy import stats as st
- from mpi4py import MPI
- import math
- import neuron
- from neuron import h, gui
- import LFPy
- from LFPy import NetworkCell, Network, Synapse, RecExtElectrode, StimIntElectrode
- # from net_params import *
- import warnings
- warnings.filterwarnings('ignore')
- import pandas as pd
- import itertools
- #================================================================================
- #= Controls
- #================================================================================
- # Frank ===============
- # Alex ===============
- plotnetworksomas = False
- plotrasterandrates = True
- plotephistimseriesandPSD = False
- plotsomavs = False # Specify cell indices to plot in 'cell_indices_to_plot' - Note: plotting too many cells can randomly cause TCP connection errors
- # Kant ===============
- #================================================================================
- #= Analysis
- #================================================================================
- #===============================
- #= Analysis Parameters
- #===============================
- transient = 2000 #used for plotting and analysis
- radii = [79000., 80000., 85000., 90000.] #4sphere model
- sigmas = [0.3, 1.5, 0.015, 0.3] #conductivity
- L23_pos = np.array([0., 0., 78200.]) #single dipole refernece for EEG/ECoG
- EEG_sensor = np.array([[0., 0., 90000]])
- ECoG_sensor = np.array([[0., 0., 79000]])
- EEG_args = LFPy.FourSphereVolumeConductor(radii, sigmas, EEG_sensor)
- ECoG_args = LFPy.FourSphereVolumeConductor(radii, sigmas, ECoG_sensor)
- sampling_rate = (1/0.025)*1000
- nperseg = int(sampling_rate/2)
- t1 = int(transient/0.025)
- #===============================
- def bandPassFilter(signal,low=.1, high=100.):
- order = 2
- # z, p, k = ss.butter(order, [low,high],btype='bandpass',fs=sampling_rate,output='zpk')
- # sos = ss.zpk2sos(z, p, k)
- # y = ss.sosfiltfilt(sos, signal)
- b, a = ss.butter(order, [low,high],btype='bandpass',fs=sampling_rate)
- y = ss.filtfilt(b, a, signal)
- return y
- #================================================================================
- #= Plotting
- #================================================================================
- #===============================
- #= Frank
- #===============================
- pop_colors = {'HL23PYR':'k', 'HL23PV':'red', 'HL23SST':'green', 'HL23VIP':'yellow'}
- popnames = ['HL23PYR', 'HL23PV', 'HL23SST', 'HL23VIP']
- #===============================
- #===============================
- #= Alex
- #===============================
- pop_colors = {'HL23PYR':'k', 'HL23PV':'crimson', 'HL23SST':'green', 'HL23VIP':'darkorange'}
- popnames = ['HL23PYR', 'HL23PV', 'HL23SST', 'HL23VIP']
- poplabels = ['PN', 'MN', 'BN', 'VN']
- font = {'family' : 'normal',
- 'weight' : 'normal',
- 'size' : 14}
- matplotlib.rc('font', **font)
- matplotlib.rc('legend',**{'fontsize':16})
- #===============================
- # Plot soma positions
- def plot_network_somas(OUTPUTPATH):
- filename = os.path.join(OUTPUTPATH,'cell_positions_and_rotations.h5')
- popDataArray = {}
- popDataArray[popnames[0]] = pd.read_hdf(filename,popnames[0])
- popDataArray[popnames[0]] = popDataArray[popnames[0]].sort_values('gid')
- popDataArray[popnames[1]] = pd.read_hdf(filename,popnames[1])
- popDataArray[popnames[1]] = popDataArray[popnames[1]].sort_values('gid')
- popDataArray[popnames[2]] = pd.read_hdf(filename,popnames[2])
- popDataArray[popnames[2]] = popDataArray[popnames[2]].sort_values('gid')
- popDataArray[popnames[3]] = pd.read_hdf(filename,popnames[3])
- popDataArray[popnames[3]] = popDataArray[popnames[3]].sort_values('gid')
- fig = plt.figure(figsize=(5, 5))
- ax = fig.add_subplot(111, projection='3d')
- ax.view_init(elev=5)
- for pop in popnames:
- for i in range(0,len(popDataArray[pop]['gid'])):
- ax.scatter(popDataArray[pop]['x'][i],popDataArray[pop]['y'][i],popDataArray[pop]['z'][i], c=pop_colors[pop], s=5)
- ax.set_xlim([-300, 300])
- ax.set_ylim([-300, 300])
- ax.set_zlim([-1200, -400])
- return fig
- # Plot spike raster plots & spike rates
- def plot_raster_and_rates(SPIKES,tstart_plot,tstop_plot,popnames,N_cells,network,OUTPUTPATH,GLOBALSEED):
- colors = ['dimgray', 'crimson', 'green', 'darkorange']
- fig = plt.figure(figsize=(10, 8))
- ax1 =fig.add_subplot(111)
- for name, spts, gids in zip(popnames, SPIKES['times'], SPIKES['gids']):
- t = []
- g = []
- for spt, gid in zip(spts, gids):
- t = np.r_[t, spt]
- g = np.r_[g, np.zeros(spt.size)+gid]
- ax1.plot(t[t >= tstart_plot], g[t >= tstart_plot], '|', color=pop_colors[name])
- ax1.set_ylim(0,N_cells)
- halftime = 750
- plt1 = int(tstart_plot+((tstop_plot-tstart_plot)/2)-halftime)
- plt2 = int(tstart_plot+((tstop_plot-tstart_plot)/2)+halftime)
- ax1.set_xlim(plt1,plt2)
- ax1.set_xlabel('Time (ms)')
- ax1.set_ylabel('Cell Number')
- PN = []
- MN = []
- BN = []
- VN = []
- SPIKE_list = [PN ,MN, BN, VN]
- SPIKE_MEANS = []
- SPIKE_STDEV = []
- for i in range(4):
- for j in range(len(SPIKES['times'][i])):
- scount = SPIKES['times'][i][j][SPIKES['times'][i][j]>transient]
- Hz = np.array([(scount.size)/((int(network.tstop)-transient)/1000)])
- SPIKE_list[i].append(Hz)
- SPIKE_MEANS.append(np.mean(SPIKE_list[i]))
- SPIKE_STDEV.append(np.std(SPIKE_list[i]))
- meanstdevstr1 = '\n' + str(np.around(SPIKE_MEANS[0], decimals=2)) + r' $\pm$ '+ str(np.around(SPIKE_STDEV[0], decimals=2))
- meanstdevstr2 = '\n' + str(np.around(SPIKE_MEANS[1], decimals=2)) + r' $\pm$ '+ str(np.around(SPIKE_STDEV[1], decimals=2))
- meanstdevstr3 = '\n' + str(np.around(SPIKE_MEANS[2], decimals=2)) + r' $\pm$ '+ str(np.around(SPIKE_STDEV[2], decimals=2))
- meanstdevstr4 = '\n' + str(np.around(SPIKE_MEANS[3], decimals=2)) + r' $\pm$ '+ str(np.around(SPIKE_STDEV[3], decimals=2))
- names = [poplabels[0]+meanstdevstr1,poplabels[1]+meanstdevstr2,poplabels[2]+meanstdevstr3,poplabels[3]+meanstdevstr4]
- Hzs_mean = np.array(SPIKE_MEANS)
- np.savetxt(os.path.join(OUTPUTPATH,'spikerates_Seed' + str(int(GLOBALSEED)) + '.txt'),Hzs_mean)
- w = 0.8
- fig2 = plt.figure(figsize=(10, 8))
- ax2 = fig2.add_subplot(111)
- ax2.bar(x = [0, 1, 2, 3],
- height=[pop for pop in SPIKE_MEANS],
- yerr=[np.std(pop) for pop in SPIKE_list],
- capsize=12,
- width=w,
- tick_label=names,
- color=[clr for clr in colors],
- edgecolor='k',
- ecolor='black',
- linewidth=4,
- error_kw={'elinewidth':3,'markeredgewidth':3})
- ax2.set_ylabel('Spike Frequency (Hz)')
- ax2.grid(False)
- return fig, fig2
- # Plot spike time histograms
- def plot_spiketimehists(SPIKES,network):
- colors = ['dimgray', 'crimson', 'green', 'darkorange']
- binsize = 10 # ms
- numbins = int((network.tstop - transient)/binsize)
- fig, axarr = plt.subplots(len(colors),1)
- for i in range(len(colors)):
- popspikes = list(itertools.chain.from_iterable(SPIKES['times'][i]))
- popspikes = [i for i in popspikes if i > transient]
- axarr[i].hist(popspikes,bins=numbins,color=colors[i],linewidth=0,edgecolor='none',range=(2000,7000))
- axarr[i].set_xlim(transient,network.tstop)
- if i < len(colors)-1:
- axarr[i].set_xticks([])
- axarr[-1:][0].set_xlabel('Time (ms)')
- return fig
- # Plot EEG & ECoG voltages & PSDs
- def plot_eeg(network,DIPOLEMOMENT):
- low = .1
- high = 100
- DP = DIPOLEMOMENT['HL23PYR']
- for pop in popnames[1:]:
- DP = np.add(DP,DIPOLEMOMENT[pop])
- EEG = EEG_args.calc_potential(DP, L23_pos)
- ECoG = ECoG_args.calc_potential(DP, L23_pos)
- EEG = EEG[0]
- ECoG = ECoG[0]
- EEG_filt = bandPassFilter(EEG[t1:])
- ECoG_filt = bandPassFilter(ECoG[t1:])
- EEG_freq, EEG_ps = ss.welch(EEG_filt[t1:], fs=sampling_rate, nperseg=nperseg)
- ECoG_freq, ECoG_ps = ss.welch(ECoG_filt[t1:], fs=sampling_rate, nperseg=nperseg)
- EEGraw_freq, EEGraw_ps = ss.welch(EEG[t1:], fs=sampling_rate, nperseg=nperseg)
- ECoGraw_freq, ECoGraw_ps = ss.welch(ECoG[t1:], fs=sampling_rate, nperseg=nperseg)
- tvec = np.arange((network.tstop)/(1000/sampling_rate)+1)*(1000/sampling_rate)
- fig = plt.figure(figsize=(10,10))
- ax1 = fig.add_subplot(221)
- ax1.plot(tvec[t1:], EEG_filt, c='k')
- ax1.set_xlim(transient,network.tstop)
- ax1.set_ylabel('EEG (mV)')
- ax2 = fig.add_subplot(222)
- ax2.plot(EEG_freq, EEG_ps, c='k')
- ax2.set_xlim(0,100)
- ax3 = fig.add_subplot(223)
- ax3.plot(tvec[t1:], ECoG_filt, c='k')
- ax3.set_xlim(transient,network.tstop)
- ax3.set_ylabel('ECoG (mV)')
- ax3.set_xlabel('Time (ms)')
- ax4 = fig.add_subplot(224)
- ax4.plot(ECoG_freq, ECoG_ps, c='k')
- ax4.set_xlim(0,100)
- ax4.set_xlabel('Frequency (Hz)')
- fig2 = plt.figure(figsize=(10,10))
- ax21 = fig2.add_subplot(221)
- ax21.plot(tvec[t1:], EEG[t1:], c='k')
- ax21.set_xlim(transient,network.tstop)
- ax21.set_ylabel('EEG (mV)')
- ax22 = fig2.add_subplot(222)
- ax22.plot(EEGraw_freq, EEGraw_ps, c='k')
- ax22.set_xlim(0,100)
- ax23 = fig2.add_subplot(223)
- ax23.plot(tvec[t1:], ECoG[t1:], c='k')
- ax23.set_xlim(transient,network.tstop)
- ax23.set_ylabel('ECoG (mV)')
- ax23.set_xlabel('Time (ms)')
- ax24 = fig2.add_subplot(224)
- ax24.plot(ECoGraw_freq, ECoGraw_ps, c='k')
- ax24.set_xlim(0,100)
- ax24.set_xlabel('Frequency (Hz)')
- return fig, fig2
- # Plot LFP voltages & PSDs
- def plot_lfp(network,OUTPUT):
- LFP1_freq, LFP1_ps = ss.welch(OUTPUT[0]['imem'][0][t1:], fs=sampling_rate, nperseg=nperseg)
- LFP2_freq, LFP2_ps = ss.welch(OUTPUT[0]['imem'][1][t1:], fs=sampling_rate, nperseg=nperseg)
- LFP3_freq, LFP3_ps = ss.welch(OUTPUT[0]['imem'][2][t1:], fs=sampling_rate, nperseg=nperseg)
- tvec = np.arange((network.tstop)/(1000/sampling_rate)+1)*(1000/sampling_rate)
- fig = plt.figure(figsize=(10,10))
- ax1 = fig.add_subplot(311)
- ax1.plot(tvec[t1:],OUTPUT[0]['imem'][0][t1:],'k')
- ax1.set_xlim(transient,network.tstop)
- ax2 = fig.add_subplot(312)
- ax2.plot(tvec[t1:],OUTPUT[0]['imem'][1][t1:],'k')
- ax2.set_ylabel('LFP (mV)')
- ax2.set_xlim(transient,network.tstop)
- ax3 = fig.add_subplot(313)
- ax3.plot(tvec[t1:],OUTPUT[0]['imem'][2][t1:],'k')
- ax3.set_xlabel('Time (ms)')
- ax3.set_xlim(transient,network.tstop)
- fig2 = plt.figure(figsize=(10,10))
- ax21 = fig2.add_subplot(311)
- ax21.plot(LFP1_freq,LFP1_ps,'k')
- ax21.set_xlim(0,100)
- ax22 = fig2.add_subplot(312)
- ax22.plot(LFP2_freq,LFP2_ps,'k')
- ax22.set_ylabel('PSD')
- ax22.set_xlim(0,100)
- ax23 = fig2.add_subplot(313)
- ax23.plot(LFP3_freq,LFP3_ps,'k')
- ax23.set_xlabel('Frequency (Hz)')
- ax23.set_xlim(0,100)
- return fig, fig2
- # Collect Somatic Voltages Across Ranks
- def somavCollect(network,cellindices,RANK,SIZE,COMM):
- if RANK == 0:
- volts = []
- gids2 = []
- for i, pop in enumerate(network.populations):
- svolts = []
- sgids = []
- for gid, cell in zip(network.populations[pop].gids, network.populations[pop].cells):
- if gid in cellindices:
- svolts.append(cell.somav)
- sgids.append(gid)
- volts.append([])
- gids2.append([])
- volts[i] += svolts
- gids2[i] += sgids
- for j in range(1, SIZE):
- volts[i] += COMM.recv(source=j, tag=15)
- gids2[i] += COMM.recv(source=j, tag=16)
- else:
- volts = None
- gids2 = None
- for i, pop in enumerate(network.populations):
- svolts = []
- sgids = []
- for gid, cell in zip(network.populations[pop].gids, network.populations[pop].cells):
- if gid in cellindices:
- svolts.append(cell.somav)
- sgids.append(gid)
- COMM.send(svolts, dest=0, tag=15)
- COMM.send(sgids, dest=0, tag=16)
- return dict(volts=volts, gids2=gids2)
- # Plot somatic voltages for each population
- def plot_somavs(network,VOLTAGES):
- tvec = np.arange(network.tstop/network.dt+1)*network.dt
- fig = plt.figure(figsize=(10,5))
- cls = ['black','crimson','green','darkorange']
- for i, pop in enumerate(network.populations):
- for v in range(0,len(VOLTAGES['volts'][i])):
- ax = plt.subplot2grid((len(VOLTAGES['volts']), len(VOLTAGES['volts'][i])), (i, v), rowspan=1, colspan=1, frameon=False)
- ax.plot(tvec,VOLTAGES['volts'][i][v], c=cls[i])
- ax.set_xlim(transient,network.tstop)
- ax.set_ylim(-85,40)
- if i < len(VOLTAGES['volts'])-1:
- ax.set_xticks([])
- if v > 0:
- ax.set_yticks([])
- return fig
- # Run Plot Functions
- if plotsomavs:
- VOLTAGES = somavCollect(network,cell_indices_to_plot,RANK,SIZE,COMM)
- N_cells = 1000
- if RANK ==0:
- if plotnetworksomas:
- fig = plot_network_somas(OUTPUTPATH)
- fig.savefig(os.path.join(OUTPUTPATH,'network_somas_'+str(GLOBALSEED)),bbox_inches='tight', dpi=300)
- if plotrasterandrates:
- fig, fig2 = plot_raster_and_rates(SPIKES,tstart_plot,tstop_plot,popnames,N_cells,network,OUTPUTPATH,GLOBALSEED)
- fig.savefig(os.path.join(OUTPUTPATH,'raster_'+str(GLOBALSEED)),bbox_inches='tight', dpi=300)
- fig2.savefig(os.path.join(OUTPUTPATH,'rates_'+str(GLOBALSEED)),bbox_inches='tight', dpi=300)
- fig = plot_spiketimehists(SPIKES,network)
- fig.savefig(os.path.join(OUTPUTPATH,'spiketimes_'+str(GLOBALSEED)),bbox_inches='tight', dpi=300)
- if plotephistimseriesandPSD:
- fig, fig2 = plot_eeg(network,DIPOLEMOMENT)
- fig.savefig(os.path.join(OUTPUTPATH,'eeg_filt_'+str(GLOBALSEED)),bbox_inches='tight', dpi=300)
- fig2.savefig(os.path.join(OUTPUTPATH,'eeg_raw_'+str(GLOBALSEED)),bbox_inches='tight', dpi=300)
- fig, fig2 = plot_lfp(network,OUTPUT)
- fig.savefig(os.path.join(OUTPUTPATH,'lfps_traces_'+str(GLOBALSEED)),bbox_inches='tight', dpi=300)
- fig2.savefig(os.path.join(OUTPUTPATH,'lfps_PSDs_'+str(GLOBALSEED)),bbox_inches='tight', dpi=300)
- if plotsomavs:
- fig = plot_somavs(network,VOLTAGES)
- fig.savefig(os.path.join(OUTPUTPATH,'somav_'+str(GLOBALSEED)),bbox_inches='tight', dpi=300)
- #===============================
- # Kant
- #===============================
- pop_colors = {'HL23PYR':'k', 'HL23PV':'red', 'HL23SST':'green', 'HL23VIP':'yellow'}
- popnames = ['HL23PYR', 'HL23PV', 'HL23SST', 'HL23VIP']
- #===============================
|