fig6.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. """Figure 6 plots, use run -i fig6.py"""
  2. allmseustrs = np.union1d(mvimseustrs, grtmseustrs)
  3. modis = ['fmi', 'rmi', 'suppressionrmi', 'feedbackrmi', 'runfmi', 'sitfmi']
  4. mi = pd.MultiIndex.from_product([allmseustrs, STIMTYPES, modis],
  5. names=['mseu', 'stimtype', 'mi'])
  6. columns = ['meanrate', 'meanburstratio', 'spars', 'rel', 'meanpkw', 'snr']
  7. fig6 = pd.DataFrame(index=mi, columns=columns)
  8. stimtype2FMI = {'mvi': mviFMI, 'grt': grtFMI}
  9. stimtype2RMI = {'mvi': mviRMI, 'grt': grtRMI}
  10. loss = 'soft_l1' # loss function, less sensitive to outliers than ordinary least squares
  11. f_scale = 0.25 # residual value at which to start treating points as outliers
  12. # scatter plot run modulation index for feedback vs. suppression trials,
  13. # for movies and gratings, for all measures:
  14. figsize = DEFAULTFIGURESIZE
  15. linmin, linmax, linstep = -1, 1, 1
  16. ticks = np.arange(linmin, linmax+linstep, linstep)
  17. for measure in modmeasuresnoblank:
  18. axislabel = measure2axislabel[measure]
  19. axislabel = short2longaxislabel.get(axislabel, axislabel)
  20. if axislabel.islower():
  21. axislabel = axislabel.capitalize()
  22. for stimtype in STIMTYPES:
  23. FMI = stimtype2FMI[stimtype]
  24. RMI = stimtype2RMI[stimtype]
  25. mseustrs = {'mvi':mvimseustrs, 'grt':grtmseustrs}[stimtype]
  26. mseu2exmpli = {'mvi':mvimseu2exmpli, 'grt':grtmseu2exmpli}[stimtype]
  27. stimtypelabel = stimtype2axislabel[stimtype]
  28. supprmis, fbrmis, exmplis, exmplmseustrs, normlis = [], [], [], [], []
  29. keptmseui = 0 # manually init and increment instead of using enumerate()
  30. for mseustr in mseustrs:
  31. supprmi = RMI[measure][mseustr, True]
  32. fbrmi = RMI[measure][mseustr, False]
  33. if np.isnan(supprmi) or np.isnan(fbrmi): # missing one or both mod values
  34. continue
  35. supprmis.append(supprmi)
  36. fbrmis.append(fbrmi)
  37. fig6.loc[mseustr, stimtype, 'suppressionrmi'][measure] = supprmi
  38. fig6.loc[mseustr, stimtype, 'feedbackrmi'][measure] = fbrmi
  39. if mseustr in mseu2exmpli:
  40. exmplis.append(keptmseui)
  41. exmplmseustrs.append(mseustr)
  42. else:
  43. normlis.append(keptmseui)
  44. keptmseui += 1 # manually increment
  45. supprmis = np.asarray(supprmis)
  46. fbrmis = np.asarray(fbrmis)
  47. if len(supprmis) == 0 or len(fbrmis) == 0:
  48. continue # nothing to plot
  49. f, a = plt.subplots(figsize=figsize)
  50. ## scatter plot feedback vs suppression RMI
  51. wintitle('fbsupp RMI %s %s' % (measure, stimtypelabel))
  52. # plot x=0 and y=0 lines:
  53. a.axhline(y=0, ls='--', marker='', color='lightgray', zorder=-np.inf)
  54. a.axvline(x=0, ls='--', marker='', color='lightgray', zorder=-np.inf)
  55. # plot y=x line:
  56. xyline = [linmin, linmax], [linmin, linmax]
  57. a.plot(xyline[0], xyline[1], '--', color='gray', zorder=-1)
  58. # plot normal (non-example) points:
  59. a.scatter(supprmis[normlis], fbrmis[normlis], clip_on=False,
  60. marker='.', c='None', edgecolor='black', s=DEFSZ)
  61. # plot example points, one at a time:
  62. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  63. marker = exmpli2mrk[mseu2exmpli[mseustr]]
  64. c = exmpli2clr[mseu2exmpli[mseustr]]
  65. sz = exmpli2sz[mseu2exmpli[mseustr]]
  66. lw = exmpli2lw[mseu2exmpli[mseustr]]
  67. a.scatter(supprmis[exmpli], fbrmis[exmpli], clip_on=False,
  68. marker=marker, c=c, s=sz, lw=lw)
  69. # get fname for relevant figure panels:
  70. relev_measures = ['meanrate', 'meanburstratio', 'spars', 'rel']
  71. if (measure in relev_measures) and stimtype == 'mvi':
  72. # if measure is relevant, get parameters for model from csv file:
  73. if measure == 'meanrate':
  74. fname = os.path.join('stats', 'figure_6a1_coefs.csv')
  75. elif measure == 'meanburstratio':
  76. fname = os.path.join('stats', 'figure_6a2_coefs.csv')
  77. elif measure == 'spars':
  78. fname = os.path.join('stats', 'figure_6a3_coefs.csv')
  79. elif measure == 'rel':
  80. fname = os.path.join('stats', 'figure_6a4_coefs.csv')
  81. try:
  82. df = pd.read_csv(fname)
  83. foundregression = True
  84. except FileNotFoundError:
  85. print('Missing file: %s' % fname)
  86. foundregression = False
  87. if foundregression:
  88. df = pd.read_csv(fname)
  89. mm = df['slope'][0]
  90. b = df['intercept'][0]
  91. x = np.array([np.nanmin(supprmis), np.nanmax(supprmis)])
  92. y = mm * x + b
  93. a.plot(x, y, '-', color='red') # plot linregress fit
  94. #a.set_title('%s RMI' % axislabel) # typeface too big by default, decreases panel size
  95. a.set_xlabel('Suppression')
  96. a.set_ylabel('Feedback')
  97. a.set_xlim(linmin, linmax)
  98. a.set_ylim(linmin, linmax)
  99. a.set_xticks(ticks)
  100. a.set_yticks(a.get_xticks()) # make log scale y ticks the same as x ticks
  101. a.minorticks_off()
  102. a.set_aspect('equal')
  103. a.spines['left'].set_position(('outward', 4))
  104. a.spines['bottom'].set_position(('outward', 4))
  105. # scatter plot feedback modulation index for run vs. sit trials,
  106. # for movies and gratings, for all measures:
  107. figsize = DEFAULTFIGURESIZE
  108. linmin, linmax, linstep = -1, 1, 1
  109. ticks = np.arange(linmin, linmax+linstep, linstep)
  110. for measure in modmeasuresnoblank:
  111. axislabel = measure2axislabel[measure]
  112. if axislabel.islower():
  113. axislabel = axislabel.capitalize()
  114. for stimtype in STIMTYPES:
  115. FMI = stimtype2FMI[stimtype]
  116. mseustrs = {'mvi':mvimseustrs, 'grt':grtmseustrs}[stimtype]
  117. mseu2exmpli = {'mvi':mvimseu2exmpli, 'grt':grtmseu2exmpli}[stimtype]
  118. stimtypelabel = stimtype2axislabel[stimtype]
  119. runfmis, sitfmis, exmplis, exmplmseustrs, normlis = [], [], [], [], []
  120. keptmseui = 0 # manually init and increment instead of using enumerate()
  121. for mseustr in mseustrs:
  122. runfmi = FMI[measure][mseustr, 'run']
  123. sitfmi = FMI[measure][mseustr, 'sit']
  124. if np.isnan(runfmi) or np.isnan(sitfmi): # missing one or both mod values
  125. continue
  126. runfmis.append(runfmi)
  127. sitfmis.append(sitfmi)
  128. fig6.loc[mseustr, stimtype, 'runfmi'][measure] = runfmi
  129. fig6.loc[mseustr, stimtype, 'sitfmi'][measure] = sitfmi
  130. if mseustr in mseu2exmpli:
  131. exmplis.append(keptmseui)
  132. exmplmseustrs.append(mseustr)
  133. else:
  134. normlis.append(keptmseui)
  135. keptmseui += 1 # manually increment
  136. runfmis = np.asarray(runfmis)
  137. sitfmis = np.asarray(sitfmis)
  138. if len(runfmis) == 0 or len(sitfmis) == 0:
  139. continue # nothing to plot
  140. f, a = plt.subplots(figsize=figsize)
  141. wintitle('runsit FMI %s %s' % (measure, stimtypelabel))
  142. # plot x=0 and y=0 lines:
  143. a.axhline(y=0, ls='--', marker='', color='lightgray', zorder=-np.inf)
  144. a.axvline(x=0, ls='--', marker='', color='lightgray', zorder=-np.inf)
  145. # plot y=x line:
  146. xyline = [linmin, linmax], [linmin, linmax]
  147. a.plot(xyline[0], xyline[1], '--', color='gray', zorder=-1)
  148. # plot normal (non-example) points:
  149. a.scatter(sitfmis[normlis], runfmis[normlis], clip_on=False,
  150. marker='.', c='None', edgecolor='black', s=DEFSZ)
  151. # plot example points, one at a time:
  152. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  153. marker = exmpli2mrk[mseu2exmpli[mseustr]]
  154. c = exmpli2clr[mseu2exmpli[mseustr]]
  155. sz = exmpli2sz[mseu2exmpli[mseustr]]
  156. lw = exmpli2lw[mseu2exmpli[mseustr]]
  157. a.scatter(sitfmis[exmpli], runfmis[exmpli], clip_on=False,
  158. marker=marker, c=c, s=sz, lw=lw)
  159. # get fname for relevant figure panels:
  160. if stimtype == 'mvi':
  161. relev_measures = ['meanrate', 'meanburstratio', 'spars', 'rel']
  162. elif stimtype == 'grt':
  163. relev_measures = ['meanrate', 'meanburstratio']
  164. if measure in relev_measures:
  165. # if measure is relevant, get parameters for model from csv file:
  166. if stimtype == 'mvi':
  167. if measure == 'meanrate':
  168. fname = os.path.join('stats', 'figure_6b1_coefs.csv')
  169. elif measure == 'meanburstratio':
  170. fname = os.path.join('stats', 'figure_6b2_coefs.csv')
  171. elif measure == 'spars':
  172. fname = os.path.join('stats', 'figure_6b3_coefs.csv')
  173. elif measure == 'rel':
  174. fname = os.path.join('stats', 'figure_6b4_coefs.csv')
  175. elif stimtype == 'grt':
  176. if measure == 'meanrate':
  177. fname = os.path.join('stats', 'figure_6_S1b1_coefs.csv')
  178. elif measure == 'meanburstratio':
  179. fname = os.path.join('stats', 'figure_6_S1b2_coefs.csv')
  180. try:
  181. df = pd.read_csv(fname)
  182. foundregression = True
  183. except FileNotFoundError:
  184. print('Missing file: %s' % fname)
  185. foundregression = False
  186. if foundregression:
  187. df = pd.read_csv(fname)
  188. mm = df['slope'][0]
  189. b = df['intercept'][0]
  190. x = np.array([np.nanmin(sitfmis), np.nanmax(sitfmis)])
  191. y = mm * x + b
  192. a.plot(x, y, '-', color='red') # plot linregress fit
  193. #a.set_title('%s FMI' % axislabel) # typeface too big by default, decreases panel size
  194. a.set_xlabel('Sit')
  195. a.set_ylabel('Run')
  196. a.set_xlim(linmin, linmax)
  197. a.set_ylim(linmin, linmax)
  198. a.set_xticks(ticks)
  199. a.set_yticks(a.get_xticks()) # make log scale y ticks the same as x ticks
  200. a.minorticks_off()
  201. a.set_aspect('equal')
  202. a.spines['left'].set_position(('outward', 4))
  203. a.spines['bottom'].set_position(('outward', 4))
  204. # scatter plot feedback modulation index vs. run modulation index,
  205. # for movies and gratings, for all measures:
  206. figsize = DEFAULTFIGURESIZE
  207. linmin, linmax, linstep = -1, 1, 1
  208. ticks = np.arange(linmin, linmax+linstep, linstep)
  209. for measure in modmeasuresnoblank:
  210. axislabel = measure2axislabel[measure]
  211. if axislabel.islower():
  212. axislabel = axislabel.capitalize()
  213. for stimtype in STIMTYPES:
  214. FMI = stimtype2FMI[stimtype]
  215. RMI = stimtype2RMI[stimtype]
  216. mseustrs = {'mvi':mvimseustrs, 'grt':grtmseustrs}[stimtype]
  217. mseu2exmpli = {'mvi':mvimseu2exmpli, 'grt':grtmseu2exmpli}[stimtype]
  218. stimtypelabel = stimtype2axislabel[stimtype]
  219. fmis, rmis, exmplis, exmplmseustrs, normlis = [], [], [], [], []
  220. keptmseui = 0 # manually init and increment instead of using enumerate()
  221. for mseustr in mseustrs:
  222. fmi = FMI[measure][mseustr, 'none'] # ignore run condition for FMI
  223. rmi = RMI[measure][mseustr, False] # use only feedback condition for RMI
  224. if np.isnan(fmi) or np.isnan(rmi): # missing one or both mod values
  225. continue
  226. fmis.append(fmi)
  227. rmis.append(rmi)
  228. fig6.loc[mseustr, stimtype, 'fmi'][measure] = fmi
  229. fig6.loc[mseustr, stimtype, 'rmi'][measure] = rmi
  230. if mseustr in mseu2exmpli:
  231. exmplis.append(keptmseui)
  232. exmplmseustrs.append(mseustr)
  233. else:
  234. normlis.append(keptmseui)
  235. keptmseui += 1 # manually increment
  236. fmis = np.asarray(fmis)
  237. rmis = np.asarray(rmis)
  238. if len(fmis) == 0 or len(rmis) == 0:
  239. continue # nothing to plot
  240. f, a = plt.subplots(figsize=figsize)
  241. wintitle('FMI vs RMI %s %s none False' % (measure, stimtypelabel))
  242. # plot x=0 and y=0 lines:
  243. a.axhline(y=0, ls='--', marker='', color='lightgray', zorder=-np.inf)
  244. a.axvline(x=0, ls='--', marker='', color='lightgray', zorder=-np.inf)
  245. # plot y=x line:
  246. xyline = [linmin, linmax], [linmin, linmax]
  247. a.plot(xyline[0], xyline[1], '--', color='gray', zorder=-1)
  248. # plot normal (non-example) points:
  249. a.scatter(rmis[normlis], fmis[normlis], clip_on=False,
  250. marker='.', c='None', edgecolor='black', s=DEFSZ)
  251. # plot example points, one at a time:
  252. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  253. marker = exmpli2mrk[mseu2exmpli[mseustr]]
  254. c = exmpli2clr[mseu2exmpli[mseustr]]
  255. sz = exmpli2sz[mseu2exmpli[mseustr]]
  256. lw = exmpli2lw[mseu2exmpli[mseustr]]
  257. a.scatter(rmis[exmpli], fmis[exmpli], clip_on=False,
  258. marker=marker, c=c, s=sz, lw=lw)
  259. # get fname for relevant figure panels:
  260. if stimtype == 'mvi':
  261. relev_measures = ['meanrate', 'meanburstratio', 'spars', 'rel']
  262. elif stimtype == 'grt':
  263. relev_measures = ['meanrate', 'meanburstratio']
  264. if measure in relev_measures:
  265. # if measure is relevant, get parameters for model from csv file:
  266. if stimtype == 'mvi':
  267. if measure == 'meanrate':
  268. fname = os.path.join('stats', 'figure_6c1_coefs.csv')
  269. elif measure == 'meanburstratio':
  270. fname = os.path.join('stats', 'figure_6c2_coefs.csv')
  271. elif measure == 'spars':
  272. fname = os.path.join('stats', 'figure_6c3_coefs.csv')
  273. elif measure == 'rel':
  274. fname = os.path.join('stats', 'figure_6c4_coefs.csv')
  275. elif stimtype == 'grt':
  276. if measure == 'meanrate':
  277. fname = os.path.join('stats', 'figure_6_S1c1_coefs.csv')
  278. elif measure == 'meanburstratio':
  279. fname = os.path.join('stats', 'figure_6_S1c2_coefs.csv')
  280. try:
  281. df = pd.read_csv(fname)
  282. foundregression = True
  283. except FileNotFoundError:
  284. print('Missing file: %s' % fname)
  285. foundregression = False
  286. if foundregression:
  287. df = pd.read_csv(fname)
  288. mm = df['slope'][0]
  289. b = df['intercept'][0]
  290. x = np.array([np.nanmin(rmis), np.nanmax(rmis)])
  291. y = mm * x + b
  292. a.plot(x, y, '-', color='red') # plot linregress fit
  293. #a.set_title('%s' % axislabel) # typeface too big by default, decreases panel size
  294. a.set_xlabel('RMI')
  295. a.set_ylabel('FMI')
  296. a.set_xlim(linmin, linmax)
  297. a.set_ylim(linmin, linmax)
  298. a.set_xticks(ticks)
  299. a.set_yticks(a.get_xticks())
  300. a.minorticks_off()
  301. a.set_aspect('equal')
  302. a.spines['left'].set_position(('outward', 4))
  303. a.spines['bottom'].set_position(('outward', 4))