main.py 90 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786
  1. """Collect rasters of responses for all LGN opto experiments as a function of stimulus type
  2. (movies, gratings), run state, and opto state. Then, calculate various measures and save
  3. results to DataFrames. Some variables (DataFrames and mseu string lists) can be optionally
  4. saved to and loaded from .pickle files to avoid long and unnecessary recalculation times.
  5. After connecting to server with `djd`, run this script at the IPython/Jupyter command line;
  6. the `-i` runs scripts in IPython's namespace instead of an empty one:
  7. `run -i main.py`
  8. `pvmvis` (default), `ntsrmvis` or `negntsrmvis` can be provided as an `--exptype=` argument
  9. to this script, to load and later plot PV-Cre (most of the paper), Ntsr1-Cre (Fig1S4),
  10. or Ntsr negative (Fig1S5) data, respectively.
  11. This script can also be run without a server connection, if you choose to load from pickles.
  12. """
  13. import sys
  14. import argparse
  15. from numpy import pi
  16. from numpy.random import choice
  17. import scipy.stats
  18. from scipy.stats import (linregress, ttest_rel, ttest_ind, ttest_1samp, wilcoxon,
  19. ks_2samp)
  20. from scipy.stats.mstats import gmean
  21. from scipy.optimize import least_squares
  22. #import statsmodels.formula.api as smf
  23. import matplotlib.pyplot as plt
  24. from matplotlib.offsetbox import AnchoredText
  25. from matplotlib.ticker import ScalarFormatter, FormatStrFormatter
  26. import seaborn as sns
  27. import pandas as pd
  28. #from pycircstat.tests import kuiper
  29. from collections import OrderedDict as odict
  30. import expio
  31. from djd.util import (seq2sql, intround, split_tranges, wrap_raster, issubset, key2mseustr,
  32. mseustr2key, key2msestr)
  33. from djd.plot import wintitle, simpletraster, cf, saveall, simple_plot_rfs
  34. from djd.signal import (raster2psth, raster2freqcomp, sparseness, reliability, snr_baden2016,
  35. get_psth_peaks_gac, get_trialmat, pairwisecorr, rfs_for_cdata)
  36. from djd.model import fit_MUA_RFs
  37. from djd.stats import vector_OSI, percentile_ci, rayleigh_test, niell_DSI
  38. from djd.model import sum_of_gaussians
  39. from djd.eye import i_clean
  40. from util import (get_exps, fitmodel, desat, get_max_snr, ms2msstr,
  41. msu2msustr, mseustr2msustr, mseustrs2msustrs, mseustr2mstr, mseustrs2mstrs,
  42. mseustr2msestr, mseustrs2msestrs, findmse,
  43. load, load_all, save, export2csv, axes_disable_scientific, linear_loss,
  44. residual_rsquared)
  45. #import rf # deprecated
  46. ## Define constants:
  47. EXPTYPE = 'pvmvis' # pvmvis, ntsrmvis, negntsrmvis, pgnpvmvis
  48. if __name__ == '__main__':
  49. # parse command-line arguments:
  50. parser = argparse.ArgumentParser(description='Calculate script for natfeedback paper')
  51. parser.add_argument("--exptype", help="pvmvis, ntsrmvis, negntsrmvis, pgnpvmvis")
  52. args = parser.parse_args()
  53. EXPTYPE = args.exptype or EXPTYPE
  54. print('*** EXPTYPE: %r' % EXPTYPE)
  55. PAPERPATH = '/home/mspacek/blab/natfb/paper'
  56. VERBOSE = False # only used for gac at the moment
  57. RATETHRESH = 0.01 # mean firing rate threshold for unit inclusion, across all trial types, Hz
  58. SNRTHRESH = 0.015
  59. #RSQTHRESH = 0.2
  60. # p-value thresh for significance of each point's deviance from diagonal in meanrate, burst
  61. # ratio, and reliability scatter plots:
  62. SCATTERPTHRESH = 0.05
  63. SGNF2ALPHA = {True: 1, False: 0.4}
  64. # visual drive unit inclusion criteria for gratings, ignores st8 and opto:
  65. VDNCRIT = 1 # min num points in tuning curve that are at least VDZCRIT sems away from spont rate
  66. VDZCRIT = 2.5 # 1.95, 2.5, 3.2
  67. ## TODO: these need to be optimized for mouse instead of cat:
  68. PSTHBASELINEMEDIANX = 2 # PSTH median multiplier to estimate baseline
  69. PSTHPEAKMINTHRESH = 3 # PSTH peak detection threshold relative to estimated baseline, Hz
  70. STIMTYPES = 'mvi', 'grt'
  71. stimtype2axislabel = {'grt':'grating', 'mvi':'movie'}
  72. # exclude 'clr' kind, since 'nat' is a superset of 'clr' anyway
  73. #MVIKINDS = 'nat', 'pnk', 'shf'
  74. MVIKINDS = 'nat',
  75. ALLSTIMTYPES = MVIKINDS + ('grt',)
  76. STIMTYPESPLUSALL = list(STIMTYPES) + ['mvi+grt'] # strings
  77. STIMTYPESPLUSALLI = list(STIMTYPES) + [slice(None)] # for use as indices
  78. # iterate over run states in this order:
  79. ST8S = ['run', 'sit'] # excludes 'none'
  80. ALLST8S = ['none', 'run', 'sit']
  81. ST8COMBOS = [['none'], ['run', 'sit'], ['run', 'sit']]
  82. OPTOCOMBOS = [[False, True], [False], [True]]
  83. st82crit = odict((('run', '(run_speed > 1).mean() > 0.5'),
  84. ('sit', '(run_speed < 0.25).mean() > 0.5'),
  85. ('none', 'none')))
  86. rungreen = '#' + ''.join([ format(val, '02x') for val in [66, 168, 117] ])
  87. sitorange = '#' + ''.join([ format(val, '02x') for val in [235, 115, 9] ])
  88. st82clr = {'run':rungreen, 'sit':sitorange, 'none':'black'}
  89. st82alpha = {None:1, 'run':1, 'sit':0.5} # sitting is like suppression, so desaturate colour
  90. st82clrmap = {'run':plt.cm.Greens, 'sit':plt.cm.Oranges, 'none':plt.cm.Greys}
  91. # define example mseus and their colors:
  92. green = '#007f00' # deep green
  93. violet = '#9f3fff' # pure violet (7f00ff) is a little too dark
  94. darkblue = '#000064' # control condition
  95. optoblue = '#2a7fff' # light shade of blue for opto condition
  96. black = '#000000'
  97. asterisk = 5, 2, 0 # 5 sided, asterisk style, 0 deg rotation, has rounded corners though
  98. DEFSZ = 50 # default scatter plot point size (open points)
  99. exmpli2clr = {1:violet, 2:green, 3:optoblue} # example neurons 1, 2, 3
  100. # use different marker shapes to help distinguish example points:
  101. exmpli2mrk = {1:'x', 2:'+', 3:asterisk}
  102. exmpli2sz = {1:60, 2:70, 3:60} # marker size
  103. exmpli2lw = {1:2, 2:2, 3:2} # (marker) line width
  104. # use large closed circles for all example points:
  105. #exmpli2mrk = {1:'.', 2:'.', 3:'.'}
  106. #EXMPLSZ = 70 # example scatter plot point size (closed points)
  107. #exmpli2sz = {1:EXMPLSZ, 2:EXMPLSZ, 3:EXMPLSZ}
  108. fig1exmpli = 1 # fig1 uses example neuron 1
  109. fig5exmpli = 1 # fig5 uses example neuron 1
  110. mvimseu2exmpli = {'PVCre_2018_0003_s03_e03_u51':1, # example neuron 1
  111. #'PVCre_2018_0003_s03_e03_u52',
  112. 'PVCre_2017_0008_s12_e06_u56':2, # example neuron 2
  113. #'PVCre_2017_0008_s09_e07_u22',
  114. }
  115. grtmseu2exmpli = {'PVCre_2018_0003_s03_e02_u51':1, # example neuron 1
  116. #'PVCre_2018_0003_s03_e02_u52',
  117. #'PVCre_2018_0003_s02_e02_u11',
  118. #'PVCre_2018_0003_s02_e02_u38':
  119. #'PVCre_2018_0001_s02_e06_u64':
  120. #'PVCre_2018_0001_s02_e06_u61':
  121. #'PVCre_2018_0001_s02_e06_u64':
  122. #'PVCre_2018_0001_s05_e02_u16':3, # old example neuron 3
  123. 'PVCre_2018_0003_s05_e02_u164':3, # new example neuron 3 (2019-05-31)
  124. }
  125. mvimsu2exmpli = { mseustr2msustr(k):v for k, v in mvimseu2exmpli.items() }
  126. grtmsu2exmpli = { mseustr2msustr(k):v for k, v in grtmseu2exmpli.items() }
  127. msu2exmpli = {**mvimsu2exmpli, **grtmsu2exmpli} # merge the two
  128. burstclr = 'red'
  129. OPTOS = False, True
  130. opto2tuni = {False:0, True:1} # for indexing into tun_params field in Tuning table
  131. opto2fb = {None:'', False:'feedback', True:'suppression'}
  132. opto2clr = {None:darkblue, False:darkblue, True:optoblue}
  133. opto2alpha = {None:1, False:1, True:0.5} # with opto, feedback is removed, so desaturate colour
  134. modmeasures = ['meanrate', 'meanrate02', 'meanrate35', 'blankmeanrate', 'blankcondmeanrate',
  135. 'meanburstratio', 'blankmeanburstratio', 'blankcondmeanburstratio',
  136. 'spars', 'rel', 'meanpkw', 'snr']
  137. modmeasuresnoblankcond = modmeasures.copy()
  138. modmeasuresnoblankcond.remove('blankcondmeanrate')
  139. modmeasuresnoblankcond.remove('blankcondmeanburstratio')
  140. modmeasuresnoblank = modmeasuresnoblankcond.copy()
  141. modmeasuresnoblank.remove('blankmeanrate')
  142. modmeasuresnoblank.remove('blankmeanburstratio')
  143. measure2axislabel = {'meanrate':'FR', 'meanrate02':'FR', 'meanrate35':'FR',
  144. 'blankmeanrate':'FR', 'blankcondmeanrate':'FR',
  145. 'meanburstratio':'burst ratio', 'blankmeanburstratio':'burst ratio',
  146. 'blankcondmeanburstratio':'burst ratio',
  147. 'spars':'sparseness', 'rel':'reliability',
  148. 'meanpkw':'mean peak width', 'snr':'SNR'}
  149. measure2axisunits = {'meanrate':' (spk/s)', 'meanrate02':' (spk/s)', 'meanrate35':' (spk/s)',
  150. 'meanpkw':' (s)'}
  151. short2longaxislabel = {'FR':'Firing rate'}
  152. # restrict FiringPattern to our usual burst definition from Lu et al, 1992:
  153. BURSTCRITERION = ('(fp_dtsilent BETWEEN 0.099 and 0.101) AND '
  154. '(fp_dtburst BETWEEN 0.0039 AND 0.0041)')
  155. fmodes = ['', 'nb', 'b', 'nr'] # (narrow) firing modes
  156. fmode2txt = {'':'all', 'nb':'nonburst', 'b':'burst', 'nr':'nonrand'}
  157. # Tuning table keys used for gratings:
  158. ivs_order = 'grat_orientation, opto1'
  159. tun_model = 'sum_of_gaussians'
  160. model = 'threshlin' # linear or threshlin
  161. relaverage = 'mean' # reliability averaging method: mean or median
  162. DEFAULTFIGURESIZE = 2.8, 2.8 # normal
  163. DUMBBELLFIGSIZE = 3.3, 4.4 # dumbbell plots
  164. RASTERSCALEX = 0.75
  165. OFFSETS = -0.5, 0.75 # wide raster time offsets, sec
  166. OFFSETS02 = 0, -3 # limits to spikes from t=0 to 2 s, assuming 5 sec movies
  167. OFFSETS35 = 3, 0 # limits to spikes from t=3 to 5 s
  168. NTRIALSMVIGRT = 120 # limit movies to first 120 trials for more direct comparison with gratings
  169. PSTHHEIGHT = 3
  170. binw, tres, kernel = 0.02, 0.001, 'gauss' # for calculating PSTHs, sec
  171. psthplottres = 0.010 # sec
  172. ssx = intround(psthplottres / tres) # subsample factor to get the desired plot tres
  173. MINNTRIALSAMPLETHRESH = 10
  174. # ON/OFF/transient classification:
  175. PDT = 0.050 # propagation delay time for movie onset (s)
  176. WINDT = 0.20 # duration separating start of pre and PDT, and PDT and end of post (s)
  177. EXCLTR = 0, 2*PDT # time range around PDT to treat as transient response
  178. assert EXCLTR[0] <= PDT <= EXCLTR[1]
  179. ## Query the user:
  180. msg = ('Constants defined. What next?\n'
  181. '0: calculate from scratch, including fits\n'
  182. '1: calculate from scratch, but load fits from pickle\n'
  183. '2: load from pickles and exit\n'
  184. '3: load from pickles and assume connected\n'
  185. '>> ')
  186. response = int(input(msg))
  187. # load fits from pickle? save heaps of CPU time:
  188. LOADFITSFROMPICKLE = False if response == 0 else True
  189. if response in [2, 3]:
  190. subfolder = EXPTYPE if EXPTYPE != 'pvmvis' else ''
  191. name2val = load_all(subfolder=subfolder)
  192. locals().update(name2val)
  193. if response == 2:
  194. sys.exit()
  195. ## Collect series, experiments, units from database:
  196. '''
  197. # all single movie opto experiments in LGN that have been sorted:
  198. nats = ((e & 'e_name LIKE "MAS_%%"' & 'e_name NOT LIKE "%%+%%"' & 'e_optowl = 470')
  199. & (s & 's_depth > 2500' & 's_depth < 3500') # upper limit excludes TRN series
  200. & u)
  201. # all pink noise opto experiments in LGN that have been sorted:
  202. pnks = ((e & 'e_name LIKE "MAS_%%PNK%%"' & 'e_optowl = 470')
  203. & (s & 's_depth > 2500' & 's_depth < 3500')
  204. & u)
  205. '''
  206. """
  207. PVCre mice:
  208. * PVCre_2018_0002_s04 - good RFs, OK tuning, excellent opto
  209. ** PVCre_2019_0001_s05 - excellent RFs, OK tuning, good opto
  210. - make sure to exclude e01 and e02 due to seizure between e02 and e03
  211. ** PVCre_2019_0001_s06 - excellent RFs and tuning, OK opto, started by Davide
  212. - sorted: single unit RFs are quite bad though, excluding
  213. - this mouse's pupil pointed very high under visible light illumination, in the dark
  214. with dilated pupil, it pointed forward normally. Another good reason to exclude
  215. - PVCre_2019_0001_s07 - excellent RFs and tuning, very weak opto
  216. * PVCre_2019_0001_s08 - excellent RFs and tuning, OK opto. stained with DiI
  217. * PVCre_2019_0002_s05 - excellent RFs and tuning, good opto
  218. * PVCre_2019_0002_s06 - excellent RFs, OK tuning, excellent opto
  219. * PVCre_2019_0002_s07 - excellent RFs, good tuning and opto
  220. ** PVCre_2019_0002_s08 - excellent RFs, tuning and opto! stained with DiI
  221. - sorted, single unit RFs are good
  222. As of 2018-12-13, both PVCre_2018_0010_s02 and s05 are also excluded via updated
  223. 's_region' values. Both were intended to be in dLGN, but may have somehow missed it, based on
  224. weird RFs
  225. Ntsr1cre mice:
  226. * Ntsr1Cre_2019_0002_s03: sorted, mov rasters look wrong (bad sorting? asleep mouse?),
  227. 100 trials/cond, no running, but opto effects seem consistent
  228. - check itracking movies, resort this one
  229. - Ntsr1Cre_2019_0002_s04: no mov
  230. * Ntsr1Cre_2019_0002_s05: sorted but only 3 units, only 2 have mov PSTH peaks, ~ consistent
  231. - Ntsr1Cre_2019_0003: no mov
  232. - Ntsr1Cre_2019_0004 has big cortical hole:
  233. - Ntsr1Cre_2019_0004_s02: nice mov PSTHs, but no opto effect at all for mov or grt
  234. - Ntsr1Cre_2019_0004_s03: ??
  235. ** Ntsr1Cre_2019_0007_s04 - good
  236. - Ntsr1Cre_2019_0007_s05 - suspicious about if it's really LGN, SU results look poor
  237. ** Ntsr1Cre_2019_0008_s05 - good
  238. - do older mice just for their grating data, do shell/core and other separation
  239. """
  240. mvis = get_exps(locals(), EXPTYPE)
  241. mviseries = s & mvis
  242. mviunits = up & mvis
  243. mvispikes = spk & mvis
  244. if EXPTYPE == 'negntsrmvis': # force inclusion of a V1 movie recording
  245. mvieye = EyeIxtract() & [{'m':'Ntsr1Cre_2020_0001', 's':2, 'e':11}, # V1
  246. {'m':'Ntsr1Cre_2020_0001', 's':4, 'e':10}] # dLGN
  247. else:
  248. mvieye = EyeIxtract() & mvis
  249. mvirun = rt & mvis
  250. # fetch unique animals included in data set:
  251. #np.unique(mviseries.fetch('m'))
  252. ## TODO: need to check which units are plausibly in dLGN, not all units on the shank will be!
  253. ## - this is mostly done ahead of time during spike sorting, MUA envelope RF analysis suggests
  254. ## some single units have maxchans that are outside the chan range of significant RFs of
  255. ## the population, though there are some caveats to that - maybe maxchan is too stringent,
  256. ## or maybe it's possible for one or two single units to have visual responses even when
  257. ## the population doesn't show much response to sparse noise
  258. # all grating oritun opto experiments in LGN that have been sorted that occurred during the
  259. # same series as movie experiments:
  260. grts = ((e & 'e_name LIKE "%%oritun%%"' & 'e_optowl IN (465, 470)')
  261. & mviseries
  262. & u)
  263. grtseries = mviseries & grts
  264. # use the same units for gratings as in movies:
  265. #grtunits = mviunits & grts
  266. # use units active during a grating experiment, not necessarily the same units as in movies:
  267. grtunits = up & grts
  268. grtspikes = spk & grts
  269. grttun = tun & grts
  270. spons = ((e & 'e_name LIKE "%%spont%%"' & 'e_optowl IN (465, 470)')
  271. & mviseries
  272. & u)
  273. # get all possible movie mseus, regardless of rate:
  274. mvimseus = []
  275. for mse in mvis.fetch(dj.key):
  276. uids = (up & mse).fetch('u')
  277. for uid in uids:
  278. mseu = mse.copy()
  279. mseu['u'] = uid
  280. mvimseus.append(mseu)
  281. # get all possible grating mseus, regardless of rate:
  282. grtmseus = []
  283. for mse in grts.fetch(dj.key):
  284. uids = (up & mse).fetch('u')
  285. for uid in uids:
  286. mseu = mse.copy()
  287. mseu['u'] = uid
  288. grtmseus.append(mseu)
  289. ## TODO: grtmseu filtering by visual responsiveness should be done here, but would currently
  290. ## result in index errors in the big grtresp loop below, so it's done piecemeal in the loop:
  291. #grtmseus = (grts * up).fetch(dj.key)
  292. #grtmseus = (grts * up) & (vd.Units() & 'n_crit=1' & 'z_crit > 2.5')
  293. # get all possible spontaneous mseus, regardless of rate:
  294. sponmseus = []
  295. for mse in spons.fetch(dj.key):
  296. uids = (up & mse).fetch('u')
  297. for uid in uids:
  298. mseu = mse.copy()
  299. mseu['u'] = uid
  300. sponmseus.append(mseu)
  301. msefmt = '{m}_s{s:02}_e{e:02}'
  302. msufmt = '{m}_s{s:02}_u{u:02}'
  303. mseufmt = '{m}_s{s:02}_e{e:02}_u{u:02}'
  304. mvimseustrs = [ mseufmt.format(**mseu) for mseu in mvimseus ]
  305. grtmseustrs = [ mseufmt.format(**mseu) for mseu in grtmseus ]
  306. sponmseustrs = [ mseufmt.format(**mseu) for mseu in sponmseus ]
  307. mvimsustrs = [ msufmt.format(**unit) for unit in mviunits ]
  308. grtmsustrs = [ msufmt.format(**unit) for unit in grtunits ]
  309. mvigrtmsustrs = list(np.union1d(mvimsustrs, grtmsustrs)) # superset of the two lists
  310. if response == 3:
  311. sys.exit()
  312. ## Calculate:
  313. ## collect movie rasters, meanrates, burst info, PSTHs, sparseness, reliability and PSTH
  314. ## peak info by mseu, movie kind, run state, and opto condition:
  315. levels = [mvimseustrs, MVIKINDS, ALLST8S, OPTOS]
  316. names = ['mseu', 'kind', 'st8', 'opto']
  317. mi = pd.MultiIndex.from_product(levels, names=names)
  318. columns = ['dt', 'optotrange', 'trialis',
  319. 'raster', 'braster', 'nbraster', 'wraster', 'wbraster', 'wnbraster', # rasters
  320. 'rates', 'rate02s', 'rate35s', 'blankrates', # single trial FRs
  321. 'meanrate', 'meanrate02', 'meanrate35', 'blankmeanrate', # trial-averaged FRs
  322. 'burstis', 'wburstis', 'burstratios', 'blankburstratios', # single trial BRs
  323. 'meanburstratio', 'blankmeanburstratio', # trial-averaged BRs
  324. 'psth', 'bpsth', 'nbpsth', 'wpsth', 'wbpsth', 'wnbpsth', 't', 'wt', 'bins', 'wbins',
  325. 'spars', 'rhos', 'rel', 'snr', 'pkts', 'pkws', 'meanpkw']
  326. mviresp = pd.DataFrame(index=mi, columns=columns)
  327. ntotpeaks = 0 # total number of peaks detected
  328. for mvi in mvis:
  329. msestr = msefmt.format(**mvi)
  330. print(msestr)
  331. exp = mvi['e']
  332. moviecond = mov & mvi
  333. kind2stimis = moviecond.moviekind2stimis() # movie kind to stimi mapping
  334. trials = trial & mvi
  335. alltrialis = np.sort(trials.fetch('trial_id')) # trialis corresponding to rows in tranges
  336. #rates = ra & mvi
  337. uids = (up & mvi).fetch('u')
  338. # all trials, in trial ID order:
  339. rasters, tranges, _, opto, _ = (spk & mvi).get_rasters()
  340. wrasters, *_ = (spk & mvi).get_rasters(offsets=OFFSETS) # w = wide, same returned tranges
  341. rasters02, *_ = (spk & mvi).get_rasters(offsets=OFFSETS02) # from 0 to 2 s, same tranges
  342. rasters35, *_ = (spk & mvi).get_rasters(offsets=OFFSETS35) # from 3 to 5 s, same tranges
  343. wtranges = tranges.copy() + OFFSETS
  344. #tranges02 = tranges.copy() + OFFSETS02
  345. #tranges35 = tranges.copy() + OFFSETS35
  346. # make sure trial duration is very consistent:
  347. dts = tranges.ptp(axis=1)
  348. dt = dts.mean() # mean trial duration
  349. dtstd = dts.std()
  350. assert dtstd / dt < 0.005
  351. # make sure opto tranges are consistent relative to trial tranges:
  352. optotrialstimt0s = tranges[opto == True][:, 0]
  353. evtranges = (evtimes & mvi).fetch1('ev_tranges')
  354. alloptotranges = evtranges - np.reshape(optotrialstimt0s, (-1, 1))
  355. optotrange = alloptotranges.mean(axis=0)
  356. optotrangestd = alloptotranges.std(axis=0)
  357. assert (optotrangestd / optotrange < 0.005).all()
  358. # get tranges and rasters for just the period before stimulus onset during which opto
  359. # can be on, i.e. from optotrange[0] (-0.272 s) to 0 s before each stimulus onset:
  360. blankdt = abs(optotrange[0]) # averages to 0.283 s across all experiments
  361. print('mvi blankdt:', blankdt)
  362. blanktranges = np.zeros_like(tranges)
  363. blanktranges[:, 0] = tranges[:, 0] - blankdt
  364. blanktranges[:, 1] = tranges[:, 0]
  365. blankrasters, _, _, _, _ = (spk & mvi).get_rasters(tranges=blanktranges)
  366. mviopto = evcond & mvi & 'ev_chan="opto1"'
  367. mvioptostimis, mvioptovals = mviopto.fetch('stim_id', 'ec_val')
  368. # calculate burst ratios for all units:
  369. firingpatterns = fp & mvi & BURSTCRITERION
  370. brus = firingpatterns.burst_ratio(tranges=tranges)
  371. wbrus = firingpatterns.burst_ratio(tranges=wtranges)
  372. blankbrus = firingpatterns.burst_ratio(tranges=blanktranges)
  373. for kind, kindstimis in kind2stimis.items():
  374. print(kind)
  375. if not kindstimis or kind not in MVIKINDS: # exclude empty stimis and 'clr' kind
  376. continue
  377. kindtrials = (trials & 'stim_id IN %s' % seq2sql(kindstimis))
  378. kindtrialis = kindtrials.fetch('trial_id')
  379. trialiis = np.isin(alltrialis, kindtrialis) # bool with same shape as alltrialis
  380. # calculate narrow and wide PSTH bins for all units:
  381. tmin, tmax = 0, dt
  382. bins = split_tranges([(tmin, tmax)], binw, tres) # narrow
  383. midbins = bins.mean(axis=1) # narrow
  384. wbins = split_tranges([(tmin+OFFSETS[0], tmax+OFFSETS[1])], binw, tres) # wide
  385. wmidbins = wbins.mean(axis=1) # wide
  386. for st8, st8crit in st82crit.items():
  387. if st8crit == 'none': # 'none' is not in the State.Trial table
  388. st8trialis = alltrialis
  389. else:
  390. st8trialis = (st8t & mvi & {'st8_crit':st8crit}).fetch('trial_id')
  391. st8trialis = np.sort(st8trialis)
  392. for opto in OPTOS:
  393. optois = mvioptovals == opto
  394. # limit stimis by opto:
  395. optostimis = mvioptostimis[optois]
  396. # also limit stimis by movie kind:
  397. stimis = np.intersect1d(kindstimis, optostimis)
  398. if len(stimis) == 0:
  399. continue
  400. optotrials = trials & 'stim_id IN %s' % seq2sql(optostimis)
  401. optotrialis = optotrials.fetch('trial_id')
  402. # intersect kind, st8 and opto trialis:
  403. trialis = np.intersect1d(kindtrialis, st8trialis)
  404. trialis = np.intersect1d(trialis, optotrialis)
  405. trialiis = np.isin(alltrialis, trialis) # bool with same shape as alltrialis
  406. if len(trialis) == 0:
  407. print('%s %s %s has no trials, skipping' % (kind, st8, opto))
  408. continue
  409. for uid in uids:
  410. if uid not in rasters:
  411. continue
  412. mseu = {'m':mvi['m'], 's':mvi['s'], 'e':mvi['e'], 'u':uid}
  413. mseustr = mseufmt.format(**mseu)
  414. raster = rasters[uid][trialiis] # subset of raster trials
  415. wraster = wrasters[uid][trialiis] # subset of wraster trials
  416. blankraster = blankrasters[uid][trialiis] # subset of blankraster trials
  417. raster02 = rasters02[uid][trialiis][:NTRIALSMVIGRT]
  418. raster35 = rasters35[uid][trialiis][:NTRIALSMVIGRT]
  419. rates = np.array([ len(trialspikes)/dt for trialspikes in raster ])
  420. rate02s = np.full(rates.shape, np.nan) # init with full shape with NaN
  421. rate35s = np.full(rates.shape, np.nan)
  422. blankrates = np.full(rates.shape, np.nan)
  423. # fill only the first n trials/blanks:
  424. rate02s[:len(raster02)] = np.array([ len(trialspikes)/2 for trialspikes
  425. in raster02 ])
  426. rate35s[:len(raster35)] = np.array([ len(trialspikes)/2 for trialspikes
  427. in raster35 ])
  428. blankrates[:len(blankraster)] = np.array([ len(bspikes)/blankdt for bspikes
  429. in blankraster ])
  430. meanrate = len(np.hstack(raster)) / dt / len(raster)
  431. #assert np.isclose(meanrate, rates.mean()) # sanity check
  432. meanrate02 = len(np.hstack(raster02)) / 2 / len(raster02)
  433. meanrate35 = len(np.hstack(raster35)) / 2 / len(raster35)
  434. blankmeanrate = len(np.hstack(blankraster)) / blankdt / len(blankraster)
  435. if meanrate < RATETHRESH:
  436. print('%s %s %s %s does not meet RATETHRESH, skipping'
  437. % (mseustr, kind, st8, opto))
  438. continue
  439. mviresprow = mviresp.loc[mseustr, kind, st8, opto]
  440. # technically, this is the wrong approach, and could result in writing
  441. # to a copy (an independent Series), which isn't the intention (see https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy)
  442. # but it seems to work because the mviresp DataFrame is being filled one
  443. # row at a time, and wasn't initialized explicitly by setting everything
  444. # to e.g. np.nan in a single vector operation
  445. mviresprow['dt'] = dt
  446. mviresprow['optotrange'] = optotrange
  447. mviresprow['trialis'] = trialis
  448. mviresprow['raster'] = raster # narrow raster
  449. mviresprow['wraster'] = wraster # wide raster, including OFFSETS
  450. mviresprow['rates'] = rates # based on narrow raster
  451. mviresprow['rate02s'] = rate02s
  452. mviresprow['rate35s'] = rate35s
  453. mviresprow['blankrates'] = blankrates
  454. mviresprow['meanrate'] = meanrate # based on narrow raster
  455. mviresprow['meanrate02'] = meanrate02
  456. mviresprow['meanrate35'] = meanrate35
  457. mviresprow['blankmeanrate'] = blankmeanrate
  458. # fill in the corresponding trial burst ratios:
  459. for bru, wbru, blankbru in zip(brus, wbrus, blankbrus):
  460. if bru['u'] == uid:
  461. assert bru['m'] == mseu['m']
  462. assert bru['s'] == mseu['s']
  463. assert bru['e'] == mseu['e']
  464. burstis = np.array(bru['burst_index'], # narrow, ragged array
  465. dtype=object)[trialiis]
  466. wburstis = np.array(wbru['burst_index'], # wide
  467. dtype=object)[trialiis]
  468. #blankburstis = np.array(blankbru['burst_index'],
  469. # dtype=object)[trialiis]
  470. burstratios = np.array(bru['burst_ratio'])[trialiis] # narrow
  471. blankburstratios = np.array(blankbru['burst_ratio'])[trialiis]
  472. mviresprow['burstis'] = burstis # narrow
  473. mviresprow['wburstis'] = wburstis # wide, for plotting
  474. mviresprow['burstratios'] = burstratios # narrow
  475. mviresprow['blankburstratios'] = blankburstratios
  476. mviresprow['meanburstratio'] = np.nanmean(burstratios)
  477. mviresprow['blankmeanburstratio'] = np.nanmean(blankburstratios)
  478. break # out of bru loop
  479. # collect narrow burst and non-burst spike rasters:
  480. braster, nbraster = [], []
  481. for row, bis in zip(raster, burstis): # iterate over trials in raster
  482. if len(bis) == 0:
  483. bis = [] # none
  484. nbis = slice(None) # all
  485. else:
  486. allis = np.arange(len(row)) # all spikes in this trial
  487. nbis = np.setdiff1d(allis, bis)
  488. braster.append(row[bis])
  489. nbraster.append(row[nbis])
  490. # collect wide burst and non-burst spike rasters:
  491. wbraster, wnbraster = [], []
  492. for row, bis in zip(wraster, wburstis): # iterate over trials in wraster
  493. if len(bis) == 0:
  494. bis = [] # none
  495. nbis = slice(None) # all
  496. else:
  497. allis = np.arange(len(row)) # all spikes in this trial
  498. nbis = np.setdiff1d(allis, bis)
  499. wbraster.append(row[bis])
  500. wnbraster.append(row[nbis])
  501. braster = np.array(braster, dtype=object) # ragged array
  502. nbraster = np.array(nbraster, dtype=object)
  503. wbraster = np.array(wbraster, dtype=object)
  504. wnbraster = np.array(wnbraster, dtype=object)
  505. mviresprow['braster'] = braster # narrow burst raster
  506. mviresprow['nbraster'] = nbraster # narrow non-burst raster
  507. mviresprow['wbraster'] = wbraster # wide burst raster
  508. mviresprow['wnbraster'] = wnbraster # wide non-burst raster
  509. # calculate narrow PSTHs:
  510. psth = raster2psth(raster, bins, binw, tres, kernel)
  511. bpsth = raster2psth(braster, bins, binw, tres, kernel)
  512. nbpsth = raster2psth(nbraster, bins, binw, tres, kernel)
  513. # calculate wide PSTHs:
  514. wpsth = raster2psth(wraster, wbins, binw, tres, kernel)
  515. wbpsth = raster2psth(wbraster, wbins, binw, tres, kernel)
  516. wnbpsth = raster2psth(nbraster, wbins, binw, tres, kernel)
  517. mviresprow['psth'] = psth # narrow PSTH
  518. mviresprow['bpsth'] = bpsth # narrow burst PSTH
  519. mviresprow['nbpsth'] = nbpsth # narrow non-burst PSTH
  520. mviresprow['wpsth'] = wpsth # wide PSTH
  521. mviresprow['wbpsth'] = wbpsth # wide burst PSTH
  522. mviresprow['wnbpsth'] = wnbpsth # wide non-burst PSTH
  523. mviresprow['t'] = midbins # narrow PSTH timepoints
  524. mviresprow['wt'] = wmidbins # wide PSTH timepoints
  525. mviresprow['bins'] = bins # narrow PSTH bins
  526. mviresprow['wbins'] = wbins # wide PSTH bins
  527. # calculate PSTH sparseness on narrow PSTH:
  528. spars = sparseness(psth)
  529. mviresprow['spars'] = spars
  530. # collect single trial signals from narrow PSTH, calc response reliability:
  531. trialsignals = []
  532. for trialspiketimes in raster:
  533. trialsignal = raster2psth([trialspiketimes], bins, binw, tres, kernel)
  534. trialsignals.append(trialsignal)
  535. trialsignals = np.vstack(trialsignals)
  536. rel, rhos = reliability(trialsignals, average=relaverage)
  537. if not np.isnan(rel):
  538. assert rel > -0.05 # can come out weakly negative
  539. mviresprow['rhos'] = rhos
  540. mviresprow['rel'] = max(rel, 0) # clamp weakly negative values to 0
  541. mviresprow['snr'] = snr_baden2016(trialsignals)
  542. # detect peaks in narrow PSTH:
  543. baseline = PSTHBASELINEMEDIANX * np.median(psth)
  544. thresh = baseline + PSTHPEAKMINTHRESH # peak detection threshold
  545. spikets = np.sort(np.concatenate(raster)) # flatten raster spike times
  546. peakis, lis, ris = get_psth_peaks_gac(spikets, midbins, psth, thresh,
  547. verbose=VERBOSE)
  548. if VERBOSE:
  549. print() # new line
  550. npeaks = len(peakis)
  551. if npeaks > 0:
  552. mviresprow['pkts'] = peakis * tres # PSTH peak times (s)
  553. pkws = (ris - lis) * tres # PSTH peak widths (s)
  554. mviresprow['pkws'] = pkws
  555. mviresprow['meanpkw'] = pkws.mean() # mean PSTH peak width (s)
  556. ## TODO: use the peakis themselves for some other kind of measure,
  557. ## maybe plot distributions of peaks relative to stimulus start, or
  558. ## interpeak intervals, e.g.
  559. # use absence of PSTH peaks to decide which units/conditions to exclude?
  560. # this would obviously only work for movies, but that's what mviresp
  561. # is all about...
  562. #if npeaks == 0:
  563. # nrnids[state].append(nid) # save nonresponsive nids by state
  564. # continue # this PSTH has no peaks, skip all subsequent measures
  565. ntotpeaks += npeaks
  566. print('Detected %d movie PSTH peaks' % ntotpeaks)
  567. ## collect grating meanrates, rasters, tuning, and burst info by mseu,
  568. ## run state, and opto condition:
  569. levels = [grtmseustrs, ALLST8S, OPTOS]
  570. names = ['mseu', 'st8', 'opto']
  571. mi = pd.MultiIndex.from_product(levels, names=names)
  572. columns = ['dt', 'optotrange', 'trialis',
  573. 'raster', 'braster', 'nbraster', 'wraster', # rasters
  574. 'rates', 'blankrates', 'blankcondrates', # single trial FRs
  575. 'meanrate', 'blankmeanrate', 'blankcondmeanrate', # trial-averaged FRs
  576. 'burstis', 'wburstis', 'burstratios', # single trial BRs
  577. 'blankburstratios', 'blankcondburstratios', # single trial BRs
  578. 'meanburstratio', 'blankmeanburstratio', 'blankcondmeanburstratio', # trial-averaged BRs
  579. 'stimis', 'oris', 'tfreq', 'tun', 'tunrsq']
  580. grtresp = pd.DataFrame(index=mi, columns=columns)
  581. for grt in grts:
  582. msestr = msefmt.format(**grt)
  583. print(msestr)
  584. exp = grt['e']
  585. alltrials = trial & grt # all trials, including blanks
  586. # all stimis and trialis, in stimis order:
  587. allstimis, alltrialis = alltrials.fetch('stim_id', 'trial_id')
  588. # map triali to stimi for all trials:
  589. triali2stimi = { triali:stimi for stimi, triali in zip(allstimis, alltrialis) }
  590. alltrialis = np.sort(alltrialis) # trialis corresponding to rows in tranges
  591. trials = alltrials & (grat & 'grat_contrast != 0.0') # filter out blank trials
  592. blankcondtrials = alltrials & (grat & 'grat_contrast = 0.0') # blank trials only
  593. uids = (up & grt).fetch('u')
  594. # all trials, in trial ID (temporal) order:
  595. rasters, tranges, _, opto, _ = (spk & grt).get_rasters()
  596. wrasters, *_ = (spk & grt).get_rasters(offsets=OFFSETS) # w = wide, same returned tranges
  597. wtranges = tranges.copy() + OFFSETS
  598. # make sure opto tranges are consistent relative to trial tranges:
  599. optotrialstimt0s = tranges[opto == True][:, 0]
  600. evtranges = (evtimes & grt).fetch1('ev_tranges')
  601. alloptotranges = evtranges - np.reshape(optotrialstimt0s, (-1, 1))
  602. optotrange = alloptotranges.mean(axis=0)
  603. optotrangestd = alloptotranges.std(axis=0)
  604. assert (optotrangestd / optotrange < 0.005).all()
  605. # get tranges and rasters for just the period before stimulus onset during which opto
  606. # can be on, i.e. from optotrange[0] to 0 s before each stimulus onset:
  607. blankdt = abs(optotrange[0]) # averages to 0.285 s across all experiments, as in movies
  608. print('grt blankdt:', blankdt)
  609. blanktranges = np.zeros_like(tranges)
  610. blanktranges[:, 0] = tranges[:, 0] - blankdt
  611. blanktranges[:, 1] = tranges[:, 0]
  612. blankrasters, _, _, _, _ = (spk & grt).get_rasters(tranges=blanktranges)
  613. grtopto = evcond & grt & 'ev_chan="opto1"'
  614. grtoptostimis, grtoptovals = grtopto.fetch('stim_id', 'ec_val')
  615. # calculate burst ratios for all units:
  616. firingpatterns = fp & grt & BURSTCRITERION
  617. brus = firingpatterns.burst_ratio(tranges=tranges)
  618. wbrus = firingpatterns.burst_ratio(tranges=wtranges)
  619. blankbrus = firingpatterns.burst_ratio(tranges=blanktranges)
  620. # make sure trial duration is very consistent:
  621. dts = tranges.ptp(axis=1)
  622. dt = dts.mean() # mean trial duration, for all trials (non-blank and blank)
  623. dtstd = dts.std()
  624. assert dtstd / dt < 0.005
  625. # make sure opto tranges are consistent relative to trial tranges:
  626. optotrialstimt0s = tranges[opto == True][:, 0]
  627. evtranges = (evtimes & grt).fetch1('ev_tranges')
  628. alloptotranges = evtranges - np.reshape(optotrialstimt0s, (-1, 1))
  629. optotrange = alloptotranges.mean(axis=0)
  630. optotrangestd = alloptotranges.std(axis=0)
  631. assert (optotrangestd / optotrange < 0.005).all()
  632. # make sure temporal freq is very consistent:
  633. tfreqs = (grat & grt).fetch('grat_temp_freq')
  634. tfreq = tfreqs.mean()
  635. tfreqstd = tfreqs.std()
  636. assert tfreqstd / tfreq < 0.005
  637. for st8, st8crit in st82crit.items():
  638. if st8crit == 'none': # 'none' is not in the State.Trial table
  639. st8trialis = alltrialis
  640. else:
  641. st8trialis = (st8t & grt & {'st8_crit':st8crit}).fetch('trial_id')
  642. st8trialis = np.sort(st8trialis)
  643. for opto in OPTOS:
  644. optois = grtoptovals == opto
  645. # limit stimis by opto:
  646. optostimis = grtoptostimis[optois]
  647. optotrials = trials & 'stim_id IN %s' % seq2sql(optostimis)
  648. blankcondoptotrials = blankcondtrials & 'stim_id IN %s' % seq2sql(optostimis)
  649. optotrialis = optotrials.fetch('trial_id')
  650. blankcondoptotrialis = blankcondoptotrials.fetch('trial_id')
  651. # intersect st8 and opto trialis:
  652. trialis = np.intersect1d(st8trialis, optotrialis)
  653. blankcondtrialis = np.intersect1d(st8trialis, blankcondoptotrialis)
  654. if len(trialis) == 0:
  655. print('%s %s has no trials, skipping' % (st8, opto))
  656. continue
  657. stimis = np.array([ triali2stimi[triali] for triali in trialis ]) # triali order
  658. oristimis, oris = (grat & grt).fetch('stim_id', 'grat_orientation')
  659. oris = np.hstack([ oris[oristimis == stimi] for stimi in stimis ]) # triali order
  660. stimsortis = stimis.argsort(kind='stable')
  661. trialis = trialis[stimsortis] # sorted by stimi
  662. stimis = stimis[stimsortis] # sorted by stimi
  663. oris = oris[stimsortis] # sorted by stimi
  664. for uid in uids:
  665. if uid not in rasters:
  666. continue
  667. mseu = {'m':grt['m'], 's':grt['s'], 'e':grt['e'], 'u':uid}
  668. mseustr = mseufmt.format(**mseu)
  669. ## TODO: this should really be done once, where grtmseus are defined,
  670. ## but I don't want to break things right now (2019-07-15)
  671. if not (vd.Units() & mseu & ('n_crit=%d' % VDNCRIT)
  672. & ('z_crit > %f' % VDZCRIT)):
  673. # not a visually responsive mseu (this query ignores st8 and opto)
  674. print('%s is not visually responsive, skipping' % mseustr)
  675. continue
  676. raster = rasters[uid][trialis] # subset of trials, sorted by stimi
  677. wraster = wrasters[uid][trialis] # subset of trials, sorted by stimi
  678. blankraster = blankrasters[uid][trialis] # subset of trials, sorted by stimi
  679. blankcondraster = rasters[uid][blankcondtrialis] # subset of blank condition trials
  680. rates = np.array([ len(trialspikes)/dt for trialspikes in raster ])
  681. blankrates = np.array([ len(bspikes)/blankdt for bspikes in blankraster ])
  682. assert len(blankrates) == len(rates) # should be same shape as rates
  683. blankcondrates = np.array([ len(bcspikes)/dt for bcspikes in blankcondraster ])
  684. meanrate = len(np.hstack(raster)) / dt / len(raster)
  685. #assert np.isclose(meanrate, rates.mean()) # sanity check
  686. blankmeanrate = len(np.hstack(blankraster)) / blankdt / len(blankraster)
  687. if len(blankcondraster) == 0:
  688. blankcondmeanrate = np.nan
  689. else:
  690. blankcondmeanrate = len(np.hstack(blankcondraster)) / dt / len(blankcondraster)
  691. if meanrate < RATETHRESH:
  692. print('%s %s %s does not meet RATETHRESH, skipping'
  693. % (mseustr, st8, opto))
  694. continue
  695. grtresprow = grtresp.loc[mseustr, st8, opto]
  696. grtresprow['dt'] = dt
  697. grtresprow['optotrange'] = optotrange
  698. # overload 'trialis' name in df - includes non-blank trials and blank condition
  699. # trials, in that order:
  700. grtresprow['trialis'] = np.concatenate([trialis, blankcondtrialis])
  701. grtresprow['raster'] = raster # sorted by stimi
  702. grtresprow['wraster'] = wraster
  703. grtresprow['rates'] = rates
  704. grtresprow['blankrates'] = blankrates
  705. grtresprow['blankcondrates'] = blankcondrates
  706. grtresprow['meanrate'] = meanrate
  707. grtresprow['blankmeanrate'] = blankmeanrate
  708. grtresprow['blankcondmeanrate'] = blankcondmeanrate
  709. grtresprow['stimis'] = stimis # one entry per row in raster
  710. grtresprow['oris'] = oris # one entry per row in raster
  711. grtresprow['rates'] = rates
  712. grtresprow['tfreq'] = tfreq
  713. tunrow = (tun & mseu & {'ivs_order':ivs_order, 'tun_model':tun_model,
  714. 'st8_crit':st8crit})
  715. if tunrow: # not empty
  716. tunparams = tunrow.fetch1('tun_pars') # 2x5 array of tuning params
  717. tunrsq = tunrow.fetch1('tun_rsq') # len 2 array of R2 values
  718. tunspon = tunrow.fetch1('tun_spon_mean') # len 2 array of spon activity
  719. oti = opto2tuni[opto] # opto Tuning index
  720. params = list(tunparams[oti]) + [(tunspon[oti])]
  721. grtresprow['tun'] = params # len 6 array: 5 tuning params + spon
  722. """[dp, rp, rn, r0, sigma, spon]:
  723. oripref, gauss ampl @ oripref, gauss ampl @ antipref, gauss offset,
  724. gauss width, spon level
  725. so index 1 wrt 3 and 5 is a measure of ori tuning strength"""
  726. grtresprow['tunrsq'] = tunrsq[oti]
  727. else:
  728. print('WARNING: empty Tuning row for %s, %s' % (mseu, st8crit))
  729. # fill in the corresponding trial burst ratios:
  730. for bru, wbru, blankbru in zip(brus, wbrus, blankbrus):
  731. if bru['u'] == uid:
  732. assert bru['m'] == mseu['m']
  733. assert bru['s'] == mseu['s']
  734. assert bru['e'] == mseu['e']
  735. burstis = np.array(bru['burst_index'], # narrow, ragged array
  736. dtype=object)[trialis]
  737. wburstis = np.array(wbru['burst_index'], # wide
  738. dtype=object)[trialis]
  739. burstratios = np.array(bru['burst_ratio'])[trialis] # narrow
  740. blankburstratios = np.array(blankbru['burst_ratio'])[trialis]
  741. blankcondburstratios = np.array(bru['burst_ratio'])[blankcondtrialis]
  742. grtresprow['burstis'] = burstis # narrow
  743. grtresprow['wburstis'] = wburstis # wide, for plotting
  744. grtresprow['burstratios'] = burstratios # narrow
  745. grtresprow['blankburstratios'] = blankburstratios # narrow
  746. grtresprow['blankcondburstratios'] = blankcondburstratios # narrow
  747. # nanmean can raise "Mean of empty slice" warnings if argument
  748. # is empty or is all nan, but that's OK, returns a nan in both cases:
  749. grtresprow['meanburstratio'] = np.nanmean(burstratios)
  750. grtresprow['blankmeanburstratio'] = np.nanmean(blankburstratios)
  751. if len(blankcondburstratios) == 0:
  752. grtresprow['blankcondmeanburstratio'] = np.nan
  753. else:
  754. grtresprow['blankcondmeanburstratio'] = np.nanmean(blankcondburstratios)
  755. break # out of bru loop
  756. # collect narrow burst and non-burst spike rasters:
  757. braster, nbraster = [], []
  758. for row, bis in zip(raster, burstis): # iterate over trials in raster
  759. if len(bis) == 0:
  760. bis = [] # none
  761. nbis = slice(None) # all
  762. else:
  763. allis = np.arange(len(row)) # all spikes in this trial
  764. nbis = np.setdiff1d(allis, bis)
  765. braster.append(row[bis])
  766. nbraster.append(row[nbis])
  767. braster = np.array(braster, dtype=object) # ragged array
  768. nbraster = np.array(nbraster, dtype=object)
  769. grtresprow['braster'] = braster # narrow burst raster
  770. grtresprow['nbraster'] = nbraster # narrow non-burst raster
  771. ## Generate grating ori tuning dataframe for saving to a .pickle, and for convenient
  772. ## offline tuning plots:
  773. grttundf = grttun.fetch(format='frame')
  774. grttuninfo = (TuningInfo() & grttun).fetch(format='frame') # contains stimulus oris
  775. # convert separate m, s, e, u columns in grttundf to a single mseu column in a new temporary
  776. # df without a MultiIndex:
  777. grttunresprows = []
  778. for rowindex, rowvalues in grttundf.iterrows():
  779. mid, sid, eid, ivs_order, tun_model, st8_type, st8_crit, uid = rowindex
  780. mean, sem, spon_mean, pars, rsq = (rowvalues['tun_mean'],
  781. rowvalues['tun_sem'],
  782. rowvalues['tun_spon_mean'],
  783. rowvalues['tun_pars'],
  784. rowvalues['tun_rsq'])
  785. mseustr = key2mseustr({'m':mid, 's':sid, 'e':eid, 'u':uid})
  786. ori = grttuninfo.loc[mid, sid, eid, ivs_order]['ti_axes'][:, 0]
  787. newrow = {'mseu':mseustr, 'ivs_order':ivs_order, 'tun_model':tun_model,
  788. 'st8_type':st8_type, 'st8_crit':st8_crit,
  789. 'ori':ori, 'mean':mean, 'sem':sem, 'spon_mean':spon_mean,
  790. 'pars':pars, 'rsq':rsq}
  791. grttunresprows.append(newrow)
  792. newindex = ['mseu', 'ivs_order', 'tun_model', 'st8_type', 'st8_crit']
  793. grttunresp = pd.DataFrame(grttunresprows).set_index(newindex)
  794. ## collect spontaneous meanrates and meanburstratios by mseu and opto cond,
  795. ## ignore run st8 for now:
  796. levels = [sponmseustrs, OPTOS]
  797. names = ['mseu', 'opto']
  798. mi = pd.MultiIndex.from_product(levels, names=names)
  799. columns = ['meanrate', 'meanburstratio']
  800. sponresp = pd.DataFrame(index=mi, columns=columns)
  801. for spon in spons:
  802. msestr = msefmt.format(**spon)
  803. print(msestr)
  804. #exp = spon['e']
  805. optotranges = (evtimes & spon).fetch1('ev_tranges')
  806. # ensure that intervals between end and start of consecutive opto pulses are at least
  807. # as long as the pulse length:
  808. optodts = np.diff(optotranges, axis=1) # opto pulse durations (sec)
  809. maxoptodt = optodts.max() # sec
  810. isitranges = np.vstack([optotranges[:-1, 0], optotranges[1:, 1]]).T # shape: npulses-1, 2
  811. assert np.diff(isitranges, axis=1).min() > maxoptodt # sec
  812. # use one opto pulse duration before each pulse as the control trange for that pulse:
  813. ctrltranges = optotranges - optodts
  814. ctrldts = np.diff(ctrltranges, axis=1) # control trange durations (sec)
  815. spkspon = spk & spon
  816. fpspon = fp & spon & BURSTCRITERION
  817. optorasters, _, _, optovals, _ = spkspon.get_rasters(optotranges, offsets=[0, 0])
  818. ctrlrasters, _, _, ctrlvals, _ = spkspon.get_rasters(ctrltranges, offsets=[0, 0])
  819. optobrus = fpspon.burst_ratio(tranges=optotranges)
  820. ctrlbrus = fpspon.burst_ratio(tranges=ctrltranges)
  821. assert optovals.all() == True
  822. assert ctrlvals.all() == False
  823. assert list(optorasters) == list(ctrlrasters) # should have exact same set of units
  824. uids = list(optorasters)
  825. opto2rasters = {False:ctrlrasters, True:optorasters}
  826. opto2brus = {False:ctrlbrus, True:optobrus}
  827. opto2totaldt = {False:ctrldts.sum(), True:optodts.sum()}
  828. for opto in OPTOS:
  829. rasters = opto2rasters[opto]
  830. brus = opto2brus[opto]
  831. totaldt = opto2totaldt[opto] # sec
  832. for uid, bru in zip(uids, brus):
  833. trialspikets = np.concatenate(rasters[uid]) # concatenate spikes from all "trials"
  834. mseu = {'m':spon['m'], 's':spon['s'], 'e':spon['e'], 'u':uid}
  835. mseustr = mseufmt.format(**mseu)
  836. sponresprow = sponresp.loc[mseustr, opto]
  837. nspikes = len(trialspikets)
  838. meanrate = nspikes / totaldt
  839. sponresprow['meanrate'] = meanrate
  840. assert bru['u'] == uid # sanity check
  841. burstratios = bru['burst_ratio'] # one per "trial"
  842. meanburstratio = np.nanmean(burstratios)
  843. sponresprow['meanburstratio'] = meanburstratio
  844. ## find best movie and grating response parameters per unit, across experiments:
  845. mvilevels = [mvimsustrs, MVIKINDS, ALLST8S, OPTOS]
  846. grtlevels = [grtmsustrs, ALLST8S, OPTOS]
  847. mvinames = ['msu', 'kind', 'st8', 'opto']
  848. grtnames = ['msu', 'st8', 'opto']
  849. mvimi = pd.MultiIndex.from_product(mvilevels, names=mvinames)
  850. grtmi = pd.MultiIndex.from_product(grtlevels, names=grtnames)
  851. bestmvicolumns = modmeasuresnoblankcond
  852. bestgrtcolumns = ['meanrate', 'blankmeanrate', 'meanburstratio', 'blankmeanburstratio']
  853. bestmviresp = pd.DataFrame(index=mvimi, columns=bestmvicolumns)
  854. bestgrtresp = pd.DataFrame(index=grtmi, columns=bestgrtcolumns)
  855. stimtype2resp = {'mvi':mviresp, 'grt':grtresp}
  856. stimtype2bestresp = {'mvi':bestmviresp, 'grt':bestgrtresp}
  857. for stimtype in STIMTYPES:
  858. resp, bestresp = stimtype2resp[stimtype], stimtype2bestresp[stimtype]
  859. columns = bestresp.columns
  860. for index, row in resp.iterrows():
  861. mseu = index[0]
  862. msu = mseustr2msustr(mseu)
  863. bestindex = msu, *index[1:]
  864. for column in columns:
  865. newval = row[column]
  866. oldbestval = bestresp[column][bestindex]
  867. # the biggest value isn't always the best, e.g. lower meanpkw values are better:
  868. if column in ['meanpkw']:
  869. if pd.isna(oldbestval) or (newval < oldbestval):
  870. bestresp[column][bestindex] = newval # overwrite
  871. else:
  872. if pd.isna(oldbestval) or (newval > oldbestval):
  873. bestresp[column][bestindex] = newval # overwrite
  874. ## collect movie model fits by kind, state and opto:
  875. #levels = [mvimseustrs, MVIKINDS, ALLST8S]
  876. levels = [mvimseustrs, ['nat'], ['none']]
  877. names = ['mseu', 'kind', 'st8']
  878. mi = pd.MultiIndex.from_product(levels, names=names)
  879. fitcolumns = ['m', 'th', 'rsq', 'nbm', 'nbth', 'nbrsq', 'bm', 'bth', 'brsq']
  880. sampfitcolumns = ['ms', 'ths', 'rsqs', 'nbms', 'nbths', 'nbrsqs', 'bms', 'bths', 'brsqs',
  881. 'nrms', 'nrths', 'nrrsqs']
  882. fits = pd.DataFrame(index=mi, columns=fitcolumns)
  883. sampfits = pd.DataFrame(index=mi, columns=sampfitcolumns)
  884. nsamples = 1000
  885. np.random.seed(0) # fix random seed for identical results from np.random.choice() on each run
  886. if LOADFITSFROMPICKLE:
  887. print('Loading fits')
  888. fits = load('fits')
  889. sampfits = load('sampfits')
  890. nsamples = len(sampfits[sampfits['ms'].notnull()]['ms'].iloc[0]) # convoluted, but works
  891. mseustrfititerator = []
  892. else:
  893. print('Calculating fits')
  894. mseustrfititerator = mvimseustrs
  895. for mseustr in mseustrfititerator:
  896. for kind in ['nat']: #MVIKINDS:
  897. for st8 in ['none']: #ALLST8S:
  898. print(mseustr, kind, st8)
  899. psth = mviresp.loc[mseustr, kind, st8]['psth']
  900. if psth.isna().any(): # test both ctrl and opto
  901. continue
  902. # init sampfits lists:
  903. sampfits.loc[mseustr, kind, st8] = [], [], [], [], [], [], [], [], [], [], [], []
  904. # iterate over (narrow) firing modes (all, non-burst, burst, non-rand):
  905. for fmode in fmodes:
  906. # first do single fit of all relevant trials, fit and test data are
  907. # identical (no trial resampling):
  908. if fmode != 'nr': # don't bother filling nr column in the fits array
  909. psth = mviresp.loc[mseustr, kind, st8][fmode+'psth']
  910. ctrlpsth, optopsth = psth[False], psth[True]
  911. mm, b, rsq = fitmodel(ctrlpsth, optopsth, ctrlpsth, optopsth,
  912. model=model)
  913. th = -b / mm # threshold, i.e., x intercept
  914. fits.loc[mseustr, kind, st8][fmode+'m'] = mm
  915. fits.loc[mseustr, kind, st8][fmode+'th'] = th
  916. fits.loc[mseustr, kind, st8][fmode+'rsq'] = rsq
  917. # now sample trials multiple times, separate into fit and test data:
  918. if fmode != 'nr':
  919. ctrlraster = mviresp.loc[mseustr, kind, st8, False][fmode+'raster']
  920. optoraster = mviresp.loc[mseustr, kind, st8, True][fmode+'raster']
  921. else: # fmode == 'nr'
  922. # find number of burst spikes per trial and condition,
  923. # remove that same number randomly from each trial in each condition:
  924. ctrlraster, optoraster = [], []
  925. ctrlbraster = mviresp.loc[mseustr, kind, st8, False]['braster']
  926. optobraster = mviresp.loc[mseustr, kind, st8, True]['braster']
  927. # find number of burst spikes per trial in both conditions:
  928. bctrlcounts = [ len(ctrlbtrial) for ctrlbtrial in ctrlbraster ]
  929. boptocounts = [ len(optobtrial) for optobtrial in optobraster ]
  930. # build nr rasters from full rasters, trial by trial:
  931. ctrlfullraster = mviresp.loc[mseustr, kind, st8, False]['raster']
  932. optofullraster = mviresp.loc[mseustr, kind, st8, True]['raster']
  933. for bctrlcount, ctrltrial in zip(bctrlcounts, ctrlfullraster):
  934. nspikes = len(ctrltrial) - bctrlcount # num spikes to sample
  935. resampctrltrial = choice(ctrltrial, size=nspikes, replace=False)
  936. ctrlraster.append(resampctrltrial)
  937. for boptocount, optotrial in zip(boptocounts, optofullraster):
  938. nspikes = len(optotrial) - boptocount # num spikes to sample
  939. resampoptotrial = choice(optotrial, size=nspikes, replace=False)
  940. optoraster.append(resampoptotrial)
  941. ctrlraster, optoraster = np.array(ctrlraster), np.array(optoraster)
  942. nctrltrials = len(ctrlraster) # typically 200 trials when st8 == 'none'
  943. noptotrials = len(optoraster) # typically 200 trials when st8 == 'none'
  944. for ntrials in [nctrltrials, noptotrials]:
  945. if ntrials < MINNTRIALSAMPLETHRESH:
  946. print('Not enough trials to sample:', mseustr, kind, st8, fmode)
  947. continue # don't bother sampling, leave entry in sampfits empty
  948. ctrltrialis = np.arange(nctrltrials)
  949. optotrialis = np.arange(noptotrials)
  950. ctrlsamplesize = intround(nctrltrials / 2) # half for fitting, half for testing
  951. optosamplesize = intround(noptotrials / 2) # half for fitting, half for testing
  952. # probably doesn't matter if use ctrl or opto bins, should be the same:
  953. bins = mviresp.loc[mseustr, kind, st8, False]['bins']
  954. for samplei in range(nsamples):
  955. # randomly sample half the ctrl and opto trials, without replacement:
  956. ctrlfitis = np.sort(choice(ctrltrialis, size=ctrlsamplesize, replace=False))
  957. optofitis = np.sort(choice(optotrialis, size=optosamplesize, replace=False))
  958. ctrltestis = np.setdiff1d(ctrltrialis, ctrlfitis) # get the complement
  959. optotestis = np.setdiff1d(optotrialis, optofitis) # get the complement
  960. ctrlfitraster, ctrltestraster = ctrlraster[ctrlfitis], ctrlraster[ctrltestis]
  961. optofitraster, optotestraster = optoraster[optofitis], optoraster[optotestis]
  962. # calculate fit and test opto PSTHs, subsampled in time:
  963. ctrlfitpsth = raster2psth(ctrlfitraster, bins, binw, tres, kernel)[::ssx]
  964. optofitpsth = raster2psth(optofitraster, bins, binw, tres, kernel)[::ssx]
  965. ctrltestpsth = raster2psth(ctrltestraster, bins, binw, tres, kernel)[::ssx]
  966. optotestpsth = raster2psth(optotestraster, bins, binw, tres, kernel)[::ssx]
  967. #if np.isnan(ctrlfitpsth).any() or np.isnan(optofitpsth).any():
  968. # continue
  969. #assert not (np.isnan(ctrltestpsth).any() or np.isnan(optotestpsth).any())
  970. mm, b, rsq = fitmodel(ctrlfitpsth, optofitpsth, ctrltestpsth, optotestpsth,
  971. model=model)
  972. #if np.isnan(rsq) or rsq < RSQTHRESH: # poor fit
  973. # continue # skip this sample
  974. th = -b / mm # threshold, i.e., x intercept
  975. sampfits.loc[mseustr, kind, st8][fmode+'ms'].append(mm)
  976. sampfits.loc[mseustr, kind, st8][fmode+'ths'].append(th)
  977. sampfits.loc[mseustr, kind, st8][fmode+'rsqs'].append(rsq)
  978. ## FMI (feedback modulation index: (feedback - suppression) / (feedback + suppression)) and
  979. ## RMI (run modulation index: (run - sit) / (run + sit)) for each mseu,
  980. ## for nat mvi and grt experiments, for all measures.
  981. # set up mviFMI DataFrame:
  982. levels = [mvimseustrs, ALLST8S]
  983. names = ['mseu', 'st8']
  984. mi = pd.MultiIndex.from_product(levels, names=names)
  985. mviFMI = pd.DataFrame(index=mi, columns=modmeasures)
  986. # set up grtFMI DataFrame:
  987. levels = [grtmseustrs, ALLST8S]
  988. names = ['mseu', 'st8']
  989. mi = pd.MultiIndex.from_product(levels, names=names)
  990. grtFMI = pd.DataFrame(index=mi, columns=modmeasures) # lots of unused columns
  991. # set up mviRMI DataFrame:
  992. levels = [mvimseustrs, OPTOS]
  993. names = ['mseu', 'opto']
  994. mi = pd.MultiIndex.from_product(levels, names=names)
  995. mviRMI = pd.DataFrame(index=mi, columns=modmeasures)
  996. # set up grtRMI DataFrame:
  997. levels = [grtmseustrs, OPTOS]
  998. names = ['mseu', 'opto']
  999. mi = pd.MultiIndex.from_product(levels, names=names)
  1000. grtRMI = pd.DataFrame(index=mi, columns=modmeasures)
  1001. stimtype2FMI = {'mvi': mviFMI, 'grt': grtFMI}
  1002. stimtype2RMI = {'mvi': mviRMI, 'grt': grtRMI}
  1003. for stimtype in STIMTYPES:
  1004. if stimtype == 'mvi':
  1005. resp = mviresp.xs('nat', level='kind') # keep only nat movie measures
  1006. else:
  1007. resp = grtresp
  1008. fmidf = stimtype2FMI[stimtype]
  1009. rmidf = stimtype2RMI[stimtype]
  1010. for st8 in ALLST8S:
  1011. for measure in modmeasures:
  1012. if measure not in resp:
  1013. continue
  1014. measurevals = resp[measure][(slice(None), st8)] # mseustr, opto
  1015. feedback = measurevals[(slice(None), False)] # mseustr
  1016. suppress = measurevals[(slice(None), True)] # mseustr
  1017. assert (feedback.dropna() >= 0).all() # negative values violate -1 <= FMI <= 1
  1018. assert (suppress.dropna() >= 0).all()
  1019. # filter out cases where measureval in both conditions is 0:
  1020. keepis = (feedback != 0) | (suppress != 0)
  1021. if stimtype == 'mvi': # also filter by SNRTHRESH:
  1022. maxsnrs = resp['snr'][(slice(None), st8)].max(level='mseu') # max SNR per mseu
  1023. keepis = keepis & (maxsnrs >= SNRTHRESH)
  1024. feedback, suppress = feedback[keepis], suppress[keepis]
  1025. fmis = (feedback - suppress) / (feedback + suppress) # indexed by mseustr
  1026. for mseustr, fmi in zip(fmis.index.values, fmis):
  1027. fmidf.loc[mseustr, st8][measure] = fmi # fill in the appropriate rows
  1028. for opto in OPTOS:
  1029. for measure in modmeasures:
  1030. if measure not in resp:
  1031. continue
  1032. measurevals = resp[measure][(slice(None), slice(None), opto)] # mseustr, st8
  1033. runvals = measurevals[(slice(None), 'run')] # mseustr
  1034. sitvals = measurevals[(slice(None), 'sit')] # mseustr
  1035. assert (runvals.dropna() >= 0).all() # negative values violate -1 <= RMI <= 1
  1036. assert (sitvals.dropna() >= 0).all()
  1037. # filter out cases where measureval in both conditions is 0:
  1038. keepis = (runvals != 0) | (sitvals != 0)
  1039. if stimtype == 'mvi': # also filter by SNRTHRESH:
  1040. maxsnrs = resp['snr'][(slice(None), slice(None), opto)].max(level='mseu')
  1041. keepis = keepis & (maxsnrs >= SNRTHRESH)
  1042. runvals, sitvals = runvals[keepis], sitvals[keepis]
  1043. rmis = (runvals - sitvals) / (runvals + sitvals) # indexed by mseustr
  1044. for mseustr, rmi in zip(rmis.index.values, rmis):
  1045. rmidf.loc[mseustr, opto][measure] = rmi # fill in the appropriate rows
  1046. ## maxFMI and maxRMI for each unit (msu), for nat mvi and grt experiments, for all measures.
  1047. ## For each stimtype+st8+measure+unit combination, take the abs max across experiments. Also
  1048. ## store the argmax (i.e., the mseu). Among other things, this is for comparison
  1049. ## of movie and grating responses.
  1050. # set up maxFMI DataFrame:
  1051. levels = [mvigrtmsustrs, ALLST8S, STIMTYPES]
  1052. names = ['msu', 'st8', 'stimtype']
  1053. mi = pd.MultiIndex.from_product(levels, names=names)
  1054. columns = ['mseu'] + modmeasures
  1055. maxFMI = pd.DataFrame(index=mi, columns=columns)
  1056. # set up maxRMI DataFrame:
  1057. levels = [mvigrtmsustrs, OPTOS, STIMTYPES]
  1058. names = ['msu', 'opto', 'stimtype']
  1059. mi = pd.MultiIndex.from_product(levels, names=names)
  1060. columns = ['mseu'] + modmeasures
  1061. maxRMI = pd.DataFrame(index=mi, columns=columns)
  1062. for stimtype in STIMTYPES:
  1063. fmidf = stimtype2FMI[stimtype]
  1064. rmidf = stimtype2RMI[stimtype]
  1065. for st8 in ALLST8S:
  1066. for measure in modmeasures:
  1067. mseustrs = fmidf.index.levels[0].values
  1068. for mseustr in mseustrs:
  1069. fmi = fmidf.loc[mseustr, st8][measure]
  1070. if pd.isna(fmi):
  1071. continue
  1072. msustr = mseustr2msustr(mseustr)
  1073. oldfmi = maxFMI.loc[msustr, st8, stimtype][measure] # check existing val
  1074. if pd.isna(oldfmi) or (abs(fmi) > abs(oldfmi)):
  1075. maxFMI.loc[msustr, st8, stimtype][measure] = fmi # overwrite
  1076. maxFMI.loc[msustr, st8, stimtype]['mseu'] = mseustr # overwrite
  1077. for opto in OPTOS:
  1078. for measure in modmeasures:
  1079. mseustrs = rmidf.index.levels[0].values
  1080. for mseustr in mseustrs:
  1081. rmi = rmidf.loc[mseustr, opto][measure]
  1082. if pd.isna(rmi):
  1083. continue
  1084. msustr = mseustr2msustr(mseustr)
  1085. oldrmi = maxRMI.loc[msustr, opto, stimtype][measure] # check existing val
  1086. if pd.isna(oldrmi) or (abs(rmi) > abs(oldrmi)):
  1087. maxRMI.loc[msustr, opto, stimtype][measure] = rmi # overwrite
  1088. maxRMI.loc[msustr, opto, stimtype]['mseu'] = mseustr # overwrite
  1089. ## calculate eye position stdev and pupil area during all 3 stimtypes,
  1090. ## as a function of run state and opto conditions, store in ipos_st8 and ipos_opto DataFrames:
  1091. # Requires huxley file system to be mounted. Author: Davide Crombie
  1092. # parameters for cleaning up the pupil displacement data:
  1093. min0t = 1 # only interpolate over stretches missing data shorter than min0t (s)
  1094. min1t = 1 # only keep stretches of continuous data longer than min1t (s)
  1095. min1prop = .8 # only keep stretches of data with at least 80% non-NaN
  1096. st8_rows = [] # to be filled with dictionaries containing relevant data
  1097. opto_rows = []
  1098. print('Calculating ipos_st8 and ipos_opto DataFrames...')
  1099. for stimtype, stims in zip(['mvi', 'grt', 'spon'], [mvis, grts, spons]):
  1100. print('...for stimtype %r' % stimtype)
  1101. stimeye = EyeIxtract() & stims
  1102. for key in stimeye.get_keys(): # these are mse keys
  1103. if stimtype == 'mvi':
  1104. moviecond = mov & key
  1105. kind2stimis = moviecond.moviekind2stimis() # movie kind to stimi mapping
  1106. # stim_id values considered in this analysis:
  1107. kindstimis = tuple(kind2stimis['nat'])
  1108. if len(kindstimis) == 0:
  1109. continue # wrong kind of movie, e.g. PNK or SHF
  1110. msestr = "%s_s%02d_e%02d" % (key['m'], key['s'], key['e'])
  1111. print(msestr)
  1112. # fetch and clean the eye position data:
  1113. irow = EyeIxtract() & key
  1114. ts = irow.fetch1('i_frame_times')
  1115. eyedt = np.diff(ts).mean() # eye sampling period
  1116. # x & y pupil positions relative to baseline (median) in deg. visual angle,
  1117. # after subtraction of corneal reflection position:
  1118. xpos, ypos = irow.fetch_position(cr_sub=True, deg=True) # in deg
  1119. area = irow.fetch_area(mm=False, normalize=True) # in pix
  1120. # position timecourses after elimintation of unclean stretches, smoothing,
  1121. # and interpolation of remaining NaN values within clean stretches:
  1122. min0len, min1len = min0t/eyedt, min1t/eyedt # convert from s to n timepoints
  1123. xpos_clean = i_clean(xpos, min0len=min0len, min1len=min1len, min1prop=min1prop)
  1124. ypos_clean = i_clean(ypos, min0len=min0len, min1len=min1len, min1prop=min1prop)
  1125. area_clean = i_clean(area, min0len=min0len, min1len=min1len, min1prop=min1prop)
  1126. conditions = {}
  1127. # get keys for trial table restriction based on st8:
  1128. conditions['run'] = (st8t & key & {'st8_crit':st82crit['run']}).get_keys()
  1129. conditions['sit'] = (st8t & key & {'st8_crit':st82crit['sit']}).get_keys()
  1130. if len(conditions['run']) == 0: # perhaps run data wasn't acquired for this exp
  1131. assert len(conditions['sit']) == 0 # sanity check
  1132. del conditions['run'] # del empty list
  1133. del conditions['sit'] # del empty list
  1134. # get keys for trial table restriction based on opto condition:
  1135. if stimtype == 'mvi':
  1136. if len(kindstimis) > 1: # only include opto conditions correspond to nat movie
  1137. conditions[False] = (evcond & key & 'ev_chan="opto1"' & 'ec_val=0' &
  1138. 'stim_id IN {}'.format(kindstimis)).get_keys()
  1139. conditions[True] = (evcond & key & 'ev_chan="opto1"' & 'ec_val=1' &
  1140. 'stim_id IN {}'.format(kindstimis)).get_keys()
  1141. else:
  1142. conditions[False] = (evcond & key & 'ev_chan="opto1"' & 'ec_val=0').get_keys()
  1143. conditions[True] = (evcond & key & 'ev_chan="opto1"' & 'ec_val=1').get_keys()
  1144. # compute eye position variance for all conditions:
  1145. for condition in conditions: # iterate over all existing keys in conditions
  1146. row = {}
  1147. row['mse'] = msestr
  1148. row['stimtype'] = stimtype
  1149. # start and stop times for trials of the current state only:
  1150. ## TODO: spons don't have trials, just periods of opto and non-opto
  1151. trialis, t0s, t1s = (trial & conditions[condition]).fetch('trial_id',
  1152. 'trial_on_time',
  1153. 'trial_off_time')
  1154. if len(t0s) == 0 or len(t1s) == 0:
  1155. continue
  1156. # sort pupil displacement data into trial matrices:
  1157. xpos_trialmat, _ = get_trialmat(xpos_clean, ts, t0s, t1s)
  1158. ypos_trialmat, _ = get_trialmat(ypos_clean, ts, t0s, t1s)
  1159. area_trialmat, area_trialts = get_trialmat(area_clean, ts, t0s, t1s)
  1160. area_trialmean = np.nanmean(area_trialmat, axis=1) # mean pupil area of each trial
  1161. # save trial IDs:
  1162. row['trialis'] = trialis
  1163. # save mean pupil area for each trial:
  1164. row['area_trialmean'] = area_trialmean
  1165. # save pupil area signal for each trial:
  1166. row['area_trialmat'] = area_trialmat
  1167. row['area_trialts'] = area_trialts
  1168. # save mean pupil area across trials:
  1169. row['mean_area'] = np.nanmean(area_trialmean)
  1170. # save eye position variability across trials:
  1171. row['std_xpos_cross'] = np.nanstd(xpos_trialmat, axis=0).mean()
  1172. row['std_ypos_cross'] = np.nanstd(ypos_trialmat, axis=0).mean()
  1173. # save eye position variability within trials:
  1174. row['std_xpos_within'] = np.nanmean(np.nanstd(xpos_trialmat, axis=1))
  1175. row['std_ypos_within'] = np.nanmean(np.nanstd(ypos_trialmat, axis=1))
  1176. # add dictionary to appropriate list:
  1177. if condition in ['run', 'sit']:
  1178. row['st8'] = condition
  1179. st8_rows.append(row)
  1180. elif condition in [False, True]:
  1181. row['opto'] = condition
  1182. opto_rows.append(row)
  1183. # convert list of dicts to DataFrame:
  1184. ipos_st8 = pd.DataFrame(st8_rows)
  1185. if len(st8_rows) > 0:
  1186. ipos_st8.set_index(['mse', 'stimtype', 'st8'], inplace=True) # convert to MultiIndex
  1187. ipos_opto = pd.DataFrame(opto_rows)
  1188. if len(opto_rows) > 0:
  1189. ipos_opto.set_index(['mse', 'stimtype', 'opto'], inplace=True)
  1190. ## collect unit spatial position, store in upos DataFrame:
  1191. rows = [ {'msu':msu2msustr(msu), 'x':x, 'y':y}
  1192. for (msu, x, y) in zip(*mviunits.fetch(dj.key, 'u_xpos', 'u_ypos')) ]
  1193. upos = pd.DataFrame(rows)
  1194. upos = upos.set_index(['msu'])
  1195. ## collect runspeed traces, store in runspeed DataFrame:
  1196. runspeedrows = [] # to be filled with dictionaries containing relevant data
  1197. for mvi in mvis.fetch(dj.key):
  1198. msestr = "%s_s%02d_e%02d" % (mvi['m'], mvi['s'], mvi['e'])
  1199. print(msestr)
  1200. # fetch run data for all trials in this mse, each is an array, sometimes some trials
  1201. # are missing from the run data:
  1202. runtrialis, runt, runspeed = (rt & mvi).fetch('trial_id', 'run_t', 'run_speed')
  1203. if len(runspeed) == 0:
  1204. continue # no run data for this mse
  1205. moviecond = mov & mvi
  1206. kind2stimis = moviecond.moviekind2stimis() # movie kind to stimi mapping
  1207. #trials = trial & mvi
  1208. #alltrialis = np.sort(trials.fetch('trial_id')) # trialis corresponding to rows in tranges
  1209. for kind in ['nat']: #MVIKINDS:
  1210. kindstimis = kind2stimis[kind]
  1211. if len(kindstimis) == 0:
  1212. continue # wrong kind of movie
  1213. # get stimis from evcond table for this mvi and kind, and restrict to opto1 chan:
  1214. mviopto = evcond & mvi & 'stim_id IN %s' % seq2sql(kindstimis) & {'ev_chan':'opto1'}
  1215. mvioptostimis, mvioptovals = mviopto.fetch('stim_id', 'ec_val')
  1216. assert len(mvioptostimis) == 2
  1217. assert len(mvioptovals) == 2
  1218. mvioptovals = np.bool_(mvioptovals) # convert from [0, 1] to [False, True]
  1219. opto2stimi = {False:int(mvioptostimis[mvioptovals == False]),
  1220. True :int(mvioptostimis[mvioptovals == True])}
  1221. for opto in OPTOS:
  1222. row = {}
  1223. row['mse'] = msestr
  1224. row['kind'] = kind
  1225. row['opto'] = opto
  1226. # extract run trials for the stimi corresponding to this kind and opto:
  1227. stimi = opto2stimi[opto]
  1228. stimtrialis = np.sort((trial & mvi & {'stim_id':stimi}).fetch('trial_id'))
  1229. # make sure that stimtrialis is a subset of runtrialis:
  1230. assert issubset(stimtrialis, runtrialis)
  1231. # now it's safe to ask where stimtrialis are found in runtrialis:
  1232. trialiis = runtrialis.searchsorted(stimtrialis)
  1233. #print(opto, len(stimtrialis), len(runtrialis))
  1234. row['trialis'] = runtrialis[trialiis]
  1235. row['t'] = runt[trialiis]
  1236. row['speed'] = runspeed[trialiis]
  1237. runspeedrows.append(row)
  1238. runspeed = pd.DataFrame(runspeedrows)
  1239. if len(rows) > 0:
  1240. runspeed = runspeed.set_index(['mse', 'kind', 'opto']) # convert to MultiIndex
  1241. ## classify units in various ways, store classification in celltype DataFrame:
  1242. # use normal Index, not MultiIndex, since there's only a single index column, otherwise
  1243. # assigning values with `df.loc[indexval][colname] = foo` won't work
  1244. index = pd.Index(mvigrtmsustrs, name='msu')
  1245. columns = ['sbczscore', 'sbc', 'dsi', 'depth', 'normdepth', 'onoff', 'trans',
  1246. 'mvionsetpsth', 't']
  1247. celltype = pd.DataFrame(index=index, columns=columns)
  1248. print('Classifying SbC')
  1249. # NOTE: seems that the zscores stored in the Condition.ZScores table are only
  1250. # for control conditions, i.e. opto == False, which is good
  1251. SBCZSCORETHRESH = -3 # 0.0, -1.96, -2.58, -3
  1252. # classify by SbC, keep the most negative sbczscore of each msu across grating experiments:
  1253. msustr2maxzsmseu = {} # map msustr to most negative sbczscore mseu key
  1254. for grt in grts:
  1255. zscores = zs & grt
  1256. uids = (up & grt).fetch('u')
  1257. for uid in uids:
  1258. msu = {'m':grt['m'], 's':grt['s'], 'u':uid}
  1259. mseu = {'m':grt['m'], 's':grt['s'], 'e':grt['e'], 'u':uid}
  1260. msustr = msufmt.format(**msu)
  1261. mseustr = mseufmt.format(**mseu)
  1262. unitzscore = zscores & {'u':uid}
  1263. if len(unitzscore) == 0:
  1264. print('No Condition.Zscores entry for', mseustr)
  1265. continue
  1266. sbckeys = unitzscore.SbC()
  1267. assert len(sbckeys) == 1
  1268. sbczscore = sbckeys[0]['sbc_zscore'] # median, across ori conditions
  1269. if np.isnan(sbczscore):
  1270. print('WARNING: SbC() returned NaN for', mseustr)
  1271. meanrate = grtresp.loc[mseustr, 'none', False]['meanrate']
  1272. if np.isnan(meanrate):
  1273. continue # skip this non visually responsive mseu, based on z_crit and n_crit
  1274. oldsbczscore = celltype.loc[msustr]['sbczscore']
  1275. if np.isnan(oldsbczscore) or sbczscore < oldsbczscore: # keep most -ve sbczscore
  1276. celltype.loc[msustr]['sbczscore'] = sbczscore # update
  1277. msustr2maxzsmseu[msustr] = mseu
  1278. # classify units as SbC == [False, True] according to SBCZSCORETHRESH:
  1279. for msustr in mvigrtmsustrs:
  1280. sbczscore = celltype.loc[msustr]['sbczscore']
  1281. if np.isnan(sbczscore):
  1282. sbc = np.nan # this is the default anyway
  1283. elif sbczscore <= SBCZSCORETHRESH:
  1284. sbc = True
  1285. else:
  1286. sbc = False
  1287. celltype.loc[msustr]['sbc'] = sbc
  1288. # plot rasters, for testing SBCZSCORETHRESH, and zscores in general:
  1289. try:
  1290. mseu = msustr2maxzsmseu[msustr]
  1291. except KeyError: # not a single exp with an sbczscore found for this msu
  1292. print('Did not collect an sbczscore for', msustr)
  1293. continue
  1294. '''
  1295. # plot the non-blank, opto == False trials of this msu:
  1296. offsets = [-1, 1]
  1297. rasters, tranges, allstimis, opto, ukeys = (spk & mseu).get_rasters(offsets=offsets)
  1298. assert len(rasters) == 1
  1299. uid, raster = rasters.popitem()
  1300. stimis = allstimis[opto == False]
  1301. raster = raster[opto == False]
  1302. blankstimis = (grat & mseu & 'stim_id IN %s' % seq2sql(stimis)
  1303. & 'grat_contrast = 0.0').fetch('stim_id')
  1304. assert len(blankstimis) == 1
  1305. blankstimi = blankstimis[0] # should be 24
  1306. raster = raster[stimis != blankstimi] # exclude blank trials
  1307. title = "%s, zscore=%.3f, SbC=%s" % (msustr, sbczscore, sbc)
  1308. ax = simpletraster(raster, dt=5, offsets=offsets, scatter=True, title=title)
  1309. ax.set_title(title)
  1310. '''
  1311. # classify by direction selectivity, keep biggest DSI of each msu across grating experiments:
  1312. print('Classifying DSI')
  1313. resp = grtresp.xs('none', level='st8') # keep only none run state
  1314. resp = resp.xs(False, level='opto') # keep only control opto state
  1315. for mseustr, row in resp.iterrows():
  1316. print(mseustr)
  1317. tunparams = row['tun']
  1318. if np.isnan(tunparams).all():
  1319. continue # skip this mseustr
  1320. dp, rp, rn, r0, sigma, spon = tunparams
  1321. dsi = niell_DSI(rp, rn, r0)
  1322. msustr = mseustr2msustr(mseustr)
  1323. celltyperow = celltype.loc[msustr]
  1324. olddsi = celltyperow['dsi']
  1325. if np.isnan(olddsi) or dsi > olddsi:
  1326. celltyperow['dsi'] = dsi # update
  1327. # estimate dLGN depth of each unit using sparse noise MUA RFs to help delineate
  1328. # at what channel dLGN starts and ends:
  1329. print('Estimating dLGN depth')
  1330. # first plot MUA RFs for each sparse noise experiment in each series, visually inspect them,
  1331. # and manually enter estimated start and end channel of dLGN for each series into
  1332. # msstr2dLGNchanrange:
  1333. for series in mviseries:
  1334. snkeys = (Series.Experiment() & series & 'e_name LIKE "%%Noise%%"').fetch(dj.key)
  1335. for snkey in snkeys:
  1336. print('Sparse noise exp:', snkey)
  1337. rfs, axes, chans = rfs_for_cdata(snkey)
  1338. # plot MUA RFs for inspection
  1339. simple_plot_rfs(snkey, rfs, axes, chans, interpolation=None, nrows=4)
  1340. # Now save all plots to disk in MUA_RFs folder:
  1341. saveall(format='pdf')
  1342. saveall(format='png')
  1343. # Next, manually map MOUSE_SERIES string to tuple of uppermost and lowermost clear dLGN
  1344. # channel (end inclusive) based on qualitative estimation of start and end of MUA RF
  1345. # progression of each series, potentially considering multiple sparse noise experiments
  1346. # per series:
  1347. ## TODO: switch to using LGNChans table and its ms2maxchanrange() method instead of
  1348. ## duplicating it here:
  1349. msstr2dLGNchanrange = {'PVCre_2017_0006_s03': (29, 8),
  1350. 'PVCre_2017_0008_s09': (25, 7),
  1351. 'PVCre_2017_0008_s12': (24, 2),
  1352. 'PVCre_2017_0015_s03': (28, 5),
  1353. 'PVCre_2017_0015_s07': (19, 1),
  1354. 'PVCre_2018_0001_s02': (11, 1),
  1355. 'PVCre_2018_0001_s05': (25, 6),
  1356. 'PVCre_2018_0003_s02': (20, 3),
  1357. 'PVCre_2018_0003_s03': (14, 2),
  1358. 'PVCre_2018_0003_s05': (20, 9),
  1359. 'PVCre_2019_0002_s08': (21, 1),
  1360. 'Ntsr1Cre_2020_0001_s04': (31, 5),
  1361. }
  1362. # get dLGN depth of each unit in each series, add to celltype df:
  1363. for series in mviseries:
  1364. msstr = ms2msstr(series)
  1365. hichan, lochan = msstr2dLGNchanrange[msstr] # hi and lo chans (in terms of depth)
  1366. # could use coords entered into ProbeModel table, but easier to use SiteLoc dict
  1367. # in probes.expio:
  1368. #probemodel, chans, coords = (pm & (pr & series)).fetch1('pm', 'pm_chans', 'pm_coords')
  1369. probemodel = (pr & series).fetch1('pm') # string
  1370. probe = expio.probes.getprobe(probemodel)
  1371. # y coord of upper and lowermost dLGN chan, use upper as reference and upper-lower as depth
  1372. # normalization for all units in this series:
  1373. y0 = probe.SiteLoc[hichan][1]
  1374. y1 = probe.SiteLoc[lochan][1]
  1375. uids, ys = (up & series).fetch('u', 'u_ypos')
  1376. depths = ys - y0
  1377. dLGNspan = y1 - y0
  1378. normdepths = depths / dLGNspan
  1379. for uid, depth, normdepth in zip(uids, depths, normdepths):
  1380. msustr = msstr + '_u%02d' % uid
  1381. if depth < 0:
  1382. print('WARNING: %s has depth %.1f < 0, skipping entry in celltype df'
  1383. % (msustr, depth))
  1384. continue
  1385. if normdepth > 1.5:
  1386. print('WARNING: %s has normdepth %.1f > 1.5, skipping entry in celltype df'
  1387. % (msustr, normdepth))
  1388. continue
  1389. celltype.loc[msustr, 'depth'] = depth
  1390. celltype.loc[msustr, 'normdepth'] = normdepth
  1391. '''
  1392. # classify by layer (shell/core), pretty much deprecated, don't trust the code in rf.py
  1393. # that does all of this...
  1394. # NOTE: several (13/80 as of 2020-02-11) msu end up with no layer classification because
  1395. # they fall outside of the detected range of chans with MUA envelope RFs of sufficient score
  1396. print('Classifying shell/core')
  1397. print('NOTE: something is wrong with this code, hangs on first series at full CPU')
  1398. for series in mviseries:
  1399. try:
  1400. unitlayers = rf.get_layer(series, snexpid=-1) # list of dicts
  1401. except RuntimeError as err:
  1402. print(err)
  1403. continue # no .envl.dat file?
  1404. for k in unitlayers:
  1405. msu = {'m':k['m'], 's':k['s'], 'u':k['u']}
  1406. msustr = msufmt.format(**msu)
  1407. celltype.loc[msustr]['layer'] = k['layer']
  1408. '''
  1409. # classify as ON/OFF/transient according to movie onset response. Movies are generally
  1410. # darker than mid grey, so a decrease in PSTH after onset is an ON response.
  1411. # After examining lots of raster plots, simplest and most informative may be
  1412. # +/- 0.2 sec on either side of movie onset, after taking into account 0.05 s propagation
  1413. # delay, i.e. from -0.15 to 0 as the spont (pre) rate, and from 0.05 to 0.25 as the evoked
  1414. # (post) rate. Taking just a short period of time around the movie onset helps reduce the
  1415. # impact of movie motion on the result
  1416. print('Classifying ON/OFF/transient')
  1417. resp = mviresp.xs('nat', level='kind') # keep only nat movie kind
  1418. resp = resp.xs('none', level='st8') # keep only none run state
  1419. resp = resp.xs(False, level='opto') # keep only control opto state
  1420. for mseustr, row in resp.iterrows():
  1421. print(mseustr)
  1422. wpsth = row['wpsth']
  1423. wt = row['wt']
  1424. if np.isnan(wt).any():
  1425. continue # skip this mseustr
  1426. prei0, prei1 = wt.searchsorted([PDT-WINDT, EXCLTR[0]])
  1427. posti0, posti1 = wt.searchsorted([EXCLTR[1], PDT+WINDT])
  1428. mvionsetpsth = wpsth[prei0:posti1]
  1429. t = wt[prei0:posti1]
  1430. prepsth = wpsth[prei0:prei1]
  1431. postpsth = wpsth[posti0:posti1]
  1432. insidepsth = wpsth[prei1:posti0]
  1433. outsidepsth = np.concatenate([prepsth, postpsth])
  1434. pre = prepsth.mean()
  1435. post = postpsth.mean()
  1436. ins = insidepsth.mean()
  1437. out = outsidepsth.mean()
  1438. if pre == 0.0 or post == 0.0 or ins == 0.0 or out == 0.0:
  1439. print('No pre/post/inside/outside firing rate, skipping')
  1440. continue
  1441. onoff = (pre - post)/(pre + post)
  1442. trans = (ins - out)/(ins + out)
  1443. Z = max(prepsth.max(), postpsth.max()) # normalization factor, ignore transient period
  1444. msustr = mseustr2msustr(mseustr)
  1445. celltyperow = celltype.loc[msustr]
  1446. oldonoff = celltyperow['onoff']
  1447. oldtrans = celltyperow['trans']
  1448. if np.isnan(oldonoff) or (abs(onoff) > abs(oldonoff)):
  1449. celltyperow['onoff'] = onoff # update
  1450. celltyperow['mvionsetpsth'] = mvionsetpsth / Z # normalize for plotting, update
  1451. celltyperow['t'] = t # update
  1452. if np.isnan(oldtrans) or (abs(trans) > abs(oldtrans)):
  1453. celltyperow['trans'] = trans # update
  1454. ## code for avoiding annoying dataframe pointer vs copy issues:
  1455. '''
  1456. # overwriting/resetting values of an existing column:
  1457. celltype['sbc'][:] = np.nan # I think the explicit slice makes the difference
  1458. # code for adding a new column to an existing dataframe:
  1459. celltype = celltype.copy()
  1460. celltype['dsi'] = np.nan
  1461. celltype['dsi'] = celltype['dsi'].astype(object)
  1462. celltype = celltype.copy()
  1463. celltype['onoff'] = np.nan
  1464. celltype['trans'] = np.nan
  1465. celltype['mvionsetpsth'] = np.nan
  1466. celltype['t'] = np.nan
  1467. celltype['onoff'] = celltype['onoff'].astype(object)
  1468. celltype['trans'] = celltype['trans'].astype(object)
  1469. celltype['mvionsetpsth'] = celltype['mvionsetpsth'].astype(object)
  1470. celltype['t'] = celltype['t'].astype(object)
  1471. '''
  1472. ## collect MUA envelope RF positions for each msu, to use in calculating how well centered
  1473. ## the screen was on the RF of each unit:
  1474. index = pd.Index(mvigrtmsustrs, name='msu')
  1475. columns = ['x0', 'y0']
  1476. cellscreenpos = pd.DataFrame(index=index, columns=columns)
  1477. print('Getting RF params from MUA envelope')
  1478. snexpid = -1 # use last one in each series
  1479. MUAENVLRSQTHRESH = 0.2 # consider only decent MUA envelope RF fits
  1480. series2rffitdf = {} # collect all RF fit params for all series
  1481. for series in mviseries:
  1482. erows = Series.Experiment() & series & 'e_name LIKE "%%Noise%%"'
  1483. expvals = erows.fetch(dj.key, 'e_name', 'e_displayazim', 'e_displayelev')
  1484. snkeys, enames, azims, elevs = expvals
  1485. snkey, ename, azim, elev = (snkeys[snexpid], enames[snexpid],
  1486. azims[snexpid], elevs[snexpid])
  1487. print('Sparse noise exp:', snkey, ename)
  1488. rfs, axes, chans = rfs_for_cdata(snkey)
  1489. #simple_plot_rfs(snkey, rfs, axes, chans, interpolation=None, nrows=4)
  1490. fitdf = fit_MUA_RFs(snkey, rfs, chans)
  1491. msestr = key2msestr(snkey)
  1492. msstr = mseustr2mstr(msestr)
  1493. series2rffitdf[msstr] = fitdf # save
  1494. chans = fitdf.index.levels[0].to_numpy() # all the chans in the df
  1495. chan2screenpos = {}
  1496. for chan in chans:
  1497. chanrow = fitdf.loc[chan, 'mix']
  1498. if chanrow['rsq'] < MUAENVLRSQTHRESH:
  1499. continue
  1500. # remove azim and elev correction to get RF position wrt screen center:
  1501. chan2screenpos[chan] = chanrow['x0']-azim, chanrow['y0']-elev
  1502. urows = up & series
  1503. uids, maxchans = urows.fetch('u', 'u_maxchan')
  1504. for uid, maxchan in zip(uids, maxchans):
  1505. if maxchan in chan2screenpos:
  1506. screenpos = chan2screenpos[maxchan]
  1507. msu = {'m':series['m'], 's':series['s'], 'u':uid}
  1508. msustr = msufmt.format(**msu)
  1509. cellscreenpos.loc[msustr] = screenpos
  1510. ## TODO: somehow convert series2rffitdf dict to DataFrame multiindexed by msstr, chan, rftype:
  1511. #muarfs = pd.DataFrame()
  1512. '''
  1513. # test reliability measure using sparse trials with randomly placed spikes:
  1514. ntrials, nt, nspikes = 200, 1000, 50
  1515. trials = np.zeros((ntrials, nt))
  1516. for trial in trials:
  1517. for spikei in range(nspikes):
  1518. i = np.random.randint(0, nt)
  1519. trial[i] = 1
  1520. print('reliability:', reliability(trials)[0])
  1521. '''
  1522. '''
  1523. # how to get pairs of ON and OFF values from mviresp instead of resp:
  1524. # multiindex row slicing is confusing, need the axis=0 for some reason,
  1525. # see https://pandas.pydata.org/pandas-docs/stable/advanced.html#using-slicers
  1526. li = mviresp.loc(axis=0)
  1527. sparsons = np.float64(li[:, st8, True]['spars'].values)
  1528. sparsoffs = np.float64(li[:, st8, False]['spars'].values)
  1529. nanis = np.isnan(sparsons)
  1530. if not (np.isnan(sparsoffs) == nanis).all():
  1531. raise ValueError("ON and OFF conditions have different nan indices")
  1532. sparsons = sparsons[~nanis]
  1533. sparsoffs = sparsoffs[~nanis]
  1534. '''
  1535. ## TODO: now that mseu are the row indicies instead of msu, they're all unique, and
  1536. ## the group kwarg isn't doing any grouping, I think
  1537. ## TODO: I think the MLM stats don't work because currently the various different MVIKINDS
  1538. ## and ST8S are disabled, only nat and none are being used. So comment this out for now:
  1539. '''
  1540. print('ALL MVI MixedLM')
  1541. mvirespr = mviresp.reset_index() # make a copy, convert mi to columns
  1542. mvirespr = mvirespr.dropna() # drop rows with nans, but not really necessary
  1543. mvirespr = mvirespr[mvirespr['st8'] != 'none'] # exclude 'none' state
  1544. for col in ['meanrate', 'spars', 'rel', 'meanburstratio']:
  1545. # each of these columns are object arrs for some reason, instead of float, which
  1546. # confuses statsmodels and prevents printout of summary. See:
  1547. # https://stackoverflow.com/questions/29799161/summary-not-working-for-ols-estimation
  1548. mvirespr[col] = mvirespr[col].astype(np.float64) # convert each column to float
  1549. formulas = ('meanrate ~ kind * opto * st8',
  1550. 'spars ~ kind * opto * st8',
  1551. 'rel ~ kind * opto * st8',
  1552. 'meanburstratio ~ kind * opto * st8')
  1553. for formula in formulas:
  1554. # group kwarg specifies how data should be treated, prevents measures from different units
  1555. # from being lumped together:
  1556. mlm = smf.mixedlm(formula=formula, data=mvirespr, groups=mvirespr['mseu'])
  1557. results = mlm.fit()
  1558. print(formula)
  1559. print('ngroups:', results.mlm.n_groups)
  1560. paramp = pd.concat([results.params, results.pvalues], axis=1,
  1561. keys=['param', 'p']) # Dataframe from 2 Series
  1562. print(paramp)
  1563. #print(results.summary())
  1564. print()
  1565. print('----------------------------------------')
  1566. print()
  1567. print('NAT only MixedLM')
  1568. mvirespr = mviresp.reset_index() # make a copy, convert mi to columns
  1569. natrespr = mvirespr[mvirespr['kind'] == 'nat']
  1570. natrespr = natrespr.dropna() # drop rows with nans, but not really necessary
  1571. natrespr = natrespr[natrespr['st8'] != 'none'] # exclude 'none' state
  1572. for col in ['meanrate', 'spars', 'rel', 'meanburstratio']:
  1573. # each of these columns are object arrs for some reason, instead of float, which
  1574. # confuses statsmodels and prevents printout of summary. See:
  1575. # https://stackoverflow.com/questions/29799161/summary-not-working-for-ols-estimation
  1576. natrespr[col] = natrespr[col].astype(np.float64) # convert each column to float
  1577. formulas = ('meanrate ~ opto * st8',
  1578. 'spars ~ opto * st8',
  1579. 'rel ~ opto * st8',
  1580. 'meanburstratio ~ opto * st8')
  1581. for formula in formulas:
  1582. # group kwarg specifies how data should be treated, prevents measures from different units
  1583. # from being lumped together:
  1584. mlm = smf.mixedlm(formula=formula, data=natrespr, groups=natrespr['mseu'])
  1585. results = mlm.fit()
  1586. print(formula)
  1587. print('ngroups:', results.mlm.n_groups)
  1588. paramp = pd.concat([results.params, results.pvalues], axis=1,
  1589. keys=['param', 'p']) # Dataframe from 2 Series
  1590. print(paramp)
  1591. #print(results.summary())
  1592. print()
  1593. print('----------------------------------------')
  1594. print()
  1595. print('PNK only MixedLM')
  1596. mvirespr = mviresp.reset_index() # make a copy, convert mi to columns
  1597. pnkrespr = mvirespr[mvirespr['kind'] == 'pnk']
  1598. pnkrespr = pnkrespr.dropna() # drop rows with nans, but not really necessary
  1599. pnkrespr = pnkrespr[pnkrespr['st8'] != 'none'] # exclude 'none' state
  1600. for col in ['meanrate', 'spars', 'rel', 'meanburstratio']:
  1601. # each of these columns are object arrs for some reason, instead of float, which
  1602. # confuses statsmodels and prevents printout of summary. See:
  1603. # https://stackoverflow.com/questions/29799161/summary-not-working-for-ols-estimation
  1604. pnkrespr[col] = pnkrespr[col].astype(np.float64) # convert each column to float
  1605. formulas = ('meanrate ~ opto * st8',
  1606. 'spars ~ opto * st8',
  1607. 'rel ~ opto * st8',
  1608. 'meanburstratio ~ opto * st8')
  1609. for formula in formulas:
  1610. # group kwarg specifies how data should be treated, prevents measures from different units
  1611. # from being lumped together:
  1612. mlm = smf.mixedlm(formula=formula, data=pnkrespr, groups=pnkrespr['mseu'])
  1613. results = mlm.fit()
  1614. print(formula)
  1615. print('ngroups:', results.mlm.n_groups)
  1616. paramp = pd.concat([results.params, results.pvalues], axis=1,
  1617. keys=['param', 'p']) # Dataframe from 2 Series
  1618. print(paramp)
  1619. #print(results.summary())
  1620. print()
  1621. print('----------------------------------------')
  1622. print()
  1623. print('SHF only MixedLM')
  1624. mvirespr = mviresp.reset_index() # make a copy, convert mi to columns
  1625. shfrespr = mvirespr[mvirespr['kind'] == 'shf']
  1626. shfrespr = shfrespr.dropna() # drop rows with nans, but not really necessary
  1627. shfrespr = shfrespr[shfrespr['st8'] != 'none'] # exclude 'none' state
  1628. for col in ['meanrate', 'spars', 'rel', 'meanburstratio']:
  1629. # each of these columns are object arrs for some reason, instead of float, which
  1630. # confuses statsmodels and prevents printout of summary. See:
  1631. # https://stackoverflow.com/questions/29799161/summary-not-working-for-ols-estimation
  1632. shfrespr[col] = shfrespr[col].astype(np.float64) # convert each column to float
  1633. formulas = ('meanrate ~ opto * st8',
  1634. 'spars ~ opto * st8',
  1635. 'rel ~ opto * st8',
  1636. 'meanburstratio ~ opto * st8')
  1637. for formula in formulas:
  1638. # group kwarg specifies how data should be treated, prevents measures from different units
  1639. # from being lumped together:
  1640. mlm = smf.mixedlm(formula=formula, data=shfrespr, groups=shfrespr['mseu'])
  1641. results = mlm.fit()
  1642. print(formula)
  1643. print('ngroups:', results.mlm.n_groups)
  1644. paramp = pd.concat([results.params, results.pvalues], axis=1,
  1645. keys=['param', 'p']) # Dataframe from 2 Series
  1646. print(paramp)
  1647. #print(results.summary())
  1648. print()
  1649. print('----------------------------------------')
  1650. print()
  1651. print('Grating MixedLM')
  1652. grtrespr = grtresp.reset_index() # make a copy, convert mi to columns
  1653. grtrespr = grtrespr.dropna() # drop rows with nans, but not really necessary
  1654. grtrespr = grtrespr[grtrespr['st8'] != 'none'] # exclude 'none' state
  1655. for col in ['meanrate', 'meanburstratio']:
  1656. # each of these columns are object arrs for some reason, instead of float, which
  1657. # confuses statsmodels and prevents printout of summary. See:
  1658. # https://stackoverflow.com/questions/29799161/summary-not-working-for-ols-estimation
  1659. grtrespr[col] = grtrespr[col].astype(np.float64) # convert each column to float
  1660. formulas = ('meanrate ~ opto * st8',
  1661. 'meanburstratio ~ opto * st8')
  1662. for formula in formulas:
  1663. # group kwarg specifies how data should be treated, prevents measures from different units
  1664. # from being lumped together:
  1665. mlm = smf.mixedlm(formula=formula, data=grtrespr, groups=grtrespr['mseu'])
  1666. results = mlm.fit()
  1667. print(formula)
  1668. print('ngroups:', results.mlm.n_groups)
  1669. paramp = pd.concat([results.params, results.pvalues], axis=1,
  1670. keys=['param', 'p']) # Dataframe from 2 Series
  1671. print(paramp)
  1672. #print(results.summary())
  1673. print()
  1674. print('----------------------------------------')
  1675. print()
  1676. '''