Browse Source

Remove dependence on djd and expio

* WIP, only main.py tested so far
Martin Spacek 2 years ago
parent
commit
bd6d7cec12
12 changed files with 649 additions and 1979 deletions
  1. 5 5
      fig1.py
  2. 0 127
      fig1S1.py
  3. 28 28
      fig1S33S1.py
  4. 5 5
      fig1S6.py
  5. 1 87
      fig2.py
  6. 0 2
      fig3.py
  7. 4 4
      fig4.py
  8. 5 5
      fig5.py
  9. 16 16
      fig6.py
  10. 2 2
      ipos.py
  11. 42 1642
      main.py
  12. 541 56
      util.py

+ 5 - 5
fig1.py

@@ -837,15 +837,15 @@ for kind in ['nat']:#MVIKINDS:
                           marker=marker, c=c, s=sz, lw=lw)
             # get fname of appropriate LMM .cvs file:
             if measure == 'meanburstratio': # use Steffen's LMM linregress fit, for fig1S2
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S2c_coefs.csv')
+                fname = os.path.join('stats', 'figure_1S2c_coefs.csv')
             elif measure == 'spars': # use Steffen's LMM linregress fit, for fig1S2
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S2d_coefs.csv')
+                fname = os.path.join('stats', 'figure_1S2d_coefs.csv')
             elif measure == 'rel': # use Steffen's LMM linregress fit, for fig1S2
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S2e_coefs.csv')
+                fname = os.path.join('stats', 'figure_1S2e_coefs.csv')
             elif measure == 'snr': # use Steffen's LMM linregress fit, for fig1S2
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S2f_coefs.csv')
+                fname = os.path.join('stats', 'figure_1S2f_coefs.csv')
             elif measure == 'meanpkw': # use Steffen's LMM linregress fit, for fig1S2
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S2g_coefs.csv')
+                fname = os.path.join('stats', 'figure_1S2g_coefs.csv')
             else:
                 print("WARNING: No LMM stats for %s measure" % measure)
                 fname = None

+ 0 - 127
fig1S1.py

@@ -1,127 +0,0 @@
-"""Fig 1S1 plots"""
-
-# import
-from djd.signal import rfs_for_cdata, optosupp_for_cdata
-from djd.plot import simple_plot_rfs, simple_plot_optosupp
-import matplotlib.pyplot as plt
-
-
-# specify and update matplotlib parameters
-fontsize = 6.4
-params = {
-            'font.family': 'FreeSans',
-            'font.weight': 'normal',
-            'font.size': fontsize,
-            'pdf.fonttype': 42,
-            'ps.fonttype': 42,
-            'xtick.labelsize': fontsize,
-            'ytick.labelsize': fontsize,
-            'lines.linewidth' :0.734,
-            'xtick.major.size' : 1.734,
-            'ytick.major.size' : 1.734,
-            'axes.linewidth' : 0.4,
-            'axes.titlesize': fontsize,
-            'ytick.major.width' :0.4,
-            'xtick.major.width' :0.4,
-            'savefig.dpi': 300,
-            }
-
-plt.rcParams.update(params)
-
-# optosuppression panel
-
-# define key and target channels
-v1key = {'m': 'PVCre_2018_0003', 's': 1, 'e': 1}
-exchanis = [24, 25, 26]
-
-# get opto suppression data
-ondata, offdata, t, stim_perc_red, light_diff_onset, light_dur, stim_dur = optosupp_for_cdata(
-                                                                            key=v1key,
-                                                                            pad=(-0.2, 1))
-
-# init parameters
-cmpin = 2.54
-figwidth = 11.4 / cmpin
-figheight = 15 / cmpin
-bardistance = 1.2 # distance bar to axes
-barwidth = 3
-chanh = 0.15 # subplot height
-interchanh = 0.038 # vertical space between subplots
-figsize = (figwidth, figheight)
-f = plt.figure(figsize=figsize)
-
-# define first axis
-l1 = 0.07
-b1 = 0.61
-w1 = 0.6
-h = chanh
-
-# create axis for each channel
-axlist = []
-for _ in exchanis:
-    b1 = b1 - chanh
-    ax = f.add_axes([l1, b1, w1, h])
-    b1 = b1 - interchanh
-    axlist.append(ax)
-
-# fill axes with data
-axs = simple_plot_optosupp(key=v1key, ondata=ondata, offdata=offdata, t=t, perc_red=stim_perc_red,
-                           light_diff_onset=light_diff_onset, light_dur=light_dur,
-                           stim_dur=stim_dur, barwidth=barwidth, bardistance=bardistance,
-                           axs=axlist, chanis=exchanis)
-
-# format figure
-axs[1].set_ylabel('Normalized multi unit activity')
-axs[-1].set_ylabel('')
-axs[-1].set_xticklabels((0,1,2,3))
-
-# rf panel
-
-# define keys for the two mua rfs
-ekeys = [{'m': 'PVCre_2017_0015', 's': 3, 'e':1},
-         {'m': 'PVCre_2017_0015', 's': 7, 'e':1}]
-
-# define analysis window
-rf_offset = (0.05, 0.1)
-
-# init parameters
-chanh = 0.04 # subplot height
-interchanh = 0.025 # vertical space between subplots
-w = 0.18
-h = chanh
-start_chans = [7,14] # indices of channels to start from (one for each key)
-nchans = 8 # number of channels to plot
-sep = 2 # step: plot every second channel
-
-# create one axis per channel
-axlists=[]
-for k in range(2):
-    axlist=[]
-    l = l1 + w1 + 0.1 + (0.1 * k)
-    b2 = 0.61
-    for i in range(nchans):
-        b2 = b2 - chanh
-        ax = f.add_axes([l, b2, w, h])
-        b2 = b2 - interchanh
-        axlist.append(ax)
-    axlists.append(axlist)
-
-# fill axes with data
-for ekey, axlist, start_chan in zip(ekeys, axlists, start_chans):
-    rfs, axes, chorder = rfs_for_cdata(ekey, offset=rf_offset)
-    simple_plot_rfs(ekey, rfs[start_chan:start_chan+(nchans*sep):sep,:,:,:],
-                    axes, interpolation='spline16', nrows=nchans,
-                    chorder=chorder[start_chan:start_chan+(nchans*sep)],
-                    axs=axlist, contrasts=2)
-
-    for ax in axlist:
-        ax.set_title((''))
-
-# format figure
-axlists[0][-1].yaxis.tick_left()
-axlists[0][-1].set_ylabel('Elevation')
-axlists[0][-1].set_xlabel('Azimuth')
-axlists[1][-1].set_yticklabels((''))
-axlists[0][0].set_title('Exp1')
-axlists[1][0].set_title('Exp2')
-

+ 28 - 28
fig1S33S1.py

@@ -55,14 +55,14 @@ for stimtype in STIMTYPES:
         fname = None # sanity check: clear from previous loop
         if stimtype == 'mvi':
             if measure == 'meanrate':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S3a_pred_means.csv')
+                fname = os.path.join('stats', 'figure_1S3a_pred_means.csv')
             elif measure == 'meanburstratio':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S3f_pred_means.csv')
+                fname = os.path.join('stats', 'figure_1S3f_pred_means.csv')
         elif stimtype == 'grt':
             if measure == 'meanrate':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_3S1a_pred_means.csv')
+                fname = os.path.join('stats', 'figure_3S1a_pred_means.csv')
             elif measure == 'meanburstratio':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_3S1f_pred_means.csv')
+                fname = os.path.join('stats', 'figure_3S1f_pred_means.csv')
         # fetch LMM means from .csv:
         df = pd.read_csv(fname)
         meannonsbcfmi = df['non_sbc'][0]
