fig1S33S1.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. """Figure 1S3 and 3S1 unit classification plots, use run -i fig1S33S1.py"""
  2. index = pd.Index(mvigrtmsustrs, name='msu')
  3. columns = ['mvi_meanrate', 'mvi_meanburstratio', 'grt_meanrate', 'grt_meanburstratio',
  4. 'sbc', 'depth', 'normdepth', 'dsi', 'rfdist',
  5. 'mvi_meanrate_raw', 'mvi_meanburstratio_raw',
  6. 'grt_meanrate_raw', 'grt_meanburstratio_raw']
  7. fig1S33S1 = pd.DataFrame(index=index, columns=columns)
  8. from scipy.stats import gmean
  9. import matplotlib as mpl
  10. from matplotlib.patches import Rectangle
  11. TRANSTHRESH = 0.2
  12. ONOFFTHRESH = 0.2
  13. # stripplot FMI by SbC/non-SbC, mvi and grt, meanrate and meanburstratio:
  14. np.random.seed(0) # to get identical horizontal jitter in strip plots on every run
  15. figsize = DEFAULTFIGURESIZE
  16. for stimtype in STIMTYPES:
  17. stimtypelabel = stimtype2axislabel[stimtype]
  18. for measure in ['meanrate', 'meanburstratio']:
  19. axislabel = measure2axislabel[measure]
  20. axislabel = short2longaxislabel.get(axislabel, axislabel)
  21. if axislabel.islower():
  22. axislabel = axislabel.capitalize()
  23. fmis, sbcs = [], []
  24. colname = '_'.join([stimtype, measure]) # e.g. 'mvi_meanrate'
  25. for msustr in mvigrtmsustrs:
  26. # save sbc to fig1S33S1 df, regardless of FMI val, will be overwritten multiple
  27. # times with the same value, but that's OK:
  28. sbc = celltype.loc[msustr]['sbc'] # bool
  29. fig1S33S1.loc[msustr]['sbc'] = sbc
  30. fmi = maxFMI.loc[msustr, 'none', stimtype][measure] # ignore run condition for FMI
  31. if np.isnan(fmi):
  32. continue
  33. fmis.append(fmi)
  34. sbcs.append(sbc)
  35. fig1S33S1.loc[msustr][colname] = fmi
  36. fmis = np.asarray(fmis)
  37. sbcs = np.asarray(sbcs)
  38. sbcfmis = fmis[sbcs == True]
  39. nonsbcfmis = fmis[sbcs == False]
  40. f, a = plt.subplots(figsize=figsize)
  41. wintitle('FMI SbC %s %s strip' % (stimtypelabel, measure))
  42. # plot y=0 line:
  43. a.axhline(y=0, ls='--', marker='', color='lightgray', zorder=-np.inf)
  44. data = pd.DataFrame.from_dict({'SbC':sbcfmis, 'Non-SbC':nonsbcfmis},
  45. orient='index').transpose()
  46. sns.stripplot(ax=a, data=data, clip_on=False, marker='.',
  47. color='None', edgecolor='black', size=np.sqrt(50))
  48. # get fname of appropriate LMM .cvs file:
  49. fname = None # sanity check: clear from previous loop
  50. if stimtype == 'mvi':
  51. if measure == 'meanrate':
  52. fname = os.path.join('stats', 'figure_1_S3a_pred_means.csv')
  53. elif measure == 'meanburstratio':
  54. fname = os.path.join('stats', 'figure_1_S3f_pred_means.csv')
  55. elif stimtype == 'grt':
  56. if measure == 'meanrate':
  57. fname = os.path.join('stats', 'figure_3_S1a_pred_means.csv')
  58. elif measure == 'meanburstratio':
  59. fname = os.path.join('stats', 'figure_3_S1f_pred_means.csv')
  60. # fetch LMM means from .csv:
  61. df = pd.read_csv(fname)
  62. meannonsbcfmi = df['non_sbc'][0]
  63. meansbcfmi = df['sbc'][0]
  64. # plot mean with short horizontal lines:
  65. a.plot([-0.25, 0.25], [meansbcfmi, meansbcfmi], '-', lw=2, c='red', zorder=np.inf)
  66. a.plot([0.75, 1.25], [meannonsbcfmi, meannonsbcfmi], '-', lw=2, c='red', zorder=np.inf)
  67. a.set_ylabel('%s FMI' % axislabel)
  68. a.set_ylim(-1, 1)
  69. a.set_yticks([-1, 0, 1])
  70. a.tick_params(bottom=False)
  71. a.spines['bottom'].set_position(('outward', 5))
  72. a.spines['bottom'].set_visible(False)
  73. # scatter plot FMI vs. depth by mvi and grt, meanrate and meanburstratio:
  74. figsize = DEFAULTFIGURESIZE
  75. for stimtype in STIMTYPES:
  76. stimtypelabel = stimtype2axislabel[stimtype]
  77. for measure in ['meanrate', 'meanburstratio']:
  78. axislabel = measure2axislabel[measure]
  79. axislabel = short2longaxislabel.get(axislabel, axislabel)
  80. if axislabel.islower():
  81. axislabel = axislabel.capitalize()
  82. fmis, depths = [], []
  83. colname = '_'.join([stimtype, measure]) # e.g. 'mvi_meanrate'
  84. for msustr in mvigrtmsustrs:
  85. # save depth to fig1S33S1 df, regardless of FMI val, will be overwritten multiple
  86. # times with the same value, but that's OK:
  87. depth = celltype.loc[msustr]['depth'] # float
  88. fig1S33S1.loc[msustr]['depth'] = depth
  89. fmi = maxFMI.loc[msustr, 'none', stimtype][measure] # ignore run condition for FMI
  90. if np.isnan(fmi):
  91. continue
  92. fmis.append(fmi)
  93. depths.append(depth)
  94. fig1S33S1.loc[msustr][colname] = fmi # might overwrite w/ identical values
  95. fmis = np.asarray(fmis)
  96. depths = np.asarray(depths)
  97. f, a = plt.subplots(figsize=figsize)
  98. wintitle('FMI depth %s %s' % (stimtypelabel, measure))
  99. # plot y=0 line:
  100. a.axhline(y=0, ls='--', marker='', color='lightgray', zorder=-np.inf)
  101. a.scatter(depths, fmis, clip_on=False, marker='.', c='None', edgecolor='k', s=DEFSZ)
  102. # get fname of appropriate LMM .cvs file:
  103. fname = None # sanity check: clear from previous loop
  104. if stimtype == 'mvi':
  105. if measure == 'meanrate':
  106. fname = os.path.join('stats', 'figure_1_S3b_coefs.csv')
  107. elif measure == 'meanburstratio':
  108. fname = os.path.join('stats', 'figure_1_S3g_coefs.csv')
  109. elif stimtype == 'grt':
  110. if measure == 'meanrate':
  111. fname = os.path.join('stats', 'figure_3_S1b_coefs.csv')
  112. elif measure == 'meanburstratio':
  113. fname = os.path.join('stats', 'figure_3_S1g_coefs.csv')
  114. # fetch LMM linregress fit params from .csv:
  115. df = pd.read_csv(fname)
  116. mm = df['slope'][0]
  117. b = df['intercept'][0]
  118. x = np.array([np.nanmin(depths), np.nanmax(depths)])
  119. y = mm * x + b
  120. a.plot(x, y, '-', color='red') # plot linregress fit
  121. a.set_xlabel('Depth ($\mathregular{\mu}$m)')
  122. a.set_ylabel('%s FMI' % axislabel)
  123. a.set_xlim(0, 500)
  124. a.set_ylim(-1, 1)
  125. a.set_yticks([-1, 0, 1])
  126. a.spines['left'].set_position(('outward', 4))
  127. a.spines['bottom'].set_position(('outward', 4))
  128. # scatter plot FMI vs. DSI by mvi and grt, meanrate and meanburstratio:
  129. figsize = DEFAULTFIGURESIZE
  130. for stimtype in STIMTYPES:
  131. stimtypelabel = stimtype2axislabel[stimtype]
  132. for measure in ['meanrate', 'meanburstratio']:
  133. axislabel = measure2axislabel[measure]
  134. axislabel = short2longaxislabel.get(axislabel, axislabel)
  135. if axislabel.islower():
  136. axislabel = axislabel.capitalize()
  137. fmis, dsis = [], []
  138. colname = '_'.join([stimtype, measure]) # e.g. 'mvi_meanrate'
  139. for msustr in mvigrtmsustrs:
  140. # save DSI to fig1S33S1 df, regardless of FMI val, will be overwritten multiple
  141. # times with the same value, but that's OK:
  142. dsi = celltype.loc[msustr]['dsi'] # float
  143. fig1S33S1.loc[msustr]['dsi'] = dsi
  144. fmi = maxFMI.loc[msustr, 'none', stimtype][measure] # ignore run condition for FMI
  145. if np.isnan(fmi):
  146. continue
  147. fmis.append(fmi)
  148. dsis.append(dsi)
  149. fig1S33S1.loc[msustr][colname] = fmi # might overwrite w/ identical values
  150. fmis = np.asarray(fmis)
  151. dsis = np.asarray(dsis)
  152. f, a = plt.subplots(figsize=figsize)
  153. wintitle('FMI DSI %s %s' % (stimtypelabel, measure))
  154. # plot y=0 line:
  155. a.axhline(y=0, ls='--', marker='', color='lightgray', zorder=-np.inf)
  156. a.scatter(dsis, fmis, clip_on=False, marker='.', c='None', edgecolor='k', s=DEFSZ)
  157. # get fname of appropriate LMM .cvs file:
  158. fname = None # sanity check: clear from previous loop
  159. if stimtype == 'mvi':
  160. if measure == 'meanrate':
  161. fname = os.path.join('stats', 'figure_1_S3c_coefs.csv')
  162. elif measure == 'meanburstratio':
  163. fname = os.path.join('stats', 'figure_1_S3h_coefs.csv')
  164. elif stimtype == 'grt':
  165. if measure == 'meanrate':
  166. fname = os.path.join('stats', 'figure_3_S1c_coefs.csv')
  167. elif measure == 'meanburstratio':
  168. fname = os.path.join('stats', 'figure_3_S1h_coefs.csv')
  169. # fetch LMM linregress fit params from .csv:
  170. df = pd.read_csv(fname)
  171. mm = df['slope'][0]
  172. b = df['intercept'][0]
  173. x = np.array([np.nanmin(dsis), np.nanmax(dsis)])
  174. y = mm * x + b
  175. a.plot(x, y, '-', color='red') # plot linregress fit
  176. a.set_xlabel('DSI')
  177. a.set_ylabel('%s FMI' % axislabel)
  178. a.set_xlim(0, 1)
  179. a.set_ylim(-1, 1)
  180. a.set_yticks([-1, 0, 1])
  181. a.spines['left'].set_position(('outward', 4))
  182. a.spines['bottom'].set_position(('outward', 4))
  183. # scatter plot FMI vs. distance of MUA envl RF from screen center, by mvi and grt,
  184. # meanrate and meanburstratio:
  185. figsize = DEFAULTFIGURESIZE
  186. for stimtype in STIMTYPES:
  187. stimtypelabel = stimtype2axislabel[stimtype]
  188. for measure in ['meanrate', 'meanburstratio']:
  189. axislabel = measure2axislabel[measure]
  190. axislabel = short2longaxislabel.get(axislabel, axislabel)
  191. if axislabel.islower():
  192. axislabel = axislabel.capitalize()
  193. fmis, ds = [], []
  194. colname = '_'.join([stimtype, measure]) # e.g. 'mvi_meanrate'
  195. for msustr in mvigrtmsustrs:
  196. # save rfdist to fig1S33S1 df, regardless of FMI val, will be overwritten multiple
  197. # times with the same value, but that's OK:
  198. x0, y0 = cellscreenpos.loc[msustr]
  199. d = np.sqrt(x0**2 + y0**2) # distance from screen center, deg
  200. fig1S33S1.loc[msustr]['rfdist'] = d
  201. fmi = maxFMI.loc[msustr, 'none', stimtype][measure] # ignore run condition for FMI
  202. if np.isnan(fmi):
  203. continue
  204. fmis.append(fmi)
  205. ds.append(d)
  206. fig1S33S1.loc[msustr][colname] = fmi # might overwrite w/ identical values
  207. fmis = np.asarray(fmis)
  208. ds = np.asarray(ds)
  209. f, a = plt.subplots(figsize=figsize)
  210. wintitle('FMI rfdist %s %s' % (stimtypelabel, measure))
  211. # plot y=0 line:
  212. a.axhline(y=0, ls='--', marker='', color='lightgray', zorder=-np.inf)
  213. a.scatter(ds, fmis, clip_on=False, marker='.', c='None', edgecolor='k', s=DEFSZ)
  214. # get fname of appropriate LMM .cvs file:
  215. fname = None # sanity check: clear from previous loop
  216. if stimtype == 'mvi':
  217. if measure == 'meanrate':
  218. fname = os.path.join('stats', 'figure_1_S3d_coefs.csv')
  219. elif measure == 'meanburstratio':
  220. fname = os.path.join('stats', 'figure_1_S3i_coefs.csv')
  221. elif stimtype == 'grt':
  222. if measure == 'meanrate':
  223. fname = os.path.join('stats', 'figure_3_S1d_coefs.csv')
  224. elif measure == 'meanburstratio':
  225. fname = os.path.join('stats', 'figure_3_S1i_coefs.csv')
  226. # fetch LMM linregress fit params from .csv:
  227. df = pd.read_csv(fname)
  228. mm = df['slope'][0]
  229. b = df['intercept'][0]
  230. x = np.array([np.nanmin(ds), np.nanmax(ds)])
  231. y = mm * x + b
  232. a.plot(x, y, '-', color='red') # plot linregress fit
  233. a.set_xlabel('RF dist. from center ($\degree$)')
  234. a.set_ylabel('%s FMI' % axislabel)
  235. a.set_xlim(0, 40)
  236. a.set_ylim(-1, 1)
  237. a.set_yticks([-1, 0, 1])
  238. a.spines['left'].set_position(('outward', 4))
  239. a.spines['bottom'].set_position(('outward', 4))
  240. # scatter plot FMI vs raw measure, for rate and burst ratio during control condition,
  241. # by mvi and grt:
  242. stimtype2resp = {'mvi':mviresp, 'grt':grtresp}
  243. figsize = DEFAULTFIGURESIZE
  244. for stimtype in STIMTYPES:
  245. stimtypelabel = stimtype2axislabel[stimtype]
  246. resp = stimtype2resp[stimtype]
  247. if stimtype == 'mvi':
  248. resp = resp.xs('nat', level='kind') # dereference movie 'kind' index level
  249. for measure in ['meanrate', 'meanburstratio']:
  250. axislabel = measure2axislabel[measure]
  251. axislabel = short2longaxislabel.get(axislabel, axislabel)
  252. axisunits = measure2axisunits.get(measure, '')
  253. if axislabel.islower():
  254. axislabel = axislabel.capitalize()
  255. fmis, msrs = [], []
  256. fmicolname = '_'.join([stimtype, measure]) # e.g. 'mvi_meanrate'
  257. msrcolname = '_'.join([stimtype, measure, 'raw']) # e.g. 'mvi_meanrate_raw'
  258. for msustr in mvigrtmsustrs:
  259. mseustr, fmi = maxFMI.loc[msustr, 'none', stimtype][['mseu', measure]]
  260. if pd.isna(mseustr) or pd.isna(fmi):
  261. continue
  262. msr = resp.loc[mseustr, 'none', False][measure]
  263. fmis.append(fmi)
  264. msrs.append(msr)
  265. fig1S33S1.loc[msustr][fmicolname] = fmi # might overwrite w/ identical values
  266. fig1S33S1.loc[msustr][msrcolname] = msr
  267. fmis, msrs = np.asarray(fmis), np.asarray(msrs)
  268. ## scatter plot FMI vs raw measure:
  269. f, a = plt.subplots(figsize=figsize)
  270. wintitle('FMI raw %s %s' % (stimtypelabel, measure))
  271. # plot y=0 line:
  272. a.axhline(y=0, ls='--', marker='', color='lightgray', zorder=-np.inf)
  273. a.scatter(msrs, fmis, clip_on=False, marker='.', c='None', edgecolor='k', s=DEFSZ)
  274. # get fname of appropriate LMM .cvs file:
  275. fname = None # sanity check: clear from previous loop
  276. if stimtype == 'mvi':
  277. if measure == 'meanrate':
  278. fname = os.path.join('stats', 'figure_1_S3e_coefs.csv')
  279. elif measure == 'meanburstratio':
  280. fname = os.path.join('stats', 'figure_1_S3j_coefs.csv')
  281. elif stimtype == 'grt':
  282. if measure == 'meanrate':
  283. fname = os.path.join('stats', 'figure_3_S1e_coefs.csv')
  284. elif measure == 'meanburstratio':
  285. fname = os.path.join('stats', 'figure_3_S1j_coefs.csv')
  286. # fetch LMM linregress fit params from .csv:
  287. df = pd.read_csv(fname)
  288. mm = df['slope'][0]
  289. b = df['intercept'][0]
  290. x = np.array([np.nanmin(msrs), np.nanmax(msrs)])
  291. y = mm * x + b
  292. a.plot(x, y, '-', color='red') # plot linregress fit
  293. a.set_xlabel('%s' % axislabel+axisunits)
  294. a.set_ylabel('%s FMI' % axislabel)
  295. a.set_xlim(xmin=0)
  296. a.set_ylim(-1, 1)
  297. #a.set_xticks(ticks)
  298. a.set_yticks([-1, 0, 1])
  299. a.spines['left'].set_position(('outward', 4))
  300. a.spines['bottom'].set_position(('outward', 4))