123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- """Figure 4 and some 4S1 plots, use run -i fig4.py"""
- mvimeasures = ['meanrate', 'meanrate02', 'meanrate35', 'blankmeanrate',
- 'meanburstratio', 'blankmeanburstratio']
- grtmeasures = ['meanrate', 'meanrate', 'meanrate', 'blankmeanrate',
- 'meanburstratio', 'blankmeanburstratio']
- mi = pd.MultiIndex.from_product([mvigrtmsustrs, STIMTYPES, [False, 'prestim', 'cond']],
- names=['msu', 'stimtype', 'blank'])
- fig4 = pd.DataFrame(index=mi, columns=['meanrate', 'meanburstratio'])
- # strip plot movie vs grating FMI for applicable measures:
- np.random.seed(0) # to get identical horizontal jitter in strip plots on every run
- figsize = DEFAULTFIGURESIZE
- linmin, linmax, linstep = -1, 1, 1
- ticks = np.arange(linmin, linmax+linstep, linstep)
- for mvimeasure, grtmeasure in zip(mvimeasures, grtmeasures):
- measure = mvimeasure
- isblank = mvimeasure.startswith('blank')
- if isblank:
- measure = mvimeasure.split('blank')[1]
- axislabel = measure2axislabel[mvimeasure]
- axislabel = short2longaxislabel.get(axislabel, axislabel)
- if axislabel[0].islower():
- Axislabel = axislabel.capitalize()
- else:
- Axislabel = axislabel
- for st8 in ['none']:#ALLST8S:
- mvimaxfmis, grtmaxfmis, grtmaxfmi2s = [], [], []
- exmplis, exmplmsustrs, normlis = [], [], []
- # collect all movie FMIs:
- # collect paired movie and grating max FMIs:
- keptmsui = 0
- for msustr in mvigrtmsustrs:
- mvimaxfmi = maxFMI[mvimeasure][msustr, st8, 'mvi']
- grtmaxfmi = maxFMI[grtmeasure][msustr, st8, 'grt']
- if pd.isna(mvimaxfmi) or pd.isna(grtmaxfmi): # missing one or both maxFMI values
- continue
- if isblank: # add a second grtmaxfmi measure for blank cond
- blankname = 'blankcond' + measure
- grtmaxfmi2 = maxFMI[blankname][msustr, st8, 'grt']
- if pd.isna(grtmaxfmi2):
- continue
- else:
- grtmaxfmi2s.append(grtmaxfmi2)
- mvimaxfmis.append(mvimaxfmi)
- grtmaxfmis.append(grtmaxfmi)
- if measure in fig4.columns:
- assert mvimeasure == grtmeasure
- if isblank: # save both kinds of grt blank measures
- fig4.loc[msustr, 'mvi', 'prestim'][measure] = mvimaxfmi # save
- fig4.loc[msustr, 'grt', 'prestim'][measure] = grtmaxfmi # save
- fig4.loc[msustr, 'grt', 'cond'][measure] = grtmaxfmi2 # save
- else:
- fig4.loc[msustr, 'mvi', False][measure] = mvimaxfmi # save
- fig4.loc[msustr, 'grt', False][measure] = grtmaxfmi # save
- if msustr in msu2exmpli:
- exmplis.append(keptmsui)
- exmplmsustrs.append(msustr)
- else:
- normlis.append(keptmsui)
- keptmsui += 1 # manually increment
- mvimaxfmis = np.asarray(mvimaxfmis)
- grtmaxfmis = np.asarray(grtmaxfmis)
- grtmaxfmi2s = np.asarray(grtmaxfmi2s)
- assert len(mvimaxfmis) == len(grtmaxfmis)
- npairs = len(mvimaxfmis)
- ## strip plot paired movie and grating max FMIs:
- if mvimeasure == 'meanrate': # 2 column wide strip plot
- f, a = plt.subplots(figsize=(figsize[0]*1.35, figsize[1]*1.5))
- elif mvimeasure == 'blankmeanrate': # 3 column wide strip plot
- f, a = plt.subplots(figsize=(figsize[0]*1.9, figsize[1]*1.5))
- elif mvimeasure == 'blankmeanburstratio': # 3 column normal width strip plot
- f, a = plt.subplots(figsize=(figsize[0]*1.4, figsize[1]))
- else:
- f, a = plt.subplots(figsize=figsize)
- wintitle('maxFMI %s movie grating stripplot' % mvimeasure)
- # plot y=0 line:
- a.axhline(y=0, ls='--', marker='', color='lightgray', zorder=-np.inf)
- datad = {'Movie':mvimaxfmis, 'Grating':grtmaxfmis}
- if isblank:
- datad = {'Movie':mvimaxfmis, 'Grating':grtmaxfmis,
- 'GratingCond':grtmaxfmi2s}
- data = pd.DataFrame.from_dict(datad, orient='index').transpose()
- sns.stripplot(ax=a, data=data, clip_on=False, marker='.',
- color='None', edgecolor='black', size=np.sqrt(50))
- # get fname of appropriate LMM .cvs file:
- if mvimeasure == 'meanrate' and grtmeasure == 'meanrate':
- fname = os.path.join('stats', 'figure_4a_pred_means.csv')
- elif mvimeasure == 'blankmeanrate' and grtmeasure == 'blankmeanrate':
- fname = os.path.join('stats', 'figure_4b_pred_means.csv')
- elif mvimeasure == 'meanburstratio' and grtmeasure == 'meanburstratio':
- fname = os.path.join('stats', 'figure_4_S1e_pred_means.csv')
- elif mvimeasure == 'blankmeanburstratio' and grtmeasure == 'blankmeanburstratio':
- fname = os.path.join('stats', 'figure_4_S1f_pred_means.csv')
- else:
- print("WARNING: No LMM stats for (mvimeasure=%s, grtmeasure=%s)"
- % (mvimeasure, grtmeasure))
- fname = None
- if fname:
- try:
- df = pd.read_csv(fname)
- foundregression = True
- except FileNotFoundError:
- print('Missing file: %s' % fname)
- foundregression = False
- if foundregression:
- # fetch LMM means from .csv:
- meanmvimaxfmi = df['mvi'][0]
- meangrtmaxfmi = df['grt'][0]
- # plot mean with short horizontal lines:
- a.plot([-0.25, 0.25], [meanmvimaxfmi, meanmvimaxfmi], '-', lw=2, c='red',
- zorder=np.inf)
- a.plot([0.75, 1.25], [meangrtmaxfmi, meangrtmaxfmi], '-', lw=2, c='red',
- zorder=np.inf)
- if isblank:
- meangrtmaxfmi2 = df['grt0c'][0]
- a.plot([1.75, 2.25], [meangrtmaxfmi2, meangrtmaxfmi2], '-', lw=2, c='red',
- zorder=np.inf)
- a.set_ylabel('%s FMI' % Axislabel)
- a.set_ylim(-1, 1)
- a.set_yticks([-1, -0.5, 0, 0.5, 1])
- a.spines['bottom'].set_position(('outward', 5))
- a.spines['bottom'].set_visible(False)
- a.tick_params(bottom=False)
- # connect the dots:
- if isblank:
- x = np.array([[0]*npairs, [1]*npairs, [2]*npairs])
- y = np.array([mvimaxfmis, grtmaxfmis, grtmaxfmi2s])
- else:
- x = np.array([[0]*npairs, [1]*npairs])
- y = np.array([mvimaxfmis, grtmaxfmis])
- signs = np.sign(y)
- nonsignchangeis = signs[0] == signs[1]
- posslopeis = (signs[0] < 0) & (signs[1] > 0)
- negslopeis = (signs[0] > 0) & (signs[1] < 0)
- #signchangeis = signs[0] != signs[1]
- a.plot(x[:, nonsignchangeis], y[:, nonsignchangeis], '-', c='k', alpha=0.2, lw=1)
- a.plot(x[:, posslopeis], y[:, posslopeis], '--', c='k', alpha=1.0, lw=1)
- a.plot(x[:, negslopeis], y[:, negslopeis], '-', c='k', alpha=1.0, lw=1)
- #a.plot(x[:, signchangeis], y[:, signchangeis], '-', c='k', alpha=1.0, lw=1)
- # due to jitter, dots don't perfectly connect. Can get actual data using:
- #mvixy, grtxy = a.collections[0].get_offsets(), a.collections[1].get_offsets()
|