@@ -117,14 +117,14 @@ for stimtype in STIMTYPES:
         fname = None # sanity check: clear from previous loop
         if stimtype == 'mvi':
             if measure == 'meanrate':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S3b_pred_means.csv')
+                fname = os.path.join('stats', 'figure_1S3b_pred_means.csv')
             elif measure == 'meanburstratio':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S3g_pred_means.csv')
+                fname = os.path.join('stats', 'figure_1S3g_pred_means.csv')
         elif stimtype == 'grt':
             if measure == 'meanrate':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_3S1b_pred_means.csv')
+                fname = os.path.join('stats', 'figure_3S1b_pred_means.csv')
             elif measure == 'meanburstratio':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_3S1g_pred_means.csv')
+                fname = os.path.join('stats', 'figure_3S1g_pred_means.csv')
         # fetch LMM means from .csv:
         df = pd.read_csv(fname)
         meancorefmi = df['core'][0]
@@ -174,14 +174,14 @@ for stimtype in STIMTYPES:
         fname = None # sanity check: clear from previous loop
         if stimtype == 'mvi':
             if measure == 'meanrate':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S3b_coefs.csv')
+                fname = os.path.join('stats', 'figure_1S3b_coefs.csv')
             elif measure == 'meanburstratio':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S3g_coefs.csv')
+                fname = os.path.join('stats', 'figure_1S3g_coefs.csv')
         elif stimtype == 'grt':
             if measure == 'meanrate':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_3S1b_coefs.csv')
+                fname = os.path.join('stats', 'figure_3S1b_coefs.csv')
             elif measure == 'meanburstratio':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_3S1g_coefs.csv')
+                fname = os.path.join('stats', 'figure_3S1g_coefs.csv')
         # fetch LMM linregress fit params from .csv:
         df = pd.read_csv(fname)
         mm = df['slope'][0]
@@ -232,14 +232,14 @@ for stimtype in STIMTYPES:
         fname = None # sanity check: clear from previous loop
         if stimtype == 'mvi':
             if measure == 'meanrate':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S3b_coefs.csv')
+                fname = os.path.join('stats', 'figure_1S3b_coefs.csv')
             elif measure == 'meanburstratio':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S3g_coefs.csv')
+                fname = os.path.join('stats', 'figure_1S3g_coefs.csv')
         elif stimtype == 'grt':
             if measure == 'meanrate':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_3S1b_coefs.csv')
+                fname = os.path.join('stats', 'figure_3S1b_coefs.csv')
             elif measure == 'meanburstratio':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_3S1g_coefs.csv')
+                fname = os.path.join('stats', 'figure_3S1g_coefs.csv')
         # fetch LMM linregress fit params from .csv:
         df = pd.read_csv(fname)
         mm = df['slope'][0]
@@ -290,14 +290,14 @@ for stimtype in STIMTYPES:
         fname = None # sanity check: clear from previous loop
         if stimtype == 'mvi':
             if measure == 'meanrate':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S3c_coefs.csv')
+                fname = os.path.join('stats', 'figure_1S3c_coefs.csv')
             elif measure == 'meanburstratio':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S3h_coefs.csv')
+                fname = os.path.join('stats', 'figure_1S3h_coefs.csv')
         elif stimtype == 'grt':
             if measure == 'meanrate':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_3S1c_coefs.csv')
+                fname = os.path.join('stats', 'figure_3S1c_coefs.csv')
             elif measure == 'meanburstratio':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_3S1h_coefs.csv')
+                fname = os.path.join('stats', 'figure_3S1h_coefs.csv')
         # fetch LMM linregress fit params from .csv:
         df = pd.read_csv(fname)
         mm = df['slope'][0]
@@ -349,14 +349,14 @@ for stimtype in STIMTYPES:
         fname = None # sanity check: clear from previous loop
         if stimtype == 'mvi':
             if measure == 'meanrate':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S3d_coefs.csv')
+                fname = os.path.join('stats', 'figure_1S3d_coefs.csv')
             elif measure == 'meanburstratio':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S3i_coefs.csv')
+                fname = os.path.join('stats', 'figure_1S3i_coefs.csv')
         elif stimtype == 'grt':
             if measure == 'meanrate':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_3S1d_coefs.csv')
+                fname = os.path.join('stats', 'figure_3S1d_coefs.csv')
             elif measure == 'meanburstratio':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_3S1i_coefs.csv')
+                fname = os.path.join('stats', 'figure_3S1i_coefs.csv')
         # fetch LMM linregress fit params from .csv:
         df = pd.read_csv(fname)
         mm = df['slope'][0]
@@ -411,14 +411,14 @@ for stimtype in STIMTYPES:
         fname = None # sanity check: clear from previous loop
         if stimtype == 'mvi':
             if measure == 'meanrate':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S3e_coefs.csv')
+                fname = os.path.join('stats', 'figure_1S3e_coefs.csv')
             elif measure == 'meanburstratio':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_1S3j_coefs.csv')
+                fname = os.path.join('stats', 'figure_1S3j_coefs.csv')
         elif stimtype == 'grt':
             if measure == 'meanrate':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_3S1e_coefs.csv')
+                fname = os.path.join('stats', 'figure_3S1e_coefs.csv')
             elif measure == 'meanburstratio':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_3S1j_coefs.csv')
+                fname = os.path.join('stats', 'figure_3S1j_coefs.csv')
         # fetch LMM linregress fit params from .csv:
         df = pd.read_csv(fname)
         mm = df['slope'][0]

+ 5 - 5
fig1S6.py

@@ -3,7 +3,7 @@
 ## Redo scatter plots from fig1.py, but only for movie experiments with insignificant
 ## effect of opto on pupil area, as determined by LMM from trial-wise pupil data:
 assert EXPTYPE == 'pvmvis'
-insigmsefname = os.path.join(PAPERPATH, 'stats', 'insig_pupil_diam_exps.csv')
+insigmsefname = os.path.join('stats', 'insig_pupil_diam_exps.csv')
 df = pd.read_csv(insigmsefname)
 insigmsestrs = list(df['mseustr'].values)
 # filter out units from movie experiments with significant opto effects on pupil area:
@@ -126,9 +126,9 @@ for stimtype, stimtypei in zip(STIMTYPESPLUSALL, STIMTYPESPLUSALLI):
     fname = None
     if stimtype == 'mvi+grt':
         if EXPTYPE == 'pvmvis':
-            fname = os.path.join(PAPERPATH, 'stats', 'figure_1_S6g_coefs.csv')
+            fname = os.path.join('stats', 'figure_1_S6g_coefs.csv')
         elif EXPTYPE == 'ntsrmvis':
-            fname = os.path.join(PAPERPATH, 'stats', 'figure_1_S6h_coefs.csv')
+            fname = os.path.join('stats', 'figure_1_S6h_coefs.csv')
     if fname:
         try:
             df = pd.read_csv(fname)
@@ -166,9 +166,9 @@ for stimtype, stimtypei in zip(STIMTYPESPLUSALL, STIMTYPESPLUSALLI):
     # get fname of appropriate LMM .cvs file:
     if stimtype == 'mvi+grt':
         if EXPTYPE == 'pvmvis':
-            fname = os.path.join(PAPERPATH, 'stats', 'figure_1_S6i_coefs.csv')
+            fname = os.path.join('stats', 'figure_1_S6i_coefs.csv')
         elif EXPTYPE == 'ntsrmvis':
-            fname = os.path.join(PAPERPATH, 'stats', 'figure_1_S6j_coefs.csv')
+            fname = os.path.join('stats', 'figure_1_S6j_coefs.csv')
     if fname:
         try:
             df = pd.read_csv(fname)

+ 1 - 87
fig2.py

