fig4.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. """Figure 4 and some 4S1 plots, use run -i fig4.py"""
  2. mvimeasures = ['meanrate', 'meanrate02', 'meanrate35', 'blankmeanrate',
  3. 'meanburstratio', 'blankmeanburstratio']
  4. grtmeasures = ['meanrate', 'meanrate', 'meanrate', 'blankmeanrate',
  5. 'meanburstratio', 'blankmeanburstratio']
  6. mi = pd.MultiIndex.from_product([mvigrtmsustrs, STIMTYPES, [False, 'prestim', 'cond']],
  7. names=['msu', 'stimtype', 'blank'])
  8. fig4 = pd.DataFrame(index=mi, columns=['meanrate', 'meanburstratio'])
  9. # strip plot movie vs grating FMI for applicable measures:
  10. np.random.seed(0) # to get identical horizontal jitter in strip plots on every run
  11. figsize = DEFAULTFIGURESIZE
  12. linmin, linmax, linstep = -1, 1, 1
  13. ticks = np.arange(linmin, linmax+linstep, linstep)
  14. for mvimeasure, grtmeasure in zip(mvimeasures, grtmeasures):
  15. measure = mvimeasure
  16. isblank = mvimeasure.startswith('blank')
  17. if isblank:
  18. measure = mvimeasure.split('blank')[1]
  19. axislabel = measure2axislabel[mvimeasure]
  20. axislabel = short2longaxislabel.get(axislabel, axislabel)
  21. if axislabel[0].islower():
  22. Axislabel = axislabel.capitalize()
  23. else:
  24. Axislabel = axislabel
  25. for st8 in ['none']:#ALLST8S:
  26. mvimaxfmis, grtmaxfmis, grtmaxfmi2s = [], [], []
  27. exmplis, exmplmsustrs, normlis = [], [], []
  28. # collect all movie FMIs:
  29. # collect paired movie and grating max FMIs:
  30. keptmsui = 0
  31. for msustr in mvigrtmsustrs:
  32. mvimaxfmi = maxFMI[mvimeasure][msustr, st8, 'mvi']
  33. grtmaxfmi = maxFMI[grtmeasure][msustr, st8, 'grt']
  34. if pd.isna(mvimaxfmi) or pd.isna(grtmaxfmi): # missing one or both maxFMI values
  35. continue
  36. if isblank: # add a second grtmaxfmi measure for blank cond
  37. blankname = 'blankcond' + measure
  38. grtmaxfmi2 = maxFMI[blankname][msustr, st8, 'grt']
  39. if pd.isna(grtmaxfmi2):
  40. continue
  41. else:
  42. grtmaxfmi2s.append(grtmaxfmi2)
  43. mvimaxfmis.append(mvimaxfmi)
  44. grtmaxfmis.append(grtmaxfmi)
  45. if measure in fig4.columns:
  46. assert mvimeasure == grtmeasure
  47. if isblank: # save both kinds of grt blank measures
  48. fig4.loc[msustr, 'mvi', 'prestim'][measure] = mvimaxfmi # save
  49. fig4.loc[msustr, 'grt', 'prestim'][measure] = grtmaxfmi # save
  50. fig4.loc[msustr, 'grt', 'cond'][measure] = grtmaxfmi2 # save
  51. else:
  52. fig4.loc[msustr, 'mvi', False][measure] = mvimaxfmi # save
  53. fig4.loc[msustr, 'grt', False][measure] = grtmaxfmi # save
  54. if msustr in msu2exmpli:
  55. exmplis.append(keptmsui)
  56. exmplmsustrs.append(msustr)
  57. else:
  58. normlis.append(keptmsui)
  59. keptmsui += 1 # manually increment
  60. mvimaxfmis = np.asarray(mvimaxfmis)
  61. grtmaxfmis = np.asarray(grtmaxfmis)
  62. grtmaxfmi2s = np.asarray(grtmaxfmi2s)
  63. assert len(mvimaxfmis) == len(grtmaxfmis)
  64. npairs = len(mvimaxfmis)
  65. ## strip plot paired movie and grating max FMIs:
  66. if mvimeasure == 'meanrate': # 2 column wide strip plot
  67. f, a = plt.subplots(figsize=(figsize[0]*1.35, figsize[1]*1.5))
  68. elif mvimeasure == 'blankmeanrate': # 3 column wide strip plot
  69. f, a = plt.subplots(figsize=(figsize[0]*1.9, figsize[1]*1.5))
  70. elif mvimeasure == 'blankmeanburstratio': # 3 column normal width strip plot
  71. f, a = plt.subplots(figsize=(figsize[0]*1.4, figsize[1]))
  72. else:
  73. f, a = plt.subplots(figsize=figsize)
  74. wintitle('maxFMI %s movie grating stripplot' % mvimeasure)
  75. # plot y=0 line:
  76. a.axhline(y=0, ls='--', marker='', color='lightgray', zorder=-np.inf)
  77. datad = {'Movie':mvimaxfmis, 'Grating':grtmaxfmis}
  78. if isblank:
  79. datad = {'Movie':mvimaxfmis, 'Grating':grtmaxfmis,
  80. 'GratingCond':grtmaxfmi2s}
  81. data = pd.DataFrame.from_dict(datad, orient='index').transpose()
  82. sns.stripplot(ax=a, data=data, clip_on=False, marker='.',
  83. color='None', edgecolor='black', size=np.sqrt(50))
  84. # get fname of appropriate LMM .cvs file:
  85. if mvimeasure == 'meanrate' and grtmeasure == 'meanrate':
  86. fname = os.path.join('stats', 'figure_4a_pred_means.csv')
  87. elif mvimeasure == 'blankmeanrate' and grtmeasure == 'blankmeanrate':
  88. fname = os.path.join('stats', 'figure_4b_pred_means.csv')
  89. elif mvimeasure == 'meanburstratio' and grtmeasure == 'meanburstratio':
  90. fname = os.path.join('stats', 'figure_4_S1e_pred_means.csv')
  91. elif mvimeasure == 'blankmeanburstratio' and grtmeasure == 'blankmeanburstratio':
  92. fname = os.path.join('stats', 'figure_4_S1f_pred_means.csv')
  93. else:
  94. print("WARNING: No LMM stats for (mvimeasure=%s, grtmeasure=%s)"
  95. % (mvimeasure, grtmeasure))
  96. fname = None
  97. if fname:
  98. try:
  99. df = pd.read_csv(fname)
  100. foundregression = True
  101. except FileNotFoundError:
  102. print('Missing file: %s' % fname)
  103. foundregression = False
  104. if foundregression:
  105. # fetch LMM means from .csv:
  106. meanmvimaxfmi = df['mvi'][0]
  107. meangrtmaxfmi = df['grt'][0]
  108. # plot mean with short horizontal lines:
  109. a.plot([-0.25, 0.25], [meanmvimaxfmi, meanmvimaxfmi], '-', lw=2, c='red',
  110. zorder=np.inf)
  111. a.plot([0.75, 1.25], [meangrtmaxfmi, meangrtmaxfmi], '-', lw=2, c='red',
  112. zorder=np.inf)
  113. if isblank:
  114. meangrtmaxfmi2 = df['grt0c'][0]
  115. a.plot([1.75, 2.25], [meangrtmaxfmi2, meangrtmaxfmi2], '-', lw=2, c='red',
  116. zorder=np.inf)
  117. a.set_ylabel('%s FMI' % Axislabel)
  118. a.set_ylim(-1, 1)
  119. a.set_yticks([-1, -0.5, 0, 0.5, 1])
  120. a.spines['bottom'].set_position(('outward', 5))
  121. a.spines['bottom'].set_visible(False)
  122. a.tick_params(bottom=False)
  123. # connect the dots:
  124. if isblank:
  125. x = np.array([[0]*npairs, [1]*npairs, [2]*npairs])
  126. y = np.array([mvimaxfmis, grtmaxfmis, grtmaxfmi2s])
  127. else:
  128. x = np.array([[0]*npairs, [1]*npairs])
  129. y = np.array([mvimaxfmis, grtmaxfmis])
  130. signs = np.sign(y)
  131. nonsignchangeis = signs[0] == signs[1]
  132. posslopeis = (signs[0] < 0) & (signs[1] > 0)
  133. negslopeis = (signs[0] > 0) & (signs[1] < 0)
  134. #signchangeis = signs[0] != signs[1]
  135. a.plot(x[:, nonsignchangeis], y[:, nonsignchangeis], '-', c='k', alpha=0.2, lw=1)
  136. a.plot(x[:, posslopeis], y[:, posslopeis], '--', c='k', alpha=1.0, lw=1)
  137. a.plot(x[:, negslopeis], y[:, negslopeis], '-', c='k', alpha=1.0, lw=1)
  138. #a.plot(x[:, signchangeis], y[:, signchangeis], '-', c='k', alpha=1.0, lw=1)
  139. # due to jitter, dots don't perfectly connect. Can get actual data using:
  140. #mvixy, grtxy = a.collections[0].get_offsets(), a.collections[1].get_offsets()