@@ -164,92 +164,6 @@ for mseustr, exmpli in mvimseu2exmpli.items(): #mvimseustrs:
             a.spines['left'].set_position(('outward', 4))
             a.spines['bottom'].set_position(('outward', 4))
 
-'''
-## TODO: plot intercept normalized by ctrl peak, or somehow, for unitless intercepts distrib
-## TODO: highlight 3rd example neuron in all scatter plots
-## TODO: add burst distrib as well
-# strip plot distributions of linear fit params, keep only those fits with a decent rsq:
-RSQTHRESH = 0.2#0.4
-keepis = (fits['rsq'] >= RSQTHRESH).values
-keepfits = fits.loc[keepis]
-resetkeepfits = keepfits.reset_index() # convert mi to columns for sns
-figsize = DEFAULTFIGURESIZE
-for kind in MVIKINDS:
-    for par in ['slope', 'intercept', 'rsq']:
-        rowis = (resetkeepfits['kind'] == kind) # boolean pandas Series
-        if not rowis.any():
-            print('No data to plot for par, kind = %s, %s, skipping' % (par, kind))
-            continue
-        data = resetkeepfits[rowis]
-        # do paired t-test for par:
-        runvals, sitvals, nonevals = [], [], []
-        submseustrs = data['mseu'].unique()
-        for mseustr in submseustrs:
-            rows = keepfits.loc[mseustr, kind]
-            if not np.all([ st8 in rows.index for st8 in ALLST8S ]): # not all st8s in rows
-                continue
-            if rows.isna().any().any(): # two any()'s: one for st8, one for fit param
-                continue
-            # rows has non-NaN entries for run, sit and none
-            runvals.append(rows.loc['run'][par])
-            sitvals.append(rows.loc['sit'][par])
-            nonevals.append(rows.loc['none'][par])
-        nunits = len(runvals) # number of units that survived
-        assert nunits == len(sitvals) == len(nonevals)
-        if nunits == 0:
-            print('No data to plot for par, kind = %s, %s, skipping' % (par, kind))
-            continue
-        t, p = ttest_rel(runvals, sitvals) # paired t-test
-        # make a strip plot:
-        f, a = plt.subplots(figsize=figsize)
-        titlestr = 'PSTH %s fit %s %s rsqthresh=%.1f' % (model, kind, par, RSQTHRESH)
-        wintitle(titlestr, f)
-        sns.stripplot(x="st8", y=par, data=data, palette=st82clr, jitter=True, size=3)
-        a.set_xlabel('')
-        ylabel = par
-        if ylabel == 'intercept':
-            ylabel += ' (spk/s)'
-        ylabel = ylabel.capitalize()
-        a.set_ylabel(ylabel)
-        a.add_artist(AnchoredText('p=%.1e' % p, loc='upper left', frameon=False))
-        # connect the dots:
-        x = np.array([[0]*nunits, [1]*nunits, [2]*nunits])
-        y = np.array([runvals, sitvals, nonevals])
-        a.plot(x, y, '-', c='k', alpha=0.2, lw=1)
-        # due to jitter, dots don't perfectly connect. Can get actual data using:
-        #runxy, sitxy = a.collections[0].get_offsets(), a.collections[1].get_offsets()
-        # but some of those points aren't paired, which makes it complicated
-        # plot mean with short horizontal lines:
-        meanrun, meansit, meannone = y.mean(axis=1)
-        a.plot([-0.25, 0.25], [meanrun, meanrun], '-', lw=2, c=st82clr['run'], zorder=np.inf)
-        a.plot([0.75, 1.25], [meansit, meansit], '-', lw=2, c=st82clr['sit'], zorder=np.inf)
-        a.plot([1.75, 2.25], [meannone, meannone], '-', lw=2, c=st82clr['none'], zorder=np.inf)
-        # test if slope and intercept are significantly different from 1 and 0, respectively:
-        if par == 'slope':
-            a.axhline(y=1, color='gray', ls='--', marker='')
-            runt1samp, runp1samp = ttest_1samp(runvals, 1)
-            sitt1samp, sitp1samp = ttest_1samp(sitvals, 1)
-            nonet1samp, nonep1samp = ttest_1samp(nonevals, 1)
-            print('%s slope != 1 test: runp1samp=%.2e, sitp1samp=%.2e, nonep1samp=%.2e'
-                  % (kind, runp1samp, sitp1samp, nonep1samp))
-            a.set_ylim(top=1.6)
-        elif par == 'intercept':
-            a.axhline(y=0, color='gray', ls='--', marker='')
-            runt1samp, runp1samp = ttest_1samp(runvals, 0)
-            sitt1samp, sitp1samp = ttest_1samp(sitvals, 0)
-            nonet1samp, nonep1samp = ttest_1samp(nonevals, 0)
-            print('%s intercept != 0 test: runp1samp=%.2e, sitp1samp=%.2e, nonep1samp=%.2e'
-                  % (kind, runp1samp, sitp1samp, nonep1samp))
-            a.set_ylim(top=7, bottom=-7)
-        # plot PDFs for nat movie none st8 only:
-        if kind == 'nat':
-            f, pdfa = plt.subplots(figsize=figsize)
-            titlestr = 'PSTH %s fit %s none %s PDF rsqthresh=%.1f' % (model, kind, par, RSQTHRESH)
-            wintitle(titlestr, f)
-            pdfa.hist(nonevals, bins=20, density=False, color=st82clr['none'])
-            pdfa.set_xlabel(ylabel)
-            pdfa.set_ylabel('Unit count')
-'''
 
 # scatter plot resampled model slope vs threshold, for only those fits with decent SNR
 # in at least one opto state:
@@ -404,7 +318,7 @@ for kind in ['nat']:#MVIKINDS:
                           marker=marker, c=c, s=sz, lw=lw)
             # get fname of appropriate LMM .cvs file:
             if fmode == '': # use Steffen's LMM linregress fit, for fig2f
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_2f_coefs.csv')
+                fname = os.path.join('stats', 'figure_2f_coefs.csv')
             # fetch LMM linregress fit params from .csv:
             df = pd.read_csv(fname)
             mm = df['slope'][0]

+ 0 - 2
fig3.py

@@ -338,8 +338,6 @@ for st8 in ['none']:#ALLST8S:
     a.set_aspect('equal')
     a.spines['left'].set_position(('outward', 4))
     a.spines['bottom'].set_position(('outward', 4))
-    #z, p = rayleigh_test(dphis) # test if dphis is uniform, not sure that's useful here though
-    #a.add_artist(AnchoredText('p$=$%.2g' % p, loc='upper left', frameon=False))
     ## plot maxori delta phi PDF:
     figsize = DEFAULTFIGURESIZE
     f, a = plt.subplots(figsize=figsize)

+ 4 - 4
fig4.py

@@ -144,13 +144,13 @@ for mvimeasure, grtmeasure in zip(mvimeasures, grtmeasures):
                       color='None', edgecolor='black', size=np.sqrt(50))
         # get fname of appropriate LMM .cvs file:
         if mvimeasure == 'meanrate' and grtmeasure == 'meanrate':
-            fname = os.path.join(PAPERPATH, 'stats', 'figure_4a_pred_means.csv')
+            fname = os.path.join('stats', 'figure_4a_pred_means.csv')
         elif mvimeasure == 'blankmeanrate' and grtmeasure == 'blankmeanrate':
-            fname = os.path.join(PAPERPATH, 'stats', 'figure_4b_pred_means.csv')
+            fname = os.path.join('stats', 'figure_4b_pred_means.csv')
         elif mvimeasure == 'meanburstratio' and grtmeasure == 'meanburstratio':
-            fname = os.path.join(PAPERPATH, 'stats', 'figure_4_S1e_pred_means.csv')
+            fname = os.path.join('stats', 'figure_4_S1e_pred_means.csv')
         elif mvimeasure == 'blankmeanburstratio' and grtmeasure == 'blankmeanburstratio':
-            fname = os.path.join(PAPERPATH, 'stats', 'figure_4_S1f_pred_means.csv')
+            fname = os.path.join('stats', 'figure_4_S1f_pred_means.csv')
         else:
             print("WARNING: No LMM stats for (mvimeasure=%s, grtmeasure=%s)"
                   % (mvimeasure, grtmeasure))

+ 5 - 5
fig5.py

@@ -577,15 +577,15 @@ for kind in ['nat']:#MVIKINDS:
                           marker=marker, c=c, s=sz, lw=lw)
             # get fname of appropriate LMM .cvs file:
             if measure == 'meanburstratio': # use Steffen's LMM linregress fit, for fig5S1
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_5S1c_coefs.csv')
+                fname = os.path.join('stats', 'figure_5S1c_coefs.csv')
             elif measure == 'spars': # use Steffen's LMM linregress fit, for fig5S1
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_5S1d_coefs.csv')
+                fname = os.path.join('stats', 'figure_5S1d_coefs.csv')
             elif measure == 'rel': # use Steffen's LMM linregress fit, for fig5S1
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_5S1e_coefs.csv')
+                fname = os.path.join('stats', 'figure_5S1e_coefs.csv')
             elif measure == 'snr': # use Steffen's LMM linregress fit, for fig5S1
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_5S1f_coefs.csv')
+                fname = os.path.join('stats', 'figure_5S1f_coefs.csv')
             elif measure == 'meanpkw': # use Steffen's LMM linregress fit, for fig5S1
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_5S1g_coefs.csv')
+                fname = os.path.join('stats', 'figure_5S1g_coefs.csv')
             else:
                 print("WARNING: No LMM stats for %s measure" % measure)
                 fname = None

+ 16 - 16
fig6.py

@@ -78,13 +78,13 @@ for measure in modmeasuresnoblank:
         if (measure in relev_measures) and stimtype == 'mvi':
             # if measure is relevant, get parameters for model from csv file:
             if measure == 'meanrate':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_6a1_coefs.csv')
+                fname = os.path.join('stats', 'figure_6a1_coefs.csv')
             elif measure == 'meanburstratio':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_6a2_coefs.csv')
+                fname = os.path.join('stats', 'figure_6a2_coefs.csv')
             elif measure == 'spars':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_6a3_coefs.csv')
+                fname = os.path.join('stats', 'figure_6a3_coefs.csv')
             elif measure == 'rel':
-                fname = os.path.join(PAPERPATH, 'stats', 'figure_6a4_coefs.csv')
+                fname = os.path.join('stats', 'figure_6a4_coefs.csv')
             try:
                 df = pd.read_csv(fname)
                 foundregression = True
@@ -219,18 +219,18 @@ for measure in modmeasuresnoblank:
             # if measure is relevant, get parameters for model from csv file:
             if stimtype == 'mvi':
                 if measure == 'meanrate':
-                    fname = os.path.join(PAPERPATH, 'stats', 'figure_6b1_coefs.csv')
+                    fname = os.path.join('stats', 'figure_6b1_coefs.csv')
                 elif measure == 'meanburstratio':
-                    fname = os.path.join(PAPERPATH, 'stats', 'figure_6b2_coefs.csv')
+                    fname = os.path.join('stats', 'figure_6b2_coefs.csv')
                 elif measure == 'spars':
-                    fname = os.path.join(PAPERPATH, 'stats', 'figure_6b3_coefs.csv')
+                    fname = os.path.join('stats', 'figure_6b3_coefs.csv')
                 elif measure == 'rel':
-                    fname = os.path.join(PAPERPATH, 'stats', 'figure_6b4_coefs.csv')
+                    fname = os.path.join('stats', 'figure_6b4_coefs.csv')
             elif stimtype == 'grt':
                 if measure == 'meanrate':
-                    fname = os.path.join(PAPERPATH, 'stats', 'figure_6_S1b1_coefs.csv')
+                    fname = os.path.join('stats', 'figure_6_S1b1_coefs.csv')
                 elif measure == 'meanburstratio':
-                    fname = os.path.join(PAPERPATH, 'stats', 'figure_6_S1b2_coefs.csv')
+                    fname = os.path.join('stats', 'figure_6_S1b2_coefs.csv')
             try:
                 df = pd.read_csv(fname)
                 foundregression = True
@@ -374,18 +374,18 @@ for measure in modmeasuresnoblank:
             # if measure is relevant, get parameters for model from csv file:
             if stimtype == 'mvi':
                 if measure == 'meanrate':
-                    fname = os.path.join(PAPERPATH, 'stats', 'figure_6c1_coefs.csv')
+                    fname = os.path.join('stats', 'figure_6c1_coefs.csv')
                 elif measure == 'meanburstratio':
-                    fname = os.path.join(PAPERPATH, 'stats', 'figure_6c2_coefs.csv')
+                    fname = os.path.join('stats', 'figure_6c2_coefs.csv')
                 elif measure == 'spars':
-                    fname = os.path.join(PAPERPATH, 'stats', 'figure_6c3_coefs.csv')
+                    fname = os.path.join('stats', 'figure_6c3_coefs.csv')
                 elif measure == 'rel':
-                    fname = os.path.join(PAPERPATH, 'stats', 'figure_6c4_coefs.csv')
+                    fname = os.path.join('stats', 'figure_6c4_coefs.csv')
             elif stimtype == 'grt':
                 if measure == 'meanrate':
-                    fname = os.path.join(PAPERPATH, 'stats', 'figure_6_S1c1_coefs.csv')
+                    fname = os.path.join('stats', 'figure_6_S1c1_coefs.csv')
                 elif measure == 'meanburstratio':
-                    fname = os.path.join(PAPERPATH, 'stats', 'figure_6_S1c2_coefs.csv')
+                    fname = os.path.join('stats', 'figure_6_S1c2_coefs.csv')
             try:
                 df = pd.read_csv(fname)
                 foundregression = True

+ 2 - 2
ipos.py

@@ -76,7 +76,7 @@ wintitle('rel FMI vs ipos stdev FMI')
 a.scatter(iposfmis[valid], relfmis[valid], clip_on=False,
           marker='.', c='None', edgecolor='black', s=DEFSZ)
 # plot regression:
-fname = os.path.join(PAPERPATH, 'stats', 'figure_1S2i_coefs.csv')
+fname = os.path.join('stats', 'figure_1S2i_coefs.csv')
 try:
     df = pd.read_csv(fname)
     foundregression = True
@@ -194,7 +194,7 @@ wintitle('rel RMI vs ipos stdev RMI')
 a.scatter(iposrmis[valid], relrmis[valid], clip_on=False,
           marker='.', c='None', edgecolor='black', s=DEFSZ)
 # plot regression:
-fname = os.path.join(PAPERPATH, 'stats', 'figure_5S1i_coefs.csv')
+fname = os.path.join('stats', 'figure_5S1i_coefs.csv')
 try:
     df = pd.read_csv(fname)
     foundregression = True

File diff suppressed because it is too large
+ 42 - 1642
main.py


+ 541 - 56
util.py

@@ -4,14 +4,16 @@ import os
 import pickle
 
 import numpy as np
+from numpy import pi
 import pandas as pd
 from scipy.stats import linregress
 from scipy.optimize import curve_fit
+import scipy.ndimage.filters as filters
 
+import matplotlib.pyplot as plt
+from matplotlib import colors
 from matplotlib.ticker import FuncFormatter
-
-from djd.plot import mixalpha
-from djd.model import threshlin, rsquared
+from matplotlib.offsetbox import AnchoredText
 
 RESPDFNAMES = ['mviresp', 'grtresp', 'sponresp', 'bestmviresp', 'bestgrtresp', 'grttunresp']
 FITSDFNAMES = ['fits', 'sampfits']
@@ -26,59 +28,6 @@ EXPORTDFNAMES = MIDFNAMES + FIGDFNAMES # to .csv
 STRNAMES = ['mvimseustrs', 'grtmseustrs', 'mvigrtmsustrs']
 
 
-def get_exps(ns, name):
-    """Return a specific set of experiments, given a namespace ns (e.g. locals()) and `name`"""
-    m, s, e, u = ns['m'], ns['s'], ns['e'], ns['u'] # djd tables
-    if name == 'pvmvis':
-        # all movie opto experiments in LGN in PVCre mice that have been sorted,
-        # whether single movie or multiple interleaved movies, natural or pink or white:
-        ## TODO: LB suggests restricting by 'e_optoampl IS NOT NULL' instead of optowl
-        exps = ((e & 'e_name LIKE "MAS_%%"' & 'e_optowl IN (465, 470)')
-                #& (s & 's_depth > 2500' & 's_depth < 3500') # upper limit excludes TRN series
-                & (s & 's_region = "LGN"')
-                & (m & 'm_strain = "PV-Cre"')
-                & u)
-        exps = exps - {'m':'PVCre_2019_0001', 's':6} # poor RFs, upward pointing pupil
-    elif name == 'ntsrmvis':
-        # all movie opto experiments in LGN in Nstr1Cre mice that have been sorted,
-        # whether single movie or multiple interleaved movies, natural or pink or white:
-        exps = ((e & 'e_name LIKE "MAS_%%"' & 'e_optowl IN (465, 470)')
-                #& (s & 's_depth > 2500' & 's_depth < 3500') # upper limit excludes TRN series
-                & (s & 's_region = "LGN"')
-                & (m & 'm_strain = "Ntsr1-Cre"')
-                & u)
-        # restrict to only very good series:
-        goodseries = [{'m':'Ntsr1Cre_2019_0002', 's':3}, # recommended by YB
-                      {'m':'Ntsr1Cre_2019_0002', 's':5}, # recommended by YB
-                      #{'m':'Ntsr1Cre_2019_0003', 's':4}, # removed by YB due to bad sorting
-                      {'m':'Ntsr1Cre_2019_0007', 's':4}, # MAS sorted
-                      {'m':'Ntsr1Cre_2019_0007', 's':6}, # GB sorted
-                      {'m':'Ntsr1Cre_2019_0008', 's':3}, # recommended by YB
-                      {'m':'Ntsr1Cre_2019_0008', 's':5}, # MAS sorted
-                      {'m':'Ntsr1Cre_2019_0008', 's':6}, # SR sorted, after removing u53...
-                      {'m':'Ntsr1Cre_2019_0008', 's':7}, # GB sorted
-                     ]
-        exps = exps & goodseries
-    elif name == 'negntsrmvis':
-        # all movie opto experiments in LGN in Nstr1Cre negative mice that have been sorted,
-        # whether single movie or multiple interleaved movies, natural or pink or white:
-        exps = ((e & 'e_name LIKE "MAS_%%"' & 'e_optowl IN (465, 470)')
-                #& (s & 's_depth > 2500' & 's_depth < 3500') # upper limit excludes TRN series
-                & (s & 's_region = "LGN"')
-                & (m & {'m_strain':'Ntsr1-Cre'} & {'m_genotype':'-/-'})
-                & u)
-    elif name == 'pgnpvmvis':
-        # all movie opto experiments in PGN in PVCre mice that have been sorted,
-        # whether single movie or multiple interleaved movies, natural or pink or white:
-        exps = ((e & 'e_name LIKE "MAS_%%"' & 'e_optowl IN (465, 470)')
-                #& (s & 's_depth > 2500' & 's_depth < 3500') # upper limit excludes TRN series
-                & (s & 's_region = "PGN"')
-                & (m & 'm_strain = "PV-Cre"')
-                & u)
-    else:
-        raise ValueError("Unknown exps name %r" % name)
-    return exps
-
 def load(name, subfolder=''):
     """Return variable (e.g. a DataFrame) from a pickle"""
     path = os.path.dirname(__file__)
@@ -199,6 +148,14 @@ def desat(hexcolor, alpha):
     lower layer"""
     return mixalpha(hexcolor, alpha)
 
+def mixalpha(hexcolor, alpha=1, bg='#ffffff'):
+    """Mix alpha into hexcolor, assuming background color.
+    See https://stackoverflow.com/a/21576659/2020363"""
+    rgb = np.array(colors.hex2color(hexcolor)) # convert to float RGB array
+    bg = np.array(colors.hex2color(bg))
+    rgb = alpha*rgb + (1 - alpha)*bg # mix it
+    return colors.rgb2hex(rgb)
+
 def axes_disable_scientific(axes, axiss=None):
     """Disable scientific notation for both axes labels, useful for log-log plots.
     See https://stackoverflow.com/a/49306588/3904031"""
@@ -284,6 +241,23 @@ def fitmodel(ctrlfit, optofit, ctrltest, optotest, model=None):
         raise ValueError("Unknown model %r" % model)
     return mm, b, rsq
 
+def threshlin(x, m, b):
+    """Return threshold linear model"""
+    y = m * x + b
+    y[y < 0] = 0
+    return y
+
+def rsquared(targets, predictions):
+    """Return the r-squared value for the fit"""
+    residuals = targets - predictions
+    residual_variance = np.sum(residuals**2)
+    variance_of_targets = np.sum((targets - np.mean(targets))**2)
+    if variance_of_targets == 0:
+        rsq = np.nan
+    else:
+        rsq = 1 - (residual_variance / variance_of_targets)
+    return rsq
+
 def residual_rsquared(targets, residuals):
     """Return the r-squared value for the fit, given the target values and residuals.
     Minor variation of djd.model.rsquared()"""
@@ -312,3 +286,514 @@ def get_max_snr(mvirespr, mseustr, kind, st8):
     supsnr = mvirows[mvirows['opto'] == True]['snr'].iloc[0] # suppression
     maxsnr = max(fbsnr, supsnr) # take the max of the two conditions
     return maxsnr
+
+def intround(n):
+    """Round to the nearest integer, return an integer. Works on arrays.
+    Saves on parentheses, nothing more"""
+    if np.iterable(n): # it's a sequence, return as an int64 array
+        return np.int64(np.round(n))
+    else: # it's a scalar, return as normal Python int
+        return int(round(n))
+
+def split_tranges(tranges, width, tres):
+    """Split up tranges into lots of smaller (typically overlapping) tranges, with width and
+    tres. Usually, tres < width, but this also works for width < tres.
+    Test with:
+
+    print(split_tranges([(0,100)], 1, 10))
+    print(split_tranges([(0,100)], 10, 1))
+    print(split_tranges([(0,100)], 10, 10))
+    print(split_tranges([(0,100)], 10, 8))
+    print(split_tranges([(0,100)], 3, 10))
+    print(split_tranges([(0,100)], 10, 3))
+    print(split_tranges([(0,100)], 3, 8))
+    print(split_tranges([(0,100)], 8, 3))
+    """
+    newtranges = []
+    for trange in tranges:
+        t0, t1 = trange
+        assert width < (t1 - t0)
+        # calculate left and right edges of subtranges that fall within trange:
+        # This is tricky: find maximum left edge such that the corresponding maximum right
+        # edge goes as close as possible to t1 without exceeding it:
+        tend = (t1-width+tres) // tres*tres # there might be a nicer way, but this works
+        ledges = np.arange(t0, tend, tres)
+        redges = ledges + width
+        subtranges = [ (le, re) for le, re in zip(ledges, redges) ]
+        newtranges.append(subtranges)
+    return np.vstack(newtranges)
+
+def wrap_raster(raster, t0, t1, newdt, offsets=[0, 0]):
+    """Extract event times in raster (list or array of arrays) between t0 and t1 (s),
+    and wrap into extra rows such that event times never exceed newdt (s)"""
+    t1floor = t1 - t1 % newdt
+    t0s = np.arange(t0, t1floor, newdt) # end exclusive
+    t1s = t0s + newdt
+    tranges = np.column_stack([t0s, t1s])
+    wrappedraster = []
+    for row in raster:
+        dst = [] # init list to collect events for this row
+        for trange in tranges:
+            # search within trange, but take into account desired offsets:
+            si0, si1 = row.searchsorted(trange + offsets)
+            # get spike times relative to start of trange:
+            dst.append(row[si0:si1] - trange[0])
+        # convert from list to object array to enable fancy indexing:
+        wrappedraster.extend(dst)
+    return np.asarray(wrappedraster)
+
+def cf():
+    """Close all figures"""
+    plt.close('all')
+
+def saveall(path=None, format='png'):
+    """Save all open figures to chosen path, pop up dialog box if path is None"""
+    if path is None: # query with dialog box for a path
+        from matplotlib import rcParams
+        startpath = os.path.expanduser(rcParams['savefig.directory']) # get default
+        path = choose_path(startpath, msg="Choose a folder to save to")
+        if not path: # dialog box was cancelled
+            return # don't do anything
+        rcParams['savefig.directory'] = path # update default
+    fs = [ plt.figure(i) for i in plt.get_fignums() ]
+    for f in fs:
+        fname = f.canvas.get_window_title() + '.' + format
+        fname = fname.replace(' ', '_')
+        fullfname = os.path.join(path, fname)
+        print(fullfname)
+        f.savefig(fullfname)
+
+def lastcmd():
+    """Return a string containing the last command entered by the user in the
+    Ipython shell. Useful for generating plot titles"""
+    ip = get_ipython()
+    return ip._last_input_line
+
+def wintitle(titlestr=None, f=None):
+    """Set title of current MPL window, defaults to last command entered"""
+    if titlestr is None:
+        titlestr = lastcmd()
+    if f is None:
+        f = plt.gcf()
+    f.canvas.set_window_title(titlestr)
+
+def simpletraster(raster, dt=5, offsets=[0, 0], s=1, clr='k',
+                  scatter=False, scattermarker='|', scattersize=10,
+                  burstis=None, burstclr='r',
+                  axisbg='w', alpha=1, inchespersec=1.5, inchespertrial=1/25,
+                  ax=None, figsize=None, title=False, xaxis=True, label=None):
+    """
+    Create a simple trial raster plot. Each entry in raster is a list of spike times
+    relative to the start of each trial.
+    dt : trial duration (s)
+    offsets : offsets relative to trial start and end (s)
+    s : tick linewidths
+    clr : tick color, either a single color or a sequence of colors, one per trial
+    scatter : whether to use original ax.scatter() command to plot much faster and use much
+              less memory, but with potentially vertically overlapping ticks. Otherwise,
+              default to slower ax.eventplot()
+    burstis : burst indices, as returned by FiringPattern().burst_ratio()
+    """
+    ntrials = len(raster)
+    spiketrialis, c = [], []
+    # get raster tick color of each trial:
+    if type(clr) == str: # all trials have the same color
+        clr = list(colors.to_rgba(clr))
+        clr[3] = alpha # apply alpha, so that we can control alpha per tick
+        trialclrs = [clr]*ntrials
+    else: # each trial has potentially a different color
+        assert type(clr) in [list, np.ndarray]
+        assert len(clr) == ntrials
+        trialclrs = []
+        for trialclr in clr:
+            trialclr = list(colors.to_rgba(trialclr))
+            trialclr[3] = alpha # apply alpha, so that we can control alpha per tick
+            trialclrs.append(trialclr)
+    burstclr = colors.to_rgba(burstclr) # keep full saturation for burst spikes
+    # collect 1-based trial info, one entry per spike:
+    for triali, rastertrial in enumerate(raster):
+        nspikes = len(rastertrial)
+        spiketrialis.append(np.tile(triali+1, nspikes)) # 1-based
+        trialclr = trialclrs[triali]
+        spikecolors = np.tile(trialclr, (nspikes, 1))
+        if burstis is not None:
+            bis = burstis[triali]
+            if len(bis) > 0:
+                spikecolors[bis] = burstclr
+        c.append(spikecolors)
+
+    # convert each list of arrays to a single flat array:
+    raster = np.hstack(raster)
+    spiketrialis = np.hstack(spiketrialis)
+    c = np.concatenate(c)
+    xmin, xmax = offsets[0], dt + offsets[1]
+    totaldt = xmax - xmin # total raster duration, including offsets
+
+    if ax == None:
+        if figsize is None:
+            figwidth = min(1 + totaldt*inchespersec, 12)
+            figheight = min(1 + ntrials*inchespertrial, 12)
+            figsize = figwidth, figheight
+        f, ax = plt.subplots(figsize=figsize)
+
+    if scatter:
+        # scatter doesn't carefully control vertical spacing, allows vertical overlap of ticks:
+        ax.scatter(raster, spiketrialis, marker=scattermarker, c=c, s=scattersize, label=label)
+    else:
+        # eventplot is slower, but does a better job:
+        raster = raster[:, np.newaxis] # somehow eventplot requires an extra unitary dimension
+        if len(raster) == 0:
+            print("No spikes for eventplot %r" % title) # prevent TypeError from eventplot()
+        else:
+            ax.eventplot(raster, lineoffsets=spiketrialis, colors=c, linewidth=s, label=label)
+    ax.set_xlim(xmin, xmax)
+    # -1 inverts the y axis, +1 ensures last trial is fully visible:
+    ax.set_ylim(ntrials+1, -1)
+    ax.set_facecolor(axisbg)
+    ax.set_xlabel('Time (s)')
+    ax.set_ylabel('Trial')
+    if label:
+        ax.legend(loc="best")
+
+    if title:
+        #a.set_title(title)
+        wintitle(title)
+
+    if xaxis != True:
+        if xaxis == False:
+            renderer = f.canvas.get_renderer()
+            bbox = a.xaxis.get_tightbbox(renderer).transformed(f.dpi_scale_trans.inverted())
+            xaxis = bbox.height
+        figheight = figheight - xaxis
+        ax.get_xaxis().set_visible(False)
+        ax.spines['bottom'].set_visible(False)
+        f.set_figheight(figheight)
+
+    #f.tight_layout(pad=0.3) # crop figure to contents, doesn't seem to do anything any more
+    #f.show()
+    return ax
+
+def raster2psth(raster, bins, binw, tres, kernel='gauss'):
+    """Convert a spike trial raster to a peri-stimulus time histogram (PSTH).
+    To calculate the PSTH of a subset of trials, pass a raster containing only that subset.
+
+    Parameters
+    ----------
+    raster : spike trial raster as a sequence of arrays of spike times (s), one array per trial
+    bins : 2D array of start and stop PSTH bin edge times (s), one row per bin.
+           Bins may or may not be overlapping. Typically generated using util.split_tranges()
+    binw : PSTH bin width (s) that was used to generate bins
+    tres : temporal resolution (s) that was used to generate bins, only used if kernel=='gauss'
+    kernel : smoothing kernel : None or 'gauss'
+
+    Returns
+    -------
+    psth : peri-stimulus time histogram (Hz), normalized by bin width and number of trials
+    """
+    # make sure raster has nested iterables, i.e. list of arrays, or array of arrays, etc.,
+    # even if there's only one array inside raster representing only one trial:
+    if len(raster) > 0: # not an empty raster
+        trial0 = raster[0]
+        if type(trial0) not in (np.ndarray, list):
+            raise ValueError("Ensure that raster is a sequence of arrays of spike times,\n"
+                             "one per trial. If you're passing only a single extracted trial,\n"
+                             "make sure to pass it within e.g. a list of length 1")
+    # now it's safe to assume that len(raster) represents the number of included trials,
+    # and not erroneuosly the number of spikes in a single unnested array of spike times:
+    ntrials = len(raster)
+    if ntrials == 0: # empty raster
+        spikes = np.asarray(raster)
+    else:
+        spikes = np.hstack(raster) # flatten across trials
+    spikes.sort()
+    spikeis = spikes.searchsorted(bins) # where bin edges fall in spikes
+    # convert to rate: number of spikes in each bin, normalized by binw:
+    psth = (spikeis[:, 1] - spikeis[:, 0]) / binw
+    if kernel is None: # rectangular bins
+        pass
+    elif kernel == 'gauss': # apply Gaussian filtering
+        sigma = binw / 2 # set sigma to half the bin width (sec)
+        sigmansamples = sigma / tres # sigma as multiple of number of samples (unitless)
+        psth = filters.gaussian_filter1d(psth, sigma=sigmansamples)
+    else:
+        raise ValueError('Unknown kernel %r' % kernel)
+    # normalize by number of trials:
+    if ntrials != 0:
+        psth = psth / ntrials
+    return psth
+
+def raster2freqcomp(raster, dt, f, mean='scalar'):
+    """Extract a frequency component from spike raster (one row of spike times per trial).
+    Adapted from getHarmsResps.m and UnitGetHarm.m
+
+    Parameters
+    ----------
+    raster : spike raster as a sequence of arrays of spike times (s), one array per trial
+    dt : trial duration (s)
+    f : frequency to extract (Hz), f=0 extracts mean firing rate
+    mean : 'scalar': compute mean of amplitudes of each trial's vector (mean(abs)), i.e. find
+           frequency component f separately for each trial, then take average amplitude.
+           'vector': compute mean across all trials before calculating amplitude (abs(mean)),
+           equivalent to first calculating PSTH from all rows of raster
+
+    Returns
+    -------
+    r : peak-to-peak amplitude of frequency component f
+    theta : angle of frequency component f (rad)
+
+    Examples
+    --------
+    >>> inphase = np.array([0, 1, 2, 3, 4]) # spike times (s)
+    >>> outphase = np.array([0.5, 1.5, 2.5, 3.5, 4.5]) # spike times (s)
+    >>> raster2freqcomp([inphase], 5, 1) # single trial, 'mean' is irrelevant
+    (2.0, -4.898587196589412e-16)
+    >>> raster2freqcomp([outphase], 5, 1)
+    (2.0, 3.1415926535897927)
+    >>> raster2freqcomp([inphase, outphase], 5, 1, mean='scalar')
+    (2.0, 1.5707963267948961)
+    >>> raster2freqcomp([inphase, outphase], 5, 1, mean='vector')
+    (1.2246467991473544e-16, 1.5707963267948966)
+
+    Using f=0 returns mean firing rate:
+    >>> raster2freqcomp([inphase, outphase], 5, 0, mean='scalar')
+    (1.0, 0.0)
+    >>> raster2freqcomp([inphase, outphase], 5, 0, mean='vector')
+    (1.0, 0.0)
+    """
+    ntrials = len(raster)
+    res, ims = np.zeros(ntrials), np.zeros(ntrials) # init real and imaginary components
+    for triali, spikes in enumerate(raster): # iterate over trials
+        if len(spikes) == 0:
+            continue
+        spikes = np.asarray(spikes) # in case raster is a list of lists
+        if spikes.max() > dt:
+            print('spikes exceeding dt:', spikes[spikes > dt])
+        # discard rare spikes in raster that for some reason (screen vsyncs?) fall outside
+        # the expected trial duration:
+        spikes = spikes[spikes <= dt]
+        ## TODO: getHarmResps.m only ever used an integer number of cycles, but why? Can't we
+        ## instead use the exact fractional number of cycles that the spikes occupy?
+        '''
+        ## UNTESTED:
+        if f != 0:
+            period = 1 / f
+            nperiods = int(np.floor(dt/period))
+            if nperiods == 0:
+                dt = period
+            else:
+                dt = nperiods*period
+            n = dt / f # number of cycles in this trial's spike train
+            nint = int(np.floor(n)) # number of full cycles before the end
+            nfrac = n % 1 # fractional number of cycles at the end
+            # need to split off spikes from last fractional cycle:
+            dtint, dtfrac = nint*f, nfrac*f # (s)
+            fracspikei = spikes.searchsorted(dtint)
+            fullspikes, fracspikes = spikes[:fracspikei], spikes[frackspikei:]
+        '''
+        omega = 2 * np.pi * f # angular frequency (rad/s)
+        res[triali] = (np.cos(omega*spikes)).sum() / dt
+        ims[triali] = (np.sin(omega*spikes)).sum() / dt
+    Hs = (res + ims*1j) # array of complex numbers
+    if f != 0: # not just the degenerate mean firing rate case
+        Hs = 2 * Hs # convert to peak-to-peak
+    if mean == 'scalar':
+        Hamplitudes = np.abs(Hs) # ntrials long
+        r = np.nanmean(Hamplitudes) # mean of amplitudes
+        theta = np.nanmean(np.angle(Hs)) # mean of angles
+        #rstd = np.nanstd(Hamplitudes) # stdev of amplitudes
+    elif mean == 'vector':
+        Hmean = np.nanmean(Hs) # single complex number
+        r = np.abs(Hmean) # corresponds to PSTH amplitude
+        theta = np.angle(Hmean) # angle of mean vector
+        #rstd = np.nanstd(Hs) # single scalar, corresponds to PSTH stdev
+    else:
+        raise ValueError('Unknown `mean` method %r' % mean)
+    ## NOTE: another way to calculate theta might be:
+    #theta = np.arctan2(np.nanmean(np.imag(Hs)), np.nanmean(np.real(Hs)))
+    return r, theta
+
+def sparseness(x):
+    """Sparseness measure, from Vinje and Gallant, 2000. This is basically 1 minus the ratio
+    of the square of the sums over the sum of the squares of the values in signal x"""
+    if x.sum() == 0:
+        return 0
+    n = len(x)
+    return (1 - (x.sum()/n)**2 / np.sum((x**2)/n)) / (1 - 1/n)
+
+def reliability(signals, average='mean', ignore_nan=True):
+    """Calculate reliability across trials in signals, one row per trial, by finding the
+    average Pearson's rho between all pairwise combinations of trial signals
+
+    Returns
+    -------
+    reliability : float
+    rhos : ndarray
+    """
+    ntrials = len(signals)
+    if ntrials < 2:
+        return np.nan # can't calculate reliability with less than 2 trials
+    if ignore_nan:
+        rhos = pairwisecorr_nan(signals)
+    else:
+        rhos, _ = pairwisecorr(signals)
+    if average == 'mean':
+        rel = np.nanmean(rhos)
+    elif average == 'median':
+        rel = np.nanmedian(rhos)
+    else:
+        raise ValueError('Unknown average %r' % average)
+    return rel, rhos
+
+def snr_baden2016(signals):
+    """Return signal-to-noise ratio, aka Berens quality index (QI), from Baden2016, of a set of
+    signals. Take ratio of the temporal variance of the trial averaged signal (i.e. PSTH),
+    to the average across trials of the variance in time of each trial. Ranges from 0 to 1.
+    Ignores NaNs."""
+    assert signals.ndim == 2
+    signal = np.nanvar(np.nanmean(signals, axis=0)) # reduce row axis, calc var across time
+    noise = np.nanmean(np.nanvar(signals, axis=1)) # reduce time axis, calc mean across trials
+    if signal == 0:
+        return 0
+    else:
+        return signal / noise
+
+def get_psth_peaks_gac(ts, t, psth, thresh, sigma=0.02, alpha=1.0, minpoints=5,
+                       lowp=16, highp=84, checkthresh=True, verbose=True):
+    """Extract PSTH peaks from spike times ts collapsed across trials, by clustering them
+    using gradient ascent clustering (GAC, Swindale2014). Then, optionally check each peak
+    against its amplitude in the PSTH (and its time stamps t), to ensure it passes thresh.
+    Also extract the left and right edges of each peak, based on where each peak's mass falls
+    between lowp and highp percentiles.
+
+    sigma is the clustering bandwidth used by GAC, in this case in seconds.
+
+    Note that very narrow peaks will be missed if the resolution of the PSTH isn't high enough
+    (TRES=0.0001 is plenty)"""
+
+    from spyke.gac import gac # .pyx file
+
+    ts2d = np.float32(ts[:, None]) # convert to 2D (one row per spike), contig float32
+    # get cluster IDs and positions corresponding to spikets, cpos is indexed into using
+    # cids:
+    cids, cpos = gac(ts2d, sigma=sigma, alpha=alpha, minpoints=minpoints, returncpos=True,
+                     verbose=verbose)
+    ucids = np.unique(cids) # unique cluster IDs across all spikets
+    ucids = ucids[ucids >= 0] # exclude junk cluster -1
+    #npeaks = len(ucids) # but not all of them will necessarily cross the PSTH threshold
+    peakis, lis, ris = [], [], []
+    for ucid, pos in zip(ucids, cpos): # clusters are numbered in order of decreasing size
+        spikeis, = np.where(cids == ucid)
+        cts = ts[spikeis] # this cluster's spike times
+        # search all spikes for argmax, same as using lowp=0 and highp=100:
+        #li, ri = t.searchsorted([cts[0], cts[-1]])
+        # search only within the percentiles for argmax:
+        lt, rt = np.percentile(cts, [lowp, highp])
+        li, ri = t.searchsorted([lt, rt])
+        if li == ri:
+            # start and end indices are identical, cluster probably falls before first or
+            # after last spike time:
+            assert li == 0 or li == len(psth)
+            continue # no peak to be found in psth for this cluster
+        localpsth = psth[li:ri]
+        # indices of all local peaks within percentiles in psth:
+        #allpeakiis, = argrelextrema(localpsth, np.greater)
+        #if len(allpeakiis) == 0:
+        #    continue # no peaks found for this cluster
+        # find peakii closest to pos:
+        #peakii = allpeakiis[abs((t[li + allpeakiis] - pos)).argmin()]
+        # find biggest peak:
+        #peakii = allpeakiis[localpsth[allpeakiis].argmax()]
+        peakii = localpsth.argmax() # find max point
+        if peakii == 0 or peakii == len(localpsth)-1:
+            continue # skip "peak" that's really just a start or end point of localpsth
+        peaki = li + peakii
+        if checkthresh and psth[peaki] < thresh:
+            continue # skip peak that doesn't meet thresh
+        if peaki in peakis:
+            continue # this peak has already been detected by a preceding, larger, cluster
+        peakis.append(peaki)
+        lis.append(li)
+        ris.append(ri)
+        if verbose:
+            print('.', end='') # indicate a peak has been found
+    return np.asarray(peakis), np.asarray(lis), np.asarray(ris)
+
+def pairwisecorr(signals, weight=False, invalid='ignore'):
+    """Calculate Pearson correlations between all pairs of rows in 2D signals array.
+    See np.seterr() for possible values of `invalid`"""
+    assert signals.ndim == 2
+    assert len(signals) >= 2 # at least two rows, i.e. at least one pair
+    N = len(signals)
+    # potentially allow 0/0 (nan) rhomat entries by ignoring 'invalid' errors
+    # (not 'divide'):
+    oldsettings = np.seterr(invalid=invalid)
+    rhomat = np.corrcoef(signals) # full correlation matrix
+    np.seterr(**oldsettings) # restore previous numpy error settings
+    uti = np.triu_indices(N, k=1)
+    rhos = rhomat[uti] # pull out the upper triangle
+    if weight:
+        sums = signals.sum(axis=1)
+        # weight each pair by the one with the least signal:
+        weights = np.vstack((sums[uti[0]], sums[uti[1]])).min(axis=0) # all pairs
+        weights = weights / weights.sum() # normalize, ensure float division
+        return rhos, weights
+    else:
+        return rhos, None
+
+def pairwisecorr_nan(signals):
+    """Calculate Pearson correlations between all pairs of rows in 2D signals array,
+    while skipping NaNs. Relies on Pandas DataFrame method"""
+    assert signals.ndim == 2
+    assert len(signals) >= 2 # at least two rows, i.e. at least one pair
+    N = len(signals)
+    rhomat = np.array(pd.DataFrame(signals.T).corr()) # full correlation matrix
+    return np.triu(rhomat, k=1) # non-unique entries zeroed
+
+def vector_OSI(oris, rates):
+    """Vector averaging method for calculating orientation selectivity index (OSI).
+    See Bonhoeffer1995, Swindale1998 and neuropy.neuron.Tune.pref().
+    Reasonable use case is to take model tuning curve, calculate its values at a
+    fine ori resolution (say 1 deg), and use that as input rates here.
+
+    Parameters
+    ----------
+    oris : orientations (degrees, potentially ranging 0 to 360)
+        Attention: for orientation data, ori should always range 0 to 180! Only for direction
+        data (e.g. from drifting gratings) can it go 0 to 360; it will then return the
+        orientation selectivity of the data. For the direction selectivity, ori has to again
+        range 0 to 180!
+    rates : corresponding firing rates
+
+    Returns
+    -------
+    r : length of net vector average as fraction of total firing
+
+    """
+    orisrad = 2 * oris * np.pi/180 # double the angle, convert from deg to rad
+    x = (rates*np.cos(orisrad)).sum()
+    y = (rates*np.sin(orisrad)).sum()
+    n = rates.sum()
+    r = np.sqrt(x**2+y**2) / n # fraction of total firing
+    return r
+
+def percentile_ci(data, alpha=0.05, func=np.percentile, **kwargs):
+    """Simple percentile method for confidence intervals. No assumptions
+    about shape of distribution"""
+    data = np.array(data) # accept lists & tuples
+    lower, med, upper = func(data, [100*alpha, 50, 100*(1-alpha)], **kwargs)
+    return med, lower, upper
+
+def sum_of_gaussians(x, dp, rp, rn, r0, sigma):
+    """ORITUNE  sum of two gaussians living on a circle, for orientation tuning
+
+        x are the orientations, bet 0 and 360.
+        dp is the preferred direction (bet 0 and 360)
+        rp is the response to the preferred direction;
+        rn is the response to the opposite direction;
+        r0 is the background response (useful only in some cases)
+        sigma is the tuning width;
+    """
+    angles_p = 180 / pi * np.angle(np.exp(1j*(x-dp) * pi / 180))
+    angles_n = 180 / pi * np.angle(np.exp(1j*(x-dp+180) * pi / 180))
+    y = (r0 + rp*np.exp(-angles_p**2 / (2*sigma**2)) + rn*np.exp(-angles_n**2 / (2*sigma**2)))
+    return y