{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# The distinct iterative prior updating process in ASD and TD individuals \n", "\n", "Research has shown individuals with autism spectrum disorder (ASD) display unique patterns in predictive processing, yet it remains controversial regarding what causes these atypical behaviors. Both ASD individuals and typically developing (TD) counterparts participated in a task where they reproduced time durations over two sessions, one characterized by high volatility and the other by predictable sequence. Both sessions involved the same time durations, but the sequences differed in volatility. A visual stimulus (a disk) appeared for a given duration, and participants were asked to reproduce the duration by pressing a key.\n", "\n", "This repository contains the data and analysis scripts for this study. The codes and data are organized as follows:\n", "\n", "## 1. Folder Structure\n", "\n", "1. `/experiments`: Experimental codes and instructions\n", "\n", "This sub-folder contains Matlab codes and instructions for the duration reproduction task. The sequences of the duration reproductions are stored in the sub-folder `/experiments/seqs`. Those sequences were used for matched participants. \n", "\n", "2. `/data`: raw data files\n", "\n", "- `rawdata.csv`: Raw reproduction trials of all participants\n", "- `parinfo.csv`: Participant information, including measured scores AQ, EQ, SQ, IQ etc. \n", "\n", "3. `/figures`: output figures. \n", "4. analysis scripts\n", "- `analysis-notebook.ipynb`: Jupyter notebook for data analysis\n", "- `kmodelY.py`: Python script for the Kalman filter two-state model\n", "- \n", "\n", "## 2. Data Analysis\n", "\n", "### 2.1 Import raw data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Nothing done.\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0DurationVolatilitytrlNoblkNodur1pdurproductionvrepReproductionsubgroupitdrep_errsequencepreDurationblk1stOrderpreErroutlier
010.400Low Vola.110.4000.4000.4090.3770.374A31ASDNaN-0.02631NaN1LV FirstNaNTrue
120.500Low Vola.210.5000.5060.5070.3530.342A31ASD0.100-0.158310.4000LV First-0.026False
230.400Low Vola.310.4000.4000.4060.3420.326A31ASD-0.100-0.074310.5060LV First-0.158False
340.400Low Vola.410.4000.4000.3970.4120.406A31ASD0.0000.006310.4000LV First-0.074False
450.500Low Vola.510.5000.5060.5010.4120.398A31ASD0.100-0.102310.4000LV First0.006False
\n", "
" ], "text/plain": [ " Unnamed: 0 Duration Volatility trlNo blkNo dur1 pdur production \\\n", "0 1 0.400 Low Vola. 1 1 0.400 0.400 0.409 \n", "1 2 0.500 Low Vola. 2 1 0.500 0.506 0.507 \n", "2 3 0.400 Low Vola. 3 1 0.400 0.400 0.406 \n", "3 4 0.400 Low Vola. 4 1 0.400 0.400 0.397 \n", "4 5 0.500 Low Vola. 5 1 0.500 0.506 0.501 \n", "\n", " vrep Reproduction sub group itd rep_err sequence preDuration \\\n", "0 0.377 0.374 A31 ASD NaN -0.026 31 NaN \n", "1 0.353 0.342 A31 ASD 0.100 -0.158 31 0.400 \n", "2 0.342 0.326 A31 ASD -0.100 -0.074 31 0.506 \n", "3 0.412 0.406 A31 ASD 0.000 0.006 31 0.400 \n", "4 0.412 0.398 A31 ASD 0.100 -0.102 31 0.400 \n", "\n", " blk1st Order preErr outlier \n", "0 1 LV First NaN True \n", "1 0 LV First -0.026 False \n", "2 0 LV First -0.158 False \n", "3 0 LV First -0.074 False \n", "4 0 LV First 0.006 False " ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# load data analysis packages\n", "%reset\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import os\n", "import statsmodels.api as sm\n", "import scipy.stats as stats\n", "from scipy.optimize import least_squares\n", "# pingouin ANOVA\n", "import pingouin as pg\n", "# keep output precision to 3 decimal places\n", "pd.options.display.float_format = '{:,.3f}'.format\n", "\n", "# read data from ./data/rawdata.csv\n", "rawdata = pd.read_csv('./data/rawdata.csv')\n", "# change rawdata.group to upper case\n", "rawdata['group'] = rawdata['group'].str.upper()\n", "# add a new column preErr, indicating the previous trial's error\n", "rawdata['preErr'] = rawdata.groupby(['sub', 'Volatility'])['rep_err'].shift(1)\n", "# mark the outliers that exceed [Duration/3, Duration * 3] or preErr is nan\n", "rawdata['outlier'] = (rawdata['Reproduction'] < rawdata['Duration']/3) | (rawdata['Reproduction'] > rawdata['Duration']*3) | (rawdata['preErr'].isna())\n", "\n", "# show the first 5 rows of rawdata\n", "rawdata.head()\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The key columns in `rawdata.csv` are:\n", "1. Duration: the duration of the visual stimulus\n", "2. Reproduction: the reproduced duration\n", "3. rep_err: the reproduction error (Reproduction - Duration)\n", "4. preDuration: the duration of the previous visual stimulus\n", "5. sub: subject ID\n", "6. group: ASD or TD\n", "7. Volatility: Low Vola. or High Vola." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The outlier trials were generally low for both groups: 2.3% for the ASD group and 1% for the TD group. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "group\n", "ASD 0.023\n", "TD 0.010\n", "Name: outlier, dtype: float64" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# show the percentange of outliers in each group\n", "rawdata.groupby(['sub', 'group'])['outlier'].mean().reset_index().groupby('group')['outlier'].mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2 Duration sequences\n", "\n", "The experiment was structured into two distinct sessions, characterized by high and low volatility respectively. Each session included 500 trials, all following the same distribution, yet they varied in terms of sequential volatility. The provided figure illustrates a typical sequence of durations for a single participant.\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# sns \n", "sns.set(style='ticks', context='paper', rc={'figure.figsize': (10, 5)})\n", "#sns.color_palette(palette='colorblind')\n", "sns.set_palette(\"Dark2\")\n", "\n", "# select a subset of rawdata (sub == 'ara27') for illustration\n", "subdata = rawdata.query('sub == \"ara27\"')\n", "# plot Duration as a function of trlNo, using different color for Volatility\n", "# use color palette 'colorblind' from seaborn\n", "#sns.set_palette('colorblind')\n", "sns.lineplot(x='trlNo', y='Duration', hue='Volatility', data=subdata)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.3 Central tendency and autocorrelation\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# calculate the mean and standard deviation of Duration for each sub, group, Volatility, and Duration\n", "mdata = rawdata.query('outlier == False').groupby(\n", " ['sub', 'group', 'Volatility', 'Duration','sequence','Order']).agg(\n", " {'rep_err': ['mean', 'std']}).reset_index()\n", "mdata.columns = ['sub', 'group', 'Volatility', 'Duration', 'sequence','Order','rep_err', 'rep_err_std']\n", "\n", "# visualize the mean reproduction error as a function of Duration\n", "ax = sns.lmplot(x='Duration', y='rep_err', hue='Volatility', col='group', \n", " scatter_kws = {'alpha':0.3, 's':5}, legend = False, height = 3.5, \n", " data=mdata, hue_order = ['Low Vola.', 'High Vola.'])\n", "\n", "# add dashed line 0 to each subplot\n", "for ax1 in ax.axes.flat:\n", " ax1.axhline(0, ls='--')\n", "plt.ylim(-0.5, 1.)\n", "# save the figure to ./figures/rep_err_vs_Duration.png\n", "plt.savefig('./figures/rep_err_vs_Duration_b.png', dpi=300)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In both group, the high volatility session showed a larger central tendency than the low volatility session. Let's do linear regression and durbin-watson test for the autocorrelation in two environments." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
subVolatilitysequencegrouplevel_4r2interceptslopectiar_dwaic
0A31High Vola.31ASD00.3890.284-0.4510.4511.7109.836
1A31Low Vola.31ASD00.4590.279-0.5060.5061.017-4.688
2A32High Vola.32ASD00.6060.550-0.9040.9041.405-86.745
3A32Low Vola.32ASD00.0200.222-0.2270.2271.701166.866
4A33High Vola.33ASD00.1820.180-0.2390.2391.783-113.065
\n", "
" ], "text/plain": [ " sub Volatility sequence group level_4 r2 intercept slope cti \\\n", "0 A31 High Vola. 31 ASD 0 0.389 0.284 -0.451 0.451 \n", "1 A31 Low Vola. 31 ASD 0 0.459 0.279 -0.506 0.506 \n", "2 A32 High Vola. 32 ASD 0 0.606 0.550 -0.904 0.904 \n", "3 A32 Low Vola. 32 ASD 0 0.020 0.222 -0.227 0.227 \n", "4 A33 High Vola. 33 ASD 0 0.182 0.180 -0.239 0.239 \n", "\n", " ar_dw aic \n", "0 1.710 9.836 \n", "1 1.017 -4.688 \n", "2 1.405 -86.745 \n", "3 1.701 166.866 \n", "4 1.783 -113.065 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# define a regression function with input of a dataframe\n", "from statsmodels.stats.stattools import durbin_watson\n", "# define a regression function, input is a dataframe and y and x \n", "def reg_func(data, x='Duration'):\n", " # Check and remove NaNs or infinite values\n", " data = data.replace([np.inf, -np.inf], np.nan).dropna(subset=[x, 'rep_err'])\n", " \n", " if data.empty or data[x].isnull().any() or data['rep_err'].isnull().any():\n", " # Return None or some indication that the data was not valid for regression\n", " return pd.DataFrame({'r2': [None], 'intercept': [None], 'slope': [None],\n", " 'cti': [None], 'ar_dw': [None], 'aic': [None]})\n", "\n", " reg = sm.OLS(data['rep_err'], sm.add_constant(data[x])).fit()\n", " dw = durbin_watson(reg.resid)\n", " aic = reg.aic\n", " # return goodness of fit, intercept, slope, and dw\n", " return pd.DataFrame({'r2': reg.rsquared, 'intercept': reg.params.iloc[0], 'slope':reg.params.iloc[1],\n", " 'cti': -reg.params.iloc[1], 'ar_dw': dw, 'aic':reg.aic}, index=[0])\n", "# apply the regression function to each sub, Volatility, and group\n", "df_coef = rawdata.query('outlier == False').groupby(['sub', 'Volatility', 'sequence', 'group']).apply(reg_func).reset_index()\n", "# show the first 5 rows of df_coef\n", "df_coef.head()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.0031620.0030.0710.7900.001NaN
1Volatility1.6111621.61180.1050.0000.5641.000
2Interaction0.0051620.0050.2280.6340.004NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.003 1 62 0.003 0.071 0.790 0.001 NaN\n", "1 Volatility 1.611 1 62 1.611 80.105 0.000 0.564 1.000\n", "2 Interaction 0.005 1 62 0.005 0.228 0.634 0.004 NaN" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Mixed ANOVA on the central tendency effect cti\n", "# use pingouin to perform mixed ANOVA\n", "aov = pg.mixed_anova(data=df_coef, dv='cti', within='Volatility', between='group', subject='sub')\n", "# show the ANOVA table\n", "aov" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.0001620.0000.0090.9260.000NaN
1Volatility1.3701621.37084.9760.0000.5781.000
2Interaction0.0101620.0100.6500.4230.010NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.000 1 62 0.000 0.009 0.926 0.000 NaN\n", "1 Volatility 1.370 1 62 1.370 84.976 0.000 0.578 1.000\n", "2 Interaction 0.010 1 62 0.010 0.650 0.423 0.010 NaN" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Mixed ANOVA on the intercept\n", "aov = pg.mixed_anova(data=df_coef, \n", " dv='intercept', within='Volatility', between='group', subject='sub')\n", "# show the ANOVA table\n", "aov" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However, autoregressive durbins-watson test values were significant different between the two groups, even with the outliers. This suggests that we need to consider the inter-trial updating of the prior information." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.3141620.3144.4980.0380.068NaN
1Volatility0.0591620.0592.1420.1480.0331.000
2Interaction0.0601620.0602.1830.1450.034NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.314 1 62 0.314 4.498 0.038 0.068 NaN\n", "1 Volatility 0.059 1 62 0.059 2.142 0.148 0.033 1.000\n", "2 Interaction 0.060 1 62 0.060 2.183 0.145 0.034 NaN" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pg.mixed_anova(data=df_coef, \n", " dv='ar_dw', within='Volatility', between = 'group', subject='sub')" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# plot mean and error bar of aw_dw as a function of group, Volatility\n", "#sns.barplot(data=df_coef.query(\"sequence not in @outliers_regress\"), \n", "# x='group', y='ar_dw', hue='Volatility',capsize = .1,\n", "# zorder = 5, errorbar=('ci', 68))\n", "# change y axis as 'DW index'\n", "# y axis from 1 to 2\n", "#plt.ylim(1, 2.4)\n", "# add stripplot\n", "sns.boxplot(data=df_coef, y='group', x='ar_dw', hue='Volatility', orient = 'h', hue_order=['Low Vola.', 'High Vola.'])\n", "# only show the last two legends\n", "handles, labels = plt.gca().get_legend_handles_labels()\n", "#plt.legend(handles[2:], labels[2:])\n", "# add dashed line 2 to indicate the 0 autocorrelation\n", "plt.axvline(2, ls='--', c='k')\n", "# remove box around the plot\n", "sns.despine()\n", "plt.xlabel('DW index')\n", "\n", "# save the figure to vector file ./figures/ar_dw.pdf\n", "plt.savefig('./figures/ar_dw.pdf', dpi=300)\n", "# save the figure to ./figures/ar_dw.png\n", "plt.savefig('./figures/ar_dw.png', dpi=300)\n", "plt.show()\n", "\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Tdofalternativep-valCI95%cohen-dBF10power
T-test-7.91363two-sided0.000[1.72, 1.83]0.9891.893e+081.000
\n", "
" ], "text/plain": [ " T dof alternative p-val CI95% cohen-d BF10 power\n", "T-test -7.913 63 two-sided 0.000 [1.72, 1.83] 0.989 1.893e+08 1.000" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# let's compare if ar_dw was significantly lower than 2 in the two groups. \n", "# We will use a one-sample t-test to compare the mean of ar_dw to 2.\n", "# t-test for ar_dw\n", "pg.ttest(df_coef.query(\"group == 'ASD'\")['ar_dw'], 2)\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Tdofalternativep-valCI95%cohen-dBF10power
T-test-11.87763two-sided0.000[1.62, 1.73]1.4855.748e+141.000
\n", "
" ], "text/plain": [ " T dof alternative p-val CI95% cohen-d BF10 \\\n", "T-test -11.877 63 two-sided 0.000 [1.62, 1.73] 1.485 5.748e+14 \n", "\n", " power \n", "T-test 1.000 " ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# for TD group\n", "pg.ttest(df_coef.query(\"group == 'TD'\")['ar_dw'], 2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2.4. Outliers \n", "There was no difference in cti between groups, partly because there were three outliers in the ASD groups, as the CTI > 0.9. " ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
subVolatilitysequencegrouplevel_4r2interceptslopectiar_dwaic
2A32High Vola.32ASD00.6060.550-0.9040.9041.405-86.745
36aril02High Vola.2ASD00.6471.051-0.9970.9971.313-131.330
50arm13High Vola.13ASD00.6170.669-0.9070.9071.650-133.904
\n", "
" ], "text/plain": [ " sub Volatility sequence group level_4 r2 intercept slope cti \\\n", "2 A32 High Vola. 32 ASD 0 0.606 0.550 -0.904 0.904 \n", "36 aril02 High Vola. 2 ASD 0 0.647 1.051 -0.997 0.997 \n", "50 arm13 High Vola. 13 ASD 0 0.617 0.669 -0.907 0.907 \n", "\n", " ar_dw aic \n", "2 1.405 -86.745 \n", "36 1.313 -131.330 \n", "50 1.650 -133.904 " ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_coef.query(\"cti > 0.9\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's visualize these three outliers. " ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "outliers_regress = df_coef.query(\"cti > 0.9\").sequence\n", "# from mdata, select rows with sequence in outliers_regress, and plot rep_err vs. Duration\n", "# separate for individual sub, Volatility\n", "moutliers = rawdata.query('sequence in @outliers_regress and group == \"ASD\" and outlier == False')\n", "sns.set_palette('Dark2')\n", "ax = sns.lmplot(x='Duration', y='Reproduction', hue='Volatility', col='sub',\n", " scatter_kws = {'alpha':0.3, 's':5}, legend = False, height = 3.5,\n", " data=moutliers)\n", "ax.set(ylabel = 'Reproduction (s)')\n", "# remove subplots titles\n", "ax.set_titles('')\n", "# add diagonal dashed line to each subplot\n", "for ax in ax.axes.flat:\n", " ax.plot(ax.get_xlim(), ax.get_xlim(), ls='--', c='k')\n", "\n", "# put legend in the upper right corner\n", "plt.legend(loc='upper right')\n", "# save the figure to ./figures/rep_err_vs_Duration_outliers.png\n", "plt.savefig('./figures/outliers.png', dpi=300)\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.0471560.0471.7340.1930.030NaN
1Volatility1.2041561.20481.3660.0000.5921.000
2Interaction0.0491560.0493.3240.0740.056NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.047 1 56 0.047 1.734 0.193 0.030 NaN\n", "1 Volatility 1.204 1 56 1.204 81.366 0.000 0.592 1.000\n", "2 Interaction 0.049 1 56 0.049 3.324 0.074 0.056 NaN" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#Let's see what happens if we remove the outliers from the regression analysis.\n", "# ANOVA for slopes with group as between-subject factor and Volatility as within-subject factor\n", "pg.mixed_anova(data=df_coef.query(\"sequence not in @outliers_regress\"), \n", " dv='cti', within='Volatility', between = 'group', subject='sub')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It remains not significant between groups, suggesting both groups had similar acquisition of the prior (Volatility) information. But let's check the residual autocorrelation.\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.5591560.5598.3570.0050.130NaN
1Volatility0.0871560.0873.3380.0730.0561.000
2Interaction0.0811560.0813.1030.0840.052NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.559 1 56 0.559 8.357 0.005 0.130 NaN\n", "1 Volatility 0.087 1 56 0.087 3.338 0.073 0.056 1.000\n", "2 Interaction 0.081 1 56 0.081 3.103 0.084 0.052 NaN" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pg.mixed_anova(data=df_coef.query(\"sequence not in @outliers_regress\"), \n", " dv='ar_dw', within='Volatility', between = 'group', subject='sub')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By excluding the outliers, the difference between groups became even more significant, p = .005. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2.5 Visualize the behavioral results \n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# now we combine plots together for the paper\n", "sns.set(style='ticks', context='paper')\n", "fig, axes = plt.subplots(nrows = 2, ncols = 2, figsize = (7, 6))\n", "sns.set_palette('Dark2')\n", "# for each group, on seperate subplot, plot rep_err as a function of Duration, using different color for Volatility with regplot \n", "for i, group in enumerate(df_coef.group.unique()):\n", " for j, vol in enumerate(df_coef.Volatility.unique()):\n", " sns.regplot(data=mdata.query(\"group == @group and Volatility == @vol\"), \n", " x='Duration', y='rep_err', ax = axes[0,i], color = sns.color_palette('Dark2')[1-j], label = vol, \n", " ci = None, scatter_kws = {'alpha':0.3, 's':5})\n", " if i == 1:\n", " axes[0,i].legend(loc='lower left')\n", " else:\n", " # remove legend\n", " axes[0,i].legend().remove()\n", " # add dashed line 0 to each subplot\n", " axes[0,i].axhline(0, ls='--', c='k')\n", " # set y axis limit\n", " axes[0,i].set_ylim(-0.6, 0.6)\n", " # set x label to 'Duration (s)', y label to 'Reproduction error (s)'\n", " axes[0,i].set(xlabel = 'Duration (s)', ylabel = 'Reproduction error (s)')\n", " # set title to group\n", " axes[0,i].set_title(group)\n", "# add horizontal boxplot for cti from df_coef on the second row first column\n", "sns.boxplot(data=df_coef, \n", " y='group', x='cti', hue='Volatility', ax = axes[1,0], width = 0.5, \n", " orient = 'h', hue_order=['Low Vola.', 'High Vola.'])\n", "axes[1,0].set(xlabel = 'Central Tendency Index')\n", "axes[1,0].legend().remove()\n", "# add a vertical line at 0.0\n", "axes[1,0].axvline(0, ls='--', c='k')\n", "# remove y axis label\n", "axes[1,0].set(ylabel = '')\n", "# add horizontal boxplot for ar_dw from df_coef on the second row second column\n", "sns.boxplot(data=df_coef, y='group', x='ar_dw', hue='Volatility', width = 0.5, \n", " ax = axes[1,1], orient = 'h', hue_order=['Low Vola.', 'High Vola.'])\n", "# legend off\n", "#axes[1,1].legend().remove()\n", "axes[1,1].legend(loc='lower right')\n", "axes[1,1].set(xlabel = 'Autocorrelation (DW) Index')\n", "# add a vertical line at 2.0\n", "axes[1,1].axvline(2, ls='--', c='k')\n", "# x axis from 0.9 to 3\n", "axes[1,1].set_xlim(0.9, 3)\n", "axes[1,1].set(ylabel = '')\n", "# remove box around the plot\n", "sns.despine()\n", "# add labels to subplots a, b, c, d\n", "for i, label in enumerate(['a', 'b', 'c', 'd']):\n", " axes[int(i/2),i%2].text(-0.1, 1.1, label, transform=axes[int(i/2),i%2].transAxes, \n", " fontsize=16, fontweight='bold', va='top', ha='right')\n", " \n", "\n", "# Adjust layout and show plot\n", "plt.tight_layout()\n", "\n", "# save fig to vector file ./figures/rep_err_vs_Duration.png\n", "plt.savefig('./figures/rep_err_vs_Duration.png', dpi=300, bbox_inches='tight', facecolor='white')\n", "\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2.6 General bias (over- or under-reproduction) \n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.0051620.0050.3850.5370.006NaN
1Volatility0.0101620.0105.5730.0210.0821.000
2Interaction0.0011620.0010.6980.4070.011NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.005 1 62 0.005 0.385 0.537 0.006 NaN\n", "1 Volatility 0.010 1 62 0.010 5.573 0.021 0.082 1.000\n", "2 Interaction 0.001 1 62 0.001 0.698 0.407 0.011 NaN" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# average the reproduction errors for general biases\n", "mrep_err = rawdata.query(\"outlier == False\").groupby(['sub','group', 'Volatility'])['rep_err'].mean().reset_index()\n", "# pingouin mixed ANOVA on mrep_err\n", "aov = pg.mixed_anova(data=mrep_err, dv='rep_err', within='Volatility', between='group', subject='sub')\n", "# show the ANOVA table\n", "aov" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
groupmeanstd
0ASD0.0240.099
1TD0.0360.062
\n", "
" ], "text/plain": [ " group mean std\n", "0 ASD 0.024 0.099\n", "1 TD 0.036 0.062" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# calculate the mean reproduction error and standard deviation for each group and Volatility\n", "mrep_err.groupby(['group'])['rep_err'].agg(['mean', 'std']).reset_index()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2.7 Two-state Iterative model\n", "\n" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# include in one notebook (the following code from KmodelY.py)\n", "from scipy.optimize import least_squares\n", "from statsmodels.stats.stattools import durbin_watson\n", "import numpy as np\n", "\n", "def fitKmodel(subdata, nolog=None, pfit=None, p0=None):\n", " \"\"\"\n", " Parameters:\n", " - subdata : subject data\n", " - nolog : If zero, logarithm is not used (default is 0)\n", " - pfit : A list of logical values to indicate which parameters to fit (default is [True, True, True])\n", " - p0 : Initial parameters (must always have length 3)\n", " \n", " Returns:\n", " - px : Parameters of the model\n", "\n", " By S.Glasauer 2019 (matlab), translated to Python by Strongway\n", " # add AIC and DW\n", "\n", " \"\"\"\n", " \n", " # Handle default arguments\n", " if p0 is None:\n", " p0 = [1., 1, 0]\n", " if pfit is None:\n", " pfit = [True, True, True]\n", " if nolog is None:\n", " nolog = 0\n", "\n", " # Convert pfit to logical and filter p0\n", " pfit = np.array(pfit, dtype=bool)\n", " p0 = np.array(p0)[pfit]\n", "\n", " # Lower bounds (lb) for the optimization\n", " lb = np.array([0, 0, -np.inf])[pfit]\n", " \n", "\n", " # extract Duration and Reproduction from subdata as 2d array\n", " x = subdata['Duration'].values\n", " y = subdata['Reproduction'].values\n", " # replace extreme y with nan with y > 3 * x or y < x/3\n", " y[(y > 3 * x) | (y < x/3)] = np.nan\n", " # combine x,y as 2d array\n", " stimrep = np.vstack([x,y]).T\n", " # Perform the optimization using least_squares (equivalent to lsqnonlin in MATLAB)\n", " result = least_squares(kmodelY, p0, args = (stimrep, 1),\n", " bounds=(lb, np.inf), method='trf')\n", " \n", " # calculate kalmann filter parameters\n", " q11 = result.x[0]\n", " q22 = result.x[1]\n", " r = 1\n", " # calculate residual sum of squares\n", " rss = np.sum(result.fun**2)\n", " dw = durbin_watson(result.fun)\n", " # number of parameters\n", " k = len(result.x)\n", " # number of observations\n", " n = len(stimrep)\n", " # calculate the log-likelihood\n", " ll = -n/2*(np.log(2*np.pi) + np.log(rss/n) + 1)\n", " # calculate the Akaike information criterion (AIC)\n", " aic = 2*k - 2*ll\n", " # steady state solution\n", " p22 = (q22+np.sqrt(q22*q22+4*(q11+r)*q22))/2\n", " K = np.array([p22 + q11, p22])/(p22+q11+r)\n", " # return the optimized parameters, steady state solution, and AIC\n", " return np.append(np.append(result.x, K), [aic, dw]) # Optimized parameters\n", "\n", "\n", "def kmodelY(par, stimrep, nolog=1, pfit=[1, 1, 1]):\n", " \"\"\"\n", " Function to perform Kalman filter-based estimation.\n", " \n", " Parameters:\n", " - par: Model parameters (if pfit = [1,1,1], then par = [q1/r, q2/r, cost-related parameter (0 for median)])\n", " - stimrep: Stimulus representation\n", " - nolog: Flag to decide if logarithm transformation is needed\n", " - pfit: Parameter fitting list (note: len(par) = sum(pfit))\n", " \n", " Returns:\n", " - sres: Stimulus residuals\n", " - xest: Estimated state\n", " - pest: Estimate error covariance\n", " - resp: Response\n", " - perr: Prediction error\n", "\n", " S.Glasauer 2019/2023, translated to Python by Strongway\n", " \"\"\"\n", "\n", " # Convert pfit to a boolean array\n", " pfit = np.array(pfit, dtype=bool)\n", "\n", " # Adjust pfit based on the size of par\n", " if len(par) < 3:\n", " pfit[len(par):] = False\n", " \n", " # Adjust stimrep's shape for further processing\n", " if stimrep.shape[1] == 1:\n", " stimrep = np.tile(stimrep, (1, 2))\n", " # the first column is the stimulus, the second column is the response, \n", " # and add the third column to indicate the start of a new sequence\n", " #if stimrep.shape[1] == 2:\n", " # stimrep = np.hstack((stimrep, np.zeros((stimrep.shape[0], 1))))\n", " # stimrep[0, 2] = 1\n", " \n", " # Initialize pars and overwrite with provided parameters based on pfit\n", " pars = np.array([0.0, 0.0, 0.0])\n", " pars[pfit] = par\n", " par = pars\n", "\n", " # Constants for the model\n", " a = 10.0\n", " off = 1.\n", " r = 1.\n", " q1 = par[0] * r\n", " q2 = par[1] * r\n", "\n", " # Define matrices Q, P, H, and F for the Kalman filter of two-state model\n", " # details see Glasauer & Shi, 2022, Sci. Rep., https://doi.org/10.1038/s41598-022-14939-8\n", " Q = np.array([[q1, 0], [0, q2]])\n", " P = np.array([[r, 0], [0, r]])\n", " H = np.array([[1., 0]])\n", " F = np.array([[0, 1.], [0, 1.]])\n", "\n", " # Apply logarithm transformation if nolog is false\n", " if nolog:\n", " z = stimrep[:, 0]\n", " else: # log transformation\n", " z = np.log(a * stimrep[:, 0] + off)\n", "\n", " # Initialize state vector x\n", " x = np.array([[z[0]], [z[0]]])\n", "\n", " # Initialize matrices for storing results\n", " xest = np.zeros((len(z), 2))\n", " pest = np.zeros((len(z), 2))\n", " perr = np.zeros(len(z))\n", "\n", " # Kalman filter estimation loop\n", " for i in range(len(z)):\n", " \n", " x = F@x\n", " P = F@P@F.T + Q\n", " K = P@H.T/(H@P@H.T + r)\n", " perr[i] = z[i] - H@x\n", " x = x + K*perr[i]\n", " P = (np.eye(2) - K@H)@P\n", "\n", " pest[i, :] = np.diag(P)\n", " xest[i, :] = x.reshape(-1)\n", "\n", " # Adjust for third parameter, if present\n", " if len(par) == 3:\n", " sh = par[2]\n", " else:\n", " sh = 0\n", "\n", " # Compute response, adjusting for logarithm if needed\n", " if nolog:\n", " resp = xest[:, 0] + sh\n", " else: # log transformation\n", " resp = (np.exp(xest[:, 0] + sh) - off)/a \n", " \n", "\n", " # Calculate stimulus residuals\n", " sres = stimrep[:, 1] - resp\n", "\n", " # Remove NaNs from sres\n", " sres = sres[np.isfinite(sres)]\n", "\n", " return sres\n" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/dc/hksrz0yj5bb8n7f4_yptkcmw0000gn/T/ipykernel_74932/1675682148.py:146: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " perr[i] = z[i] - H@x\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
subVolatilitysequencegroupparp1p2tauK1K2AIC_2SDW_2SKdkseq
0A31High Vola.31ASD[0.9185124640252414, 0.05243938464301662, -0.1...0.9190.052-0.1700.5580.152-1.1771.728-0.4060.067
1A31Low Vola.31ASD[0.7563162466196017, 2.049556565459247e-05, -0...0.7560.000-0.2900.4330.003-42.8791.181-0.4290.002
2A32High Vola.32ASD[0.09834201160694395, 1.1492562435182896e-18, ...0.0980.000-0.3300.0900.000-132.7621.382-0.0900.000
3A32Low Vola.32ASD[3.342491279429646e-10, 0.4012756813450186, -0...0.0000.401-0.0010.4640.464158.7941.669-0.0000.249
4A33High Vola.33ASD[1.0307068777913497, 1.3110276998645982, -0.06...1.0311.311-0.0610.7750.543-134.0991.838-0.2320.122
\n", "
" ], "text/plain": [ " sub Volatility sequence group \\\n", "0 A31 High Vola. 31 ASD \n", "1 A31 Low Vola. 31 ASD \n", "2 A32 High Vola. 32 ASD \n", "3 A32 Low Vola. 32 ASD \n", "4 A33 High Vola. 33 ASD \n", "\n", " par p1 p2 tau K1 \\\n", "0 [0.9185124640252414, 0.05243938464301662, -0.1... 0.919 0.052 -0.170 0.558 \n", "1 [0.7563162466196017, 2.049556565459247e-05, -0... 0.756 0.000 -0.290 0.433 \n", "2 [0.09834201160694395, 1.1492562435182896e-18, ... 0.098 0.000 -0.330 0.090 \n", "3 [3.342491279429646e-10, 0.4012756813450186, -0... 0.000 0.401 -0.001 0.464 \n", "4 [1.0307068777913497, 1.3110276998645982, -0.06... 1.031 1.311 -0.061 0.775 \n", "\n", " K2 AIC_2S DW_2S Kd kseq \n", "0 0.152 -1.177 1.728 -0.406 0.067 \n", "1 0.003 -42.879 1.181 -0.429 0.002 \n", "2 0.000 -132.762 1.382 -0.090 0.000 \n", "3 0.464 158.794 1.669 -0.000 0.249 \n", "4 0.543 -134.099 1.838 -0.232 0.122 " ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# fit two-state model\n", "# from kmodelY import * #(when use kmodelY.py)\n", "\n", "# fit two-state model\n", "# subsecting each subject, volatility, and estimate two-state model parameters\n", "df_kmodel = rawdata.groupby(\n", " ['sub', 'Volatility', 'sequence', 'group']).apply(\n", " fitKmodel).reset_index()\n", "df_kmodel.columns = ['sub', 'Volatility', 'sequence', 'group', 'par']\n", "# split the parameters to columns\n", "df_kmodel[['p1','p2','tau','K1', 'K2','AIC_2S','DW_2S']] = pd.DataFrame(df_kmodel['par'].tolist(), index=df_kmodel.index)\n", "# add a column for the difference between K1 and K2\n", "df_kmodel['Kd'] = df_kmodel['K2'] - df_kmodel['K1']\n", "# sequential dependence analytical results for randomized sequences\n", "df_kmodel['kseq'] = df_kmodel['K2']*(1-df_kmodel['K1'])\n", "df_kmodel.head()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
subVolatilitysequencegroupparp1p2tauK1K2...Kdkseqlevel_4r2interceptslopectiar_dwaicdAIC
0A31High Vola.31ASD[0.9185124640252414, 0.05243938464301662, -0.1...0.9190.052-0.1700.5580.152...-0.4060.06700.3890.284-0.4510.4511.7109.836-11.013
1A31Low Vola.31ASD[0.7563162466196017, 2.049556565459247e-05, -0...0.7560.000-0.2900.4330.003...-0.4290.00200.4590.279-0.5060.5061.017-4.688-38.192
2A32High Vola.32ASD[0.09834201160694395, 1.1492562435182896e-18, ...0.0980.000-0.3300.0900.000...-0.0900.00000.6060.550-0.9040.9041.405-86.745-46.017
3A32Low Vola.32ASD[3.342491279429646e-10, 0.4012756813450186, -0...0.0000.401-0.0010.4640.464...-0.0000.24900.0200.222-0.2270.2271.701166.866-8.072
4A33High Vola.33ASD[1.0307068777913497, 1.3110276998645982, -0.06...1.0311.311-0.0610.7750.543...-0.2320.12200.1820.180-0.2390.2391.783-113.065-21.033
\n", "

5 rows × 22 columns

\n", "
" ], "text/plain": [ " sub Volatility sequence group \\\n", "0 A31 High Vola. 31 ASD \n", "1 A31 Low Vola. 31 ASD \n", "2 A32 High Vola. 32 ASD \n", "3 A32 Low Vola. 32 ASD \n", "4 A33 High Vola. 33 ASD \n", "\n", " par p1 p2 tau K1 \\\n", "0 [0.9185124640252414, 0.05243938464301662, -0.1... 0.919 0.052 -0.170 0.558 \n", "1 [0.7563162466196017, 2.049556565459247e-05, -0... 0.756 0.000 -0.290 0.433 \n", "2 [0.09834201160694395, 1.1492562435182896e-18, ... 0.098 0.000 -0.330 0.090 \n", "3 [3.342491279429646e-10, 0.4012756813450186, -0... 0.000 0.401 -0.001 0.464 \n", "4 [1.0307068777913497, 1.3110276998645982, -0.06... 1.031 1.311 -0.061 0.775 \n", "\n", " K2 ... Kd kseq level_4 r2 intercept slope cti ar_dw \\\n", "0 0.152 ... -0.406 0.067 0 0.389 0.284 -0.451 0.451 1.710 \n", "1 0.003 ... -0.429 0.002 0 0.459 0.279 -0.506 0.506 1.017 \n", "2 0.000 ... -0.090 0.000 0 0.606 0.550 -0.904 0.904 1.405 \n", "3 0.464 ... -0.000 0.249 0 0.020 0.222 -0.227 0.227 1.701 \n", "4 0.543 ... -0.232 0.122 0 0.182 0.180 -0.239 0.239 1.783 \n", "\n", " aic dAIC \n", "0 9.836 -11.013 \n", "1 -4.688 -38.192 \n", "2 -86.745 -46.017 \n", "3 166.866 -8.072 \n", "4 -113.065 -21.033 \n", "\n", "[5 rows x 22 columns]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# combine df_kmodel and df_seq by sub, Volatility, and group\n", "kpars = df_kmodel.merge(df_coef, on=['sub', 'Volatility', 'sequence', 'group'])\n", "# show the first 5 rows of df_kmodel\n", "kpars['dAIC'] = kpars['AIC_2S'] - kpars['aic']\n", "kpars.head()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "-10.75263200035032" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# mean benefits of two-state model in terms of AIC\n", "kpars['dAIC'].mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "the AIC difference between the two models was significant, suggesting the two-state model was a better fit to the data." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, axes = plt.subplots(ncols = 2, figsize = (7, 3.5))\n", "# barplot for DW_2S as a function of group, Volatility\n", "sns.barplot(data=kpars, x='group', y='DW_2S', hue='Volatility', capsize = .1,\n", " zorder = 5, errorbar=('ci', 68), ax = axes[0])\n", "# change y axis to 1.5 to 2\n", "axes[0].set_ylim(1.5, 2)\n", "# add dashed line 2 to indicate the 0 autocorrelation\n", "axes[0].axhline(2, ls='--', c='k')\n", "axes[0].axhline(1.7, ls='--', c='k')\n", "# second subplot for the ar_dw as a function of group, Volatility\n", "sns.barplot(data=kpars, x='group', y='ar_dw', hue='Volatility', capsize = .1,\n", " zorder = 5, errorbar=('ci', 68), ax = axes[1])\n", "# change y axis to 1.5 to 2\n", "axes[1].set_ylim(1.5, 2)\n", "# add dashed line 2 to indicate the 0 autocorrelation\n", "axes[1].axhline(2, ls='--', c='k')\n", "axes[1].axhline(1.7, ls='--', c='k')" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
K1K2kseqtau
meansemmeansemmeansemmeansem
groupVolatility
ASDHigh Vola.0.6810.0430.4460.0470.1000.0130.0130.019
Low Vola.0.6420.0360.3440.0570.0890.0160.0410.017
TDHigh Vola.0.6820.0260.3960.0390.1130.0120.0310.010
Low Vola.0.5680.0350.1980.0380.0880.0160.0570.010
\n", "
" ], "text/plain": [ " K1 K2 kseq tau \n", " mean sem mean sem mean sem mean sem\n", "group Volatility \n", "ASD High Vola. 0.681 0.043 0.446 0.047 0.100 0.013 0.013 0.019\n", " Low Vola. 0.642 0.036 0.344 0.057 0.089 0.016 0.041 0.017\n", "TD High Vola. 0.682 0.026 0.396 0.039 0.113 0.012 0.031 0.010\n", " Low Vola. 0.568 0.035 0.198 0.038 0.088 0.016 0.057 0.010" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "# average K1, K2, and tau for each Volatility, group, with standard error\n", "df_kmodel.groupby(['group', 'Volatility']).agg({'K1': ['mean', 'sem'], 'K2': ['mean', 'sem'], 'kseq': ['mean', 'sem'], 'tau': ['mean', 'sem']})\n", "\n", "# save kmodel_v to csv file ./data/kmodel_v.csv\n", "#kmodel_v.to_csv('./data/kmodel_v.csv', index=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Statistics for those parameters" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.0431620.0430.6850.4110.011NaN
1Volatility0.1881620.18810.6130.0020.1461.000
2Interaction0.0461620.0462.6200.1110.041NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.043 1 62 0.043 0.685 0.411 0.011 NaN\n", "1 Volatility 0.188 1 62 0.188 10.613 0.002 0.146 1.000\n", "2 Interaction 0.046 1 62 0.046 2.620 0.111 0.041 NaN" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# statistics for K1\n", "pg.mixed_anova(data=df_kmodel, dv='K1', within='Volatility', between = 'group', subject='sub')" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.3071620.3074.2320.0440.064NaN
1Volatility0.7251620.72511.7540.0010.1591.000
2Interaction0.0751620.0751.2090.2760.019NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.307 1 62 0.307 4.232 0.044 0.064 NaN\n", "1 Volatility 0.725 1 62 0.725 11.754 0.001 0.159 1.000\n", "2 Interaction 0.075 1 62 0.075 1.209 0.276 0.019 NaN" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# statistics for K2\n", "pg.mixed_anova(data=df_kmodel, dv='K2', within='Volatility', between = 'group', subject='sub')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "K2 shows a significant difference between groups!" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.0091620.0090.7940.3760.013NaN
1Volatility0.0241620.02411.8350.0010.1601.000
2Interaction0.0001620.0000.0130.9110.000NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.009 1 62 0.009 0.794 0.376 0.013 NaN\n", "1 Volatility 0.024 1 62 0.024 11.835 0.001 0.160 1.000\n", "2 Interaction 0.000 1 62 0.000 0.013 0.911 0.000 NaN" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# statistics for tau\n", "pg.mixed_anova(data=df_kmodel, dv='tau', within='Volatility', between = 'group', subject='sub')" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "# exclude outliers\n", "kmodel_v = df_kmodel.query(\"sequence not in @outliers_regress\")\n", "#kmodel_v = df_kmodel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize the parameters" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# visualize the mean and standard error of K1, K2, and tau as a function of group, Volatility\n", "sns.set(style='ticks', context='paper')\n", "sns.set_palette('Dark2')\n", "\n", "# three subplots side by side\n", "fig, axes = plt.subplots(ncols = 3, figsize = (9, 3.5))\n", "# barplot for K1 as a function of group, Volatility\n", "sns.barplot(data=df_kmodel, x='group', y='K1', hue='Volatility', capsize = .1,\n", " zorder = 5, errorbar=('se'), ax = axes[0], hue_order=['Low Vola.', 'High Vola.'])\n", "# y axis from 0.5 to 1\n", "axes[0].set_ylim(0.1, .8)\n", "# legend top right\n", "axes[0].legend(loc='upper right')\n", "axes[0].legend().remove()\n", "# remove x axis label\n", "axes[0].set(xlabel = '')\n", "# barplot for K2 as a function of group, Volatility\n", "sns.barplot(data=df_kmodel, x='group', y='K2', hue='Volatility', capsize = .1,\n", " zorder = 5, errorbar=('se'), ax = axes[1], hue_order=['Low Vola.', 'High Vola.'])\n", "# legend off\n", "#axes[1].legend().remove()\n", "axes[1].set_ylim(0.1, .8)\n", "axes[1].set(xlabel = '')\n", "\n", "# barplot for tau as a function of group, Volatility\n", "# change the tau from seconds to milliseconds\n", "df_kmodel['Tau'] = df_kmodel['tau']*1000\n", "sns.barplot(data=df_kmodel, x='group', y='Tau', hue='Volatility', capsize = .1,\n", " zorder = 5, errorbar=('se'), ax = axes[2], hue_order=['Low Vola.', 'High Vola.'])\n", "# y label to 'tau (ms)'\n", "axes[2].set_ylabel('General Bias (ms)')\n", "# legend off\n", "axes[2].legend().remove()\n", "axes[2].set(xlabel = '')\n", "axes[2].set_ylim(0, 70)\n", "\n", "# remove box around the plot\n", "sns.despine()\n", "# add labels to subplots a, b, c, d\n", "for i, label in enumerate(['a', 'b', 'c']):\n", " axes[i].text(-0.1, 1.1, label, transform=axes[i].transAxes, \n", " fontsize=16, fontweight='bold', va='top', ha='right')\n", "# tight layout\n", "plt.tight_layout()\n", "\n", "# save the figure to vector file ./figures/kmodel.png\n", "plt.savefig('./figures/kmodel.png', dpi=300)\n" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
K1K2kseqtau
meansemmeansemmeansemmeansem
groupVolatility
ASDHigh Vola.0.6810.0430.4460.0470.1000.0130.0130.019
Low Vola.0.6420.0360.3440.0570.0890.0160.0410.017
TDHigh Vola.0.6820.0260.3960.0390.1130.0120.0310.010
Low Vola.0.5680.0350.1980.0380.0880.0160.0570.010
\n", "
" ], "text/plain": [ " K1 K2 kseq tau \n", " mean sem mean sem mean sem mean sem\n", "group Volatility \n", "ASD High Vola. 0.681 0.043 0.446 0.047 0.100 0.013 0.013 0.019\n", " Low Vola. 0.642 0.036 0.344 0.057 0.089 0.016 0.041 0.017\n", "TD High Vola. 0.682 0.026 0.396 0.039 0.113 0.012 0.031 0.010\n", " Low Vola. 0.568 0.035 0.198 0.038 0.088 0.016 0.057 0.010" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# display those outliers parameters\n", "df_kmodel.groupby(['group', 'Volatility']).agg({'K1': ['mean', 'sem'], 'K2': ['mean', 'sem'], 'kseq': ['mean', 'sem'], 'tau': ['mean', 'sem']})\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2.8 Would large K2 lead to a slow updating?\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given that K2 was significantly larger in the ASD group, we would expect a slower updating of the prior information. So we split the trials into two: the first half and the second half. Due to reduction of the sample trials, we excluded the outliers in this analysis." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "# let's check the first half of the trials from each session from the rawdata\n", "# select the first half of the trials (trlNo 1-125 for the first session, 251-375 for the second) \n", "firsthalf_raw = rawdata.query('trlNo <= 125 or (trlNo >= 251 and trlNo <= 375)')\n", "# and second half of the trials (trlNo 126-250 for the first session, 376-500 for the second)\n", "secondhalf_raw = rawdata.query('trlNo >= 126 or trlNo >= 376')\n", "# calculate coefficient for the first half of the trials\n", "df_coef1 = firsthalf_raw.query('outlier == False and sequence not in @outliers_regress').groupby(['sub', 'Volatility', 'sequence', 'group']).apply(reg_func).reset_index()\n", "# calculate coefficient for the second half of the trials\n", "df_coef2 = secondhalf_raw.query('outlier == False and sequence not in @outliers_regress').groupby(['sub', 'Volatility', 'sequence', 'group']).apply(reg_func).reset_index()\n" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "\n", "sns.set(style='ticks', context='paper')\n", "sns.set_palette('Dark2')\n", "# two subplots side by side\n", "fig, axes = plt.subplots(ncols = 2, figsize = (7, 3.5))\n", "# barplot for cti as a function of group, Volatility, and half\n", "sns.barplot(data=df_coef1, x='group', y='cti', hue='Volatility', capsize = .1,\n", " zorder = 5, errorbar=('ci', 68), ax = axes[0], hue_order=['Low Vola.', 'High Vola.'])\n", "# legend off\n", "axes[0].legend().remove()\n", "# remove x axis label\n", "axes[0].set(xlabel = '')\n", "axes[0].set(ylabel = 'Central Tendency Index (CTI)')\n", "# add title to the subplot\n", "axes[0].set_title('The first half of trials')\n", "# second subplot for the ar_dw as a function of group, Volatility, and half\n", "sns.barplot(data=df_coef1, x='group', y='ar_dw', hue='Volatility', capsize = .1,\n", " zorder = 5, errorbar=('ci', 68), ax = axes[1], hue_order=['Low Vola.', 'High Vola.'])\n", "axes[1].set(xlabel = '')\n", "axes[1].set(ylabel = 'DW index')\n", "# y axis from 1.5 to 2.5\n", "axes[1].set_ylim(1.5, 2.1)\n", "# add dashed line 2 to indicate the 0 autocorrelation\n", "axes[1].axhline(2, ls='--', c='k')\n", "axes[1].set_title('The first half of trials')\n", "# add title to the subplot\n", "#axes[1].set_title('Changes in CTI')\n", "# remove box around the plot\n", "sns.despine()\n", "# add labels to subplots a, b, c, d\n", "for i, label in enumerate(['a', 'b']):\n", " axes[i].text(-0.1, 1.1, label, transform=axes[i].transAxes, \n", " fontsize=16, fontweight='bold', va='top', ha='right')\n", "# tight layout\n", "plt.tight_layout()\n", "# save the figure to vector file ./figures/cti_half.png\n", "plt.savefig('./figures/cti_half.png', dpi=300)\n" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.2221560.2226.2360.0150.100NaN
1Volatility0.4801560.48024.7380.0000.3061.000
2Interaction0.0131560.0130.6820.4120.012NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.222 1 56 0.222 6.236 0.015 0.100 NaN\n", "1 Volatility 0.480 1 56 0.480 24.738 0.000 0.306 1.000\n", "2 Interaction 0.013 1 56 0.013 0.682 0.412 0.012 NaN" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pg.mixed_anova(data=df_coef1, \n", " dv='cti', within='Volatility', between = 'group', subject='sub')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above statistics showed that the central tendency was significantly difference between groups in the first half! This suggests that the ASD group had a slower updating of the prior information, which might not have fully updated in the first half." ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.5441560.5447.2020.0100.114NaN
1Volatility0.0361560.0360.6940.4080.0121.000
2Interaction0.0271560.0270.5220.4730.009NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.544 1 56 0.544 7.202 0.010 0.114 NaN\n", "1 Volatility 0.036 1 56 0.036 0.694 0.408 0.012 1.000\n", "2 Interaction 0.027 1 56 0.027 0.522 0.473 0.009 NaN" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pg.mixed_anova(data=df_coef1, \n", " dv='ar_dw', within='Volatility', between = 'group', subject='sub')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "the autocorrelation index (DW) for the first half of the trials was also significant between two groups. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compared to the entire sessions: " ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.0471560.0471.7340.1930.030NaN
1Volatility1.2041561.20481.3660.0000.5921.000
2Interaction0.0491560.0493.3240.0740.056NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.047 1 56 0.047 1.734 0.193 0.030 NaN\n", "1 Volatility 1.204 1 56 1.204 81.366 0.000 0.592 1.000\n", "2 Interaction 0.049 1 56 0.049 3.324 0.074 0.056 NaN" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pg.mixed_anova(data=df_coef.query(\"sequence not in @outliers_regress\"), \n", " dv='cti', within='Volatility', between = 'group', subject='sub')\n" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.5591560.5598.3570.0050.130NaN
1Volatility0.0871560.0873.3380.0730.0561.000
2Interaction0.0811560.0813.1030.0840.052NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.559 1 56 0.559 8.357 0.005 0.130 NaN\n", "1 Volatility 0.087 1 56 0.087 3.338 0.073 0.056 1.000\n", "2 Interaction 0.081 1 56 0.081 3.103 0.084 0.052 NaN" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pg.mixed_anova(data=df_coef.query(\"sequence not in @outliers_regress\"), \n", " dv='ar_dw', within='Volatility', between = 'group', subject='sub')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fit the two-state model to the first half of the trials" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/dc/hksrz0yj5bb8n7f4_yptkcmw0000gn/T/ipykernel_74932/1675682148.py:146: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " perr[i] = z[i] - H@x\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
subVolatilitysequencegroupparp1p2tauK1K2AIC_2SDW_2S
0A31High Vola.31ASD[0.9714549830363363, 0.01693385718297349, -0.1...0.9710.017-0.1990.5380.088-7.8701.933
1A31Low Vola.31ASD[2.0532000286870844e-22, 12.169718779060531, -...0.00012.170-0.4380.9290.92966.8420.850
2A33High Vola.33ASD[1.204485814584274, 0.7122651732928083, -0.058...1.2040.712-0.0580.7410.429-60.2051.666
3A33Low Vola.33ASD[5.735340853415004e-12, 0.4293106313521714, -0...0.0000.429-0.0280.4750.475-117.7132.117
4A34High Vola.34ASD[0.06397846678007811, 0.4849299971316463, 0.00...0.0640.4850.0050.5160.485-133.3101.814
\n", "
" ], "text/plain": [ " sub Volatility sequence group \\\n", "0 A31 High Vola. 31 ASD \n", "1 A31 Low Vola. 31 ASD \n", "2 A33 High Vola. 33 ASD \n", "3 A33 Low Vola. 33 ASD \n", "4 A34 High Vola. 34 ASD \n", "\n", " par p1 p2 tau \\\n", "0 [0.9714549830363363, 0.01693385718297349, -0.1... 0.971 0.017 -0.199 \n", "1 [2.0532000286870844e-22, 12.169718779060531, -... 0.000 12.170 -0.438 \n", "2 [1.204485814584274, 0.7122651732928083, -0.058... 1.204 0.712 -0.058 \n", "3 [5.735340853415004e-12, 0.4293106313521714, -0... 0.000 0.429 -0.028 \n", "4 [0.06397846678007811, 0.4849299971316463, 0.00... 0.064 0.485 0.005 \n", "\n", " K1 K2 AIC_2S DW_2S \n", "0 0.538 0.088 -7.870 1.933 \n", "1 0.929 0.929 66.842 0.850 \n", "2 0.741 0.429 -60.205 1.666 \n", "3 0.475 0.475 -117.713 2.117 \n", "4 0.516 0.485 -133.310 1.814 " ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# now fit two-state model to the first half of the trials\n", "# subsecting each subject, volatility, and estimate two-state model parameters\n", "df_kmodel1 = firsthalf_raw.query(\"sequence not in @outliers_regress\").groupby(\n", " ['sub', 'Volatility', 'sequence', 'group']).apply(\n", " fitKmodel).reset_index()\n", "df_kmodel1.columns = ['sub', 'Volatility', 'sequence', 'group', 'par']\n", "# split the parameters to columns\n", "df_kmodel1[['p1','p2','tau','K1', 'K2','AIC_2S','DW_2S']] = pd.DataFrame(df_kmodel1['par'].tolist(), \n", " index=df_kmodel1.index)\n", "df_kmodel1.head()" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
K1K2tau
meansemmeansemmeansem
groupVolatility
ASDHigh Vola.0.7500.0290.4660.0570.0280.016
Low Vola.0.6100.0460.3330.0580.0340.022
TDHigh Vola.0.6850.0260.3960.0480.0120.011
Low Vola.0.5540.0420.1780.0470.0520.013
\n", "
" ], "text/plain": [ " K1 K2 tau \n", " mean sem mean sem mean sem\n", "group Volatility \n", "ASD High Vola. 0.750 0.029 0.466 0.057 0.028 0.016\n", " Low Vola. 0.610 0.046 0.333 0.058 0.034 0.022\n", "TD High Vola. 0.685 0.026 0.396 0.048 0.012 0.011\n", " Low Vola. 0.554 0.042 0.178 0.047 0.052 0.013" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# show the average K1, K2 and tau for each group and Volatility\n", "df_kmodel1.groupby(['group', 'Volatility']).agg({'K1': ['mean', 'sem'], 'K2': ['mean', 'sem'], 'tau': ['mean', 'sem']})\n" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
K1K2tau
meansemmeansemmeansem
groupVolatility
ASDHigh Vola.0.6810.0430.4460.0470.0130.019
Low Vola.0.6420.0360.3440.0570.0410.017
TDHigh Vola.0.6820.0260.3960.0390.0310.010
Low Vola.0.5680.0350.1980.0380.0570.010
\n", "
" ], "text/plain": [ " K1 K2 tau \n", " mean sem mean sem mean sem\n", "group Volatility \n", "ASD High Vola. 0.681 0.043 0.446 0.047 0.013 0.019\n", " Low Vola. 0.642 0.036 0.344 0.057 0.041 0.017\n", "TD High Vola. 0.682 0.026 0.396 0.039 0.031 0.010\n", " Low Vola. 0.568 0.035 0.198 0.038 0.057 0.010" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# compared to df_kmodel\n", "df_kmodel.groupby(['group', 'Volatility']).agg({'K1': ['mean', 'sem'], 'K2': ['mean', 'sem'], 'tau': ['mean', 'sem']})" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dK1dK2dtau
meansemmeansemmeansem
groupVolatility
ASDHigh Vola.0.0060.010-0.0220.035-0.0040.004
Low Vola.-0.0580.034-0.0190.065-0.0120.010
TDHigh Vola.0.0190.0090.0070.040-0.0170.006
Low Vola.-0.0070.023-0.0150.047-0.0050.007
\n", "
" ], "text/plain": [ " dK1 dK2 dtau \n", " mean sem mean sem mean sem\n", "group Volatility \n", "ASD High Vola. 0.006 0.010 -0.022 0.035 -0.004 0.004\n", " Low Vola. -0.058 0.034 -0.019 0.065 -0.012 0.010\n", "TD High Vola. 0.019 0.009 0.007 0.040 -0.017 0.006\n", " Low Vola. -0.007 0.023 -0.015 0.047 -0.005 0.007" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# combine df_kmodel and df_kmodel1 by sub, Volatility, and group, for K1, K2, and tau, and add differences between two \n", "# K1, K2, and tau\n", "kmodel_v1 = df_kmodel1.merge(df_kmodel, on=['sub', 'Volatility', 'sequence', 'group'])\n", "kmodel_v1['dK1'] = kmodel_v1['K1_x'] - kmodel_v1['K1_y']\n", "kmodel_v1['dK2'] = kmodel_v1['K2_x'] - kmodel_v1['K2_y']\n", "kmodel_v1['dtau'] = kmodel_v1['tau_x'] - kmodel_v1['tau_y']\n", "# show the average dk1, dk2, and dtau for each group and Volatility\n", "kmodel_v1.groupby(['group', 'Volatility']).agg({'dK1': ['mean', 'sem'], 'dK2': ['mean', 'sem'], 'dtau': ['mean', 'sem']})" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.0301560.0302.2250.1410.038NaN
1Volatility0.0581560.0584.3050.0430.0711.000
2Interaction0.0101560.0100.7720.3830.014NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.030 1 56 0.030 2.225 0.141 0.038 NaN\n", "1 Volatility 0.058 1 56 0.058 4.305 0.043 0.071 1.000\n", "2 Interaction 0.010 1 56 0.010 0.772 0.383 0.014 NaN" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# anova test for dK1\n", "pg.mixed_anova(data=kmodel_v1, \n", " dv='dK1', within='Volatility', between = 'group', subject='sub')\n" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.1051560.1051.9140.1720.033NaN
1Volatility0.5321560.53221.8870.0000.2811.000
2Interaction0.0011560.0010.0210.8850.000NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.105 1 56 0.105 1.914 0.172 0.033 NaN\n", "1 Volatility 0.532 1 56 0.532 21.887 0.000 0.281 1.000\n", "2 Interaction 0.001 1 56 0.001 0.021 0.885 0.000 NaN" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pg.mixed_anova(data=df_kmodel1, \n", " dv='K1', within='Volatility', between = 'group', subject='sub')" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.0081560.0080.1060.7450.002NaN
1Volatility0.0021560.0020.0390.8450.0011.000
2Interaction0.0041560.0040.0700.7920.001NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.008 1 56 0.008 0.106 0.745 0.002 NaN\n", "1 Volatility 0.002 1 56 0.002 0.039 0.845 0.001 1.000\n", "2 Interaction 0.004 1 56 0.004 0.070 0.792 0.001 NaN" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# anova test for dK2\n", "pg.mixed_anova(data=kmodel_v1, \n", " dv='dK2', within='Volatility', between = 'group', subject='sub')" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.3691560.3694.5140.0380.075NaN
1Volatility0.8931560.89311.1790.0010.1661.000
2Interaction0.0531560.0530.6610.4200.012NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.369 1 56 0.369 4.514 0.038 0.075 NaN\n", "1 Volatility 0.893 1 56 0.893 11.179 0.001 0.166 1.000\n", "2 Interaction 0.053 1 56 0.053 0.661 0.420 0.012 NaN" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# anova test for dK1\n", "pg.mixed_anova(data=df_kmodel1, \n", " dv='K2', within='Volatility', between = 'group', subject='sub')" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.0001560.0000.2170.6430.004NaN
1Volatility0.0001560.0000.0780.7810.0011.000
2Interaction0.0031560.0031.6260.2080.028NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.000 1 56 0.000 0.217 0.643 0.004 NaN\n", "1 Volatility 0.000 1 56 0.000 0.078 0.781 0.001 1.000\n", "2 Interaction 0.003 1 56 0.003 1.626 0.208 0.028 NaN" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# anova test for dtau\n", "pg.mixed_anova(data=kmodel_v1, \n", " dv='dtau', within='Volatility', between = 'group', subject='sub')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2.9 General Linear Model with the current and previous trial durations" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given that the low-volatility session used the random walk sequence, the current and previous durations were highly correlated. So we use the difference between the current and previous durations as a regressor. This won't affect the fully random sequence in the high-volatility session.\n" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/dc/hksrz0yj5bb8n7f4_yptkcmw0000gn/T/ipykernel_74932/4187532016.py:34: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n", " glm_results = pd.concat([glm_results, res], ignore_index=True)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
subVolatilitygroupInterceptcurDurdDurp_Interceptp_CTIp_dDurr2sequenceaic
0A31High Vola.ASD-0.167-0.393-0.0570.0000.0000.1220.39031.00015.538
1A31Low Vola.ASD-0.224-0.487-0.1300.0000.0000.4500.44731.000-4.166
4A33High Vola.ASD-0.056-0.131-0.1040.0000.0070.0020.22233.000-117.297
5A33Low Vola.ASD-0.032-0.061-0.3820.0010.0230.0040.06033.000-216.504
6A34High Vola.ASD-0.023-0.379-0.1380.0070.0000.0000.66434.000-284.612
\n", "
" ], "text/plain": [ " sub Volatility group Intercept curDur dDur p_Intercept p_CTI \\\n", "0 A31 High Vola. ASD -0.167 -0.393 -0.057 0.000 0.000 \n", "1 A31 Low Vola. ASD -0.224 -0.487 -0.130 0.000 0.000 \n", "4 A33 High Vola. ASD -0.056 -0.131 -0.104 0.000 0.007 \n", "5 A33 Low Vola. ASD -0.032 -0.061 -0.382 0.001 0.023 \n", "6 A34 High Vola. ASD -0.023 -0.379 -0.138 0.007 0.000 \n", "\n", " p_dDur r2 sequence aic \n", "0 0.122 0.390 31.000 15.538 \n", "1 0.450 0.447 31.000 -4.166 \n", "4 0.002 0.222 33.000 -117.297 \n", "5 0.004 0.060 33.000 -216.504 \n", "6 0.000 0.664 34.000 -284.612 " ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# import OLS from statsmodels\n", "from statsmodels.regression.linear_model import OLS\n", "# creating an autorergressive model with lag 1, store the results in a new dataframe\n", "glm_results = pd.DataFrame(columns = ['sub', 'Volatility', 'group', \n", " 'Intercept', 'curDur', 'dDur',\n", " 'p_Intercept', 'p_CTI', 'p_dDur','r2'])\n", "\n", "# drop nan values in preDuration, preErr from vdata\n", "vdata = rawdata.query(\"outlier == False\").dropna(subset=['preDuration', 'preErr'])\n", "# center Duration to 1\n", "vdata['cDuration'] = vdata['Duration'] - 1\n", "vdata['cpreDuration'] = vdata['preDuration'] - 1\n", "vdata['dDuration'] = vdata['Duration'] - vdata['preDuration']\n", "groups = vdata.groupby(['sub', 'Volatility', 'group', 'sequence'])\n", "# drop nan values in preDuration, preErr\n", "\n", "for name, group in groups:\n", " glm_mod = OLS(group['rep_err'], sm.add_constant(group[['cDuration', 'dDuration']])).fit()\n", " summary = glm_mod.summary2()\n", " r2 = glm_mod.rsquared\n", " aic = glm_mod.aic\n", " # estimate the prediction reproduced error\n", " res = pd.DataFrame({\n", " 'sub': name[0], 'Volatility': name[1], 'group': name[2], 'sequence': name[3],\n", " 'Intercept': glm_mod.params['const'], \n", " 'curDur': glm_mod.params['cDuration'], #as CTI = -beta\n", " 'dDur': glm_mod.params['dDuration'],\n", " 'p_Intercept': summary.tables[1]['P>|t|']['const'],\n", " 'p_CTI': summary.tables[1]['P>|t|']['cDuration'],\n", " 'p_dDur': summary.tables[1]['P>|t|']['dDuration'],\n", " 'r2': r2,\n", " 'aic': aic\n", " }, index=[0])\n", " glm_results = pd.concat([glm_results, res], ignore_index=True)\n", " \n", "# remove outliers\n", "glm_v = glm_results.query(\"sequence not in @outliers_regress\" )\n", "glm_v.head()" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
groupVolatilityInterceptcurDurdDur
meansemmeansemmeansem
0ASDHigh Vola.0.0320.015-0.1660.026-0.1040.016
1ASDLow Vola.0.0420.015-0.0920.030-0.2810.041
2TDHigh Vola.0.0290.011-0.2240.023-0.1230.012
3TDLow Vola.0.0380.013-0.0860.021-0.3940.049
\n", "
" ], "text/plain": [ " group Volatility Intercept curDur dDur \n", " mean sem mean sem mean sem\n", "0 ASD High Vola. 0.032 0.015 -0.166 0.026 -0.104 0.016\n", "1 ASD Low Vola. 0.042 0.015 -0.092 0.030 -0.281 0.041\n", "2 TD High Vola. 0.029 0.011 -0.224 0.023 -0.123 0.012\n", "3 TD Low Vola. 0.038 0.013 -0.086 0.021 -0.394 0.049" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# some unknown error with sub column in groupby and agg, so drop it first\n", "# show the mean and standard errors of Intercept, Duration, preErr for each group, Volatility\n", "glm_v.drop('sub', axis =1).groupby(['group', 'Volatility']).agg(['mean', 'sem'])[['Intercept', 'curDur', 'dDur']].reset_index()" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.0001560.0000.0390.8450.001NaN
1Volatility0.0031560.0033.3880.0710.0571.000
2Interaction0.0001560.0000.0100.9220.000NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.000 1 56 0.000 0.039 0.845 0.001 NaN\n", "1 Volatility 0.003 1 56 0.003 3.388 0.071 0.057 1.000\n", "2 Interaction 0.000 1 56 0.000 0.010 0.922 0.000 NaN" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pg.mixed_anova(data=glm_v, dv='Intercept', within='Volatility', between='group', subject='sub')\n" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.0201560.0200.9060.3450.016NaN
1Volatility0.3261560.32624.5700.0000.3051.000
2Interaction0.0291560.0292.2180.1420.038NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.020 1 56 0.020 0.906 0.345 0.016 NaN\n", "1 Volatility 0.326 1 56 0.326 24.570 0.000 0.305 1.000\n", "2 Interaction 0.029 1 56 0.029 2.218 0.142 0.038 NaN" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# pingouin ANOVA test for CTI\n", "pg.mixed_anova(data=glm_v, dv='curDur', within='Volatility', between='group', subject='sub')" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SourceSSDF1DF2MSFp-uncnp2eps
0group0.1261560.1263.3350.0730.056NaN
1Volatility1.4601561.46053.4880.0000.4891.000
2Interaction0.0631560.0632.3230.1330.040NaN
\n", "
" ], "text/plain": [ " Source SS DF1 DF2 MS F p-unc np2 eps\n", "0 group 0.126 1 56 0.126 3.335 0.073 0.056 NaN\n", "1 Volatility 1.460 1 56 1.460 53.488 0.000 0.489 1.000\n", "2 Interaction 0.063 1 56 0.063 2.323 0.133 0.040 NaN" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "# pingouin ANOVA test for preDur\n", "pg.mixed_anova(data=glm_v, dv='dDur', within='Volatility', between='group', subject='sub')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2.10 Correlation with EQ, AQ\n" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Tdofalternativep-valCI95%cohen-dBF10power
T-test-8.878120.558two-sided0.000[-39.01, -24.78]1.5777.799e+111.000
\n", "
" ], "text/plain": [ " T dof alternative p-val CI95% cohen-d \\\n", "T-test -8.878 120.558 two-sided 0.000 [-39.01, -24.78] 1.577 \n", "\n", " BF10 power \n", "T-test 7.799e+11 1.000 " ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import statsmodels.api as sm\n", "import statsmodels.formula.api as smf\n", "# load data/parinfo.csv as a dataframe\n", "pinfo = pd.read_csv('data/parinfo.csv')\n", "# add difference between EQ and SQ as a new column ES\n", "pinfo['ES'] = pinfo['EQ'] - pinfo['SQ'] # positive 'female brain', negative 'male brain'\n", "# merge glm_v and aq on sub\n", "res = pd.merge(kpars, pinfo, on=['group','sequence'])\n", "# t-test for DQ between ASD and TD\n", "pg.ttest(res.query(\"group == 'ASD'\")['ES'], res.query(\"group == 'TD'\")['ES'])" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: K1
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0184
Min. group size: 2 Log-Likelihood: 23.2472
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.801 0.129 6.218 0.000 0.548 1.053
group[T.TD] -0.097 0.083 -1.164 0.245 -0.261 0.066
Volatility[T.Low Vola.] -0.078 0.024 -3.224 0.001 -0.125 -0.031
AQ -0.003 0.003 -0.804 0.421 -0.009 0.004
Group Var 0.023 0.060

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/strongway/miniforge3/envs/py312/lib/python3.12/site-packages/statsmodels/regression/mixed_linear_model.py:2238: ConvergenceWarning: The MLE may be on the boundary of the parameter space.\n", " warnings.warn(msg, ConvergenceWarning)\n" ] }, { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: K2
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0611
Min. group size: 2 Log-Likelihood: -17.5335
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.545 0.138 3.952 0.000 0.275 0.815
group[T.TD] -0.148 0.089 -1.666 0.096 -0.321 0.026
Volatility[T.Low Vola.] -0.158 0.044 -3.587 0.000 -0.244 -0.072
AQ -0.002 0.004 -0.538 0.590 -0.009 0.005
Group Var 0.005 0.037

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/strongway/miniforge3/envs/py312/lib/python3.12/site-packages/statsmodels/regression/mixed_linear_model.py:2238: ConvergenceWarning: The MLE may be on the boundary of the parameter space.\n", " warnings.warn(msg, ConvergenceWarning)\n" ] }, { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: tau
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0020
Min. group size: 2 Log-Likelihood: 143.1968
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.111 0.054 2.038 0.042 0.004 0.218
group[T.TD] -0.036 0.035 -1.024 0.306 -0.106 0.033
Volatility[T.Low Vola.] 0.027 0.008 3.366 0.001 0.011 0.043
AQ -0.003 0.001 -1.849 0.064 -0.005 0.000
Group Var 0.005 0.033

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: cti
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0202
Min. group size: 2 Log-Likelihood: 32.4618
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.286 0.106 2.706 0.007 0.079 0.493
group[T.TD] 0.018 0.068 0.262 0.793 -0.116 0.152
Volatility[T.Low Vola.] -0.225 0.025 -8.897 0.000 -0.275 -0.176
AQ 0.001 0.003 0.467 0.641 -0.004 0.007
Group Var 0.011 0.038

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from IPython.display import display, HTML\n", "# fit the models with AQ\n", "display(HTML(smf.mixedlm(\"K1 ~ AQ + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n", "display(HTML(smf.mixedlm(\"K2 ~ AQ + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n", "display(HTML(smf.mixedlm(\"tau ~ AQ + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n", "display(HTML(smf.mixedlm(\"cti ~ AQ + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "AQ was only marginal correlated with tau. Due to the multiple comparison, we did not consider this correlation." ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: K1
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0184
Min. group size: 2 Log-Likelihood: 23.2224
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.636 0.056 11.287 0.000 0.526 0.747
group[T.TD] -0.095 0.059 -1.617 0.106 -0.211 0.020
Volatility[T.Low Vola.] -0.078 0.024 -3.224 0.001 -0.125 -0.031
EQ 0.002 0.002 1.415 0.157 -0.001 0.006
Group Var 0.022 0.059

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/strongway/miniforge3/envs/py312/lib/python3.12/site-packages/statsmodels/regression/mixed_linear_model.py:2238: ConvergenceWarning: The MLE may be on the boundary of the parameter space.\n", " warnings.warn(msg, ConvergenceWarning)\n" ] }, { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: K2
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0611
Min. group size: 2 Log-Likelihood: -18.0032
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.432 0.063 6.874 0.000 0.309 0.556
group[T.TD] -0.143 0.063 -2.266 0.023 -0.266 -0.019
Volatility[T.Low Vola.] -0.158 0.044 -3.587 0.000 -0.244 -0.072
EQ 0.002 0.002 0.854 0.393 -0.002 0.005
Group Var 0.005 0.037

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/strongway/miniforge3/envs/py312/lib/python3.12/site-packages/statsmodels/regression/mixed_linear_model.py:2238: ConvergenceWarning: The MLE may be on the boundary of the parameter space.\n", " warnings.warn(msg, ConvergenceWarning)\n" ] }, { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: tau
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0020
Min. group size: 2 Log-Likelihood: 141.0893
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept -0.000 0.024 -0.010 0.992 -0.048 0.048
group[T.TD] 0.007 0.026 0.272 0.785 -0.044 0.058
Volatility[T.Low Vola.] 0.027 0.008 3.366 0.001 0.011 0.043
EQ 0.001 0.001 0.698 0.485 -0.001 0.002
Group Var 0.005 0.034

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: cti
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0202
Min. group size: 2 Log-Likelihood: 31.7572
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.349 0.047 7.387 0.000 0.257 0.442
group[T.TD] 0.005 0.049 0.098 0.922 -0.091 0.100
Volatility[T.Low Vola.] -0.225 0.025 -8.897 0.000 -0.275 -0.176
EQ -0.001 0.001 -0.429 0.668 -0.003 0.002
Group Var 0.011 0.038

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display(HTML(smf.mixedlm(\"K1 ~ EQ + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n", "display(HTML(smf.mixedlm(\"K2 ~ EQ + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n", "display(HTML(smf.mixedlm(\"tau ~ EQ + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n", "display(HTML(smf.mixedlm(\"cti ~ EQ + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: K1
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0184
Min. group size: 2 Log-Likelihood: 22.2862
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.701 0.072 9.754 0.000 0.560 0.841
group[T.TD] -0.041 0.048 -0.846 0.398 -0.135 0.053
Volatility[T.Low Vola.] -0.078 0.024 -3.224 0.001 -0.125 -0.031
SQ 0.000 0.002 0.001 1.000 -0.004 0.004
Group Var 0.023 0.061

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/strongway/miniforge3/envs/py312/lib/python3.12/site-packages/statsmodels/regression/mixed_linear_model.py:2238: ConvergenceWarning: The MLE may be on the boundary of the parameter space.\n", " warnings.warn(msg, ConvergenceWarning)\n" ] }, { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: K2
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0611
Min. group size: 2 Log-Likelihood: -14.2098
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.660 0.073 8.988 0.000 0.516 0.803
group[T.TD] -0.155 0.047 -3.262 0.001 -0.248 -0.062
Volatility[T.Low Vola.] -0.158 0.044 -3.587 0.000 -0.244 -0.072
SQ -0.005 0.002 -2.967 0.003 -0.009 -0.002
Group Var 0.001 0.033

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/strongway/miniforge3/envs/py312/lib/python3.12/site-packages/statsmodels/regression/mixed_linear_model.py:2238: ConvergenceWarning: The MLE may be on the boundary of the parameter space.\n", " warnings.warn(msg, ConvergenceWarning)\n" ] }, { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: tau
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0020
Min. group size: 2 Log-Likelihood: 141.3026
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.038 0.031 1.248 0.212 -0.022 0.099
group[T.TD] 0.013 0.021 0.610 0.542 -0.028 0.053
Volatility[T.Low Vola.] 0.027 0.008 3.366 0.001 0.011 0.043
SQ -0.001 0.001 -0.906 0.365 -0.002 0.001
Group Var 0.005 0.034

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: cti
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0202
Min. group size: 2 Log-Likelihood: 31.7151
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.336 0.059 5.686 0.000 0.220 0.452
group[T.TD] -0.010 0.039 -0.246 0.805 -0.086 0.067
Volatility[T.Low Vola.] -0.225 0.025 -8.897 0.000 -0.275 -0.176
SQ -0.000 0.001 -0.053 0.958 -0.003 0.003
Group Var 0.011 0.038

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display(HTML(smf.mixedlm(\"K1 ~ SQ + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n", "display(HTML(smf.mixedlm(\"K2 ~ SQ + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n", "display(HTML(smf.mixedlm(\"tau ~ SQ + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n", "display(HTML(smf.mixedlm(\"cti ~ SQ + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SQ and K2 correlation was significant!" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# scatterplot for K2 as a function of SQ, separated by group (different colors), Volatility (different symbols)\n", "#sns.scatterplot(data=res, x='SQ', y='K2', hue='group', style='Volatility', alpha=0.5)\n", "palette = sns.color_palette('Dark2')\n", "plt.figure(figsize=(6,4))\n", "# subplot 1\n", "#plt.subplot(1,2,1)\n", "plt.xlabel('SQ')\n", "plt.ylabel('K2')\n", "# now add regression lines for each group, Volatility\n", "sns.regplot(data=res.query(\"group == 'ASD' and Volatility == 'Low Vola.'\"), \n", " x='SQ', y='K2', color = palette[0], line_kws={'ls':'--'}, \n", " scatter_kws ={'facecolors':'none','alpha':0.3}, marker='o', label = \"ASD/Low Vola.\") \n", "sns.regplot(data=res.query(\"group == 'ASD' and Volatility == 'High Vola.'\"),\n", " x='SQ', y='K2', color = palette[1], line_kws={'ls':'--'}, \n", " scatter_kws ={'alpha':0.3}, marker='o', label = \"ASD/High Vola.\")\n", "sns.regplot(data=res.query(\"group == 'TD' and Volatility == 'Low Vola.'\"), \n", " x='SQ', y='K2', color = palette[2], line_kws={'ls':'-'}, \n", " scatter_kws ={'facecolors':'none','alpha':0.3}, marker = 's', label = \"TD/Low Vola.\")\n", "sns.regplot(data=res.query(\"group == 'TD' and Volatility == 'High Vola.'\"), \n", " x='SQ', y='K2', color = palette[3], line_kws={'ls':'-'}, \n", " scatter_kws ={'alpha':0.3}, marker='s', label=\"TD/High Vola.\")\n", "# add legend\n", "plt.legend()\n", "# plot a label 'a' on the top left corner\n", "#plt.text(-0.1, 1.1, 'a', transform=plt.gca().transAxes, \n", "# fontsize=16, fontweight='bold', va='top', ha='right')\n", "# remove box around the plot\n", "sns.despine()\n", "# subplot 1\n", "#plt.subplot(1,2,2)\n", "#plt.xlabel('E-S')\n", "#plt.ylabel('K2')\n", "# now add regression lines for each group, Volatility\n", "#sns.regplot(data=res.query(\"group == 'ASD' and Volatility == 'Low Vola.'\"), \n", "# x='ES', y='K2', color = palette[0], line_kws={'ls':'--'}, \n", "# scatter_kws ={'facecolors':'none','alpha':0.3}, marker='o', label = \"ASD/Low Vola.\") \n", "#sns.regplot(data=res.query(\"group == 'ASD' and Volatility == 'High Vola.'\"),\n", "# x='ES', y='K2', color = palette[1], line_kws={'ls':'--'}, \n", "# scatter_kws ={'alpha':0.3}, marker='o', label = \"ASD/High Vola.\")\n", "#sns.regplot(data=res.query(\"group == 'TD' and Volatility == 'Low Vola.'\"), \n", "# x='ES', y='K2', color = palette[2], line_kws={'ls':'-'}, \n", "# scatter_kws ={'facecolors':'none','alpha':0.3}, marker = 's', label = \"TD/Low Vola.\")\n", "#sns.regplot(data=res.query(\"group == 'TD' and Volatility == 'High Vola.'\"), \n", "# x='ES', y='K2', color = palette[3], line_kws={'ls':'-'}, \n", "# scatter_kws ={'alpha':0.3}, marker='s', label=\"TD/High Vola.\")\n", "# add legend\n", "#plt.legend()\n", "# plot a label 'b' on the top left corner\n", "#plt.text(-0.1, 1.1, 'b', transform=plt.gca().transAxes, \n", "# fontsize=16, fontweight='bold', va='top', ha='right')\n", "# remove box around the plot\n", "sns.despine()\n", "# tight layout\n", "plt.tight_layout()\n", "# save the figure to vector file ./figures/K2_vs_SQ.pdf\n", "plt.savefig('./figures/K2_vs_SQ.png', dpi=300)\n" ] }, { "cell_type": "code", "execution_count": 187, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: K1
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0184
Min. group size: 2 Log-Likelihood: 23.3700
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.735 0.045 16.190 0.000 0.646 0.823
group[T.TD] -0.057 0.047 -1.214 0.225 -0.150 0.035
Volatility[T.Low Vola.] -0.078 0.024 -3.224 0.001 -0.125 -0.031
BDI -0.003 0.003 -1.112 0.266 -0.009 0.002
Group Var 0.022 0.059

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/strongway/miniforge3/envs/py312/lib/python3.12/site-packages/statsmodels/regression/mixed_linear_model.py:2238: ConvergenceWarning: The MLE may be on the boundary of the parameter space.\n", " warnings.warn(msg, ConvergenceWarning)\n" ] }, { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: K2
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0611
Min. group size: 2 Log-Likelihood: -16.1921
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.532 0.051 10.531 0.000 0.433 0.631
group[T.TD] -0.136 0.049 -2.776 0.006 -0.232 -0.040
Volatility[T.Low Vola.] -0.158 0.044 -3.587 0.000 -0.244 -0.072
BDI -0.005 0.003 -1.844 0.065 -0.011 0.000
Group Var 0.004 0.035

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/strongway/miniforge3/envs/py312/lib/python3.12/site-packages/statsmodels/regression/mixed_linear_model.py:2238: ConvergenceWarning: The MLE may be on the boundary of the parameter space.\n", " warnings.warn(msg, ConvergenceWarning)\n" ] }, { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: tau
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0020
Min. group size: 2 Log-Likelihood: 141.6521
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.004 0.019 0.186 0.852 -0.035 0.042
group[T.TD] 0.024 0.021 1.158 0.247 -0.016 0.064
Volatility[T.Low Vola.] 0.027 0.008 3.366 0.001 0.011 0.043
BDI 0.001 0.001 0.756 0.450 -0.001 0.003
Group Var 0.005 0.034

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: cti
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0202
Min. group size: 2 Log-Likelihood: 32.6026
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.310 0.038 8.183 0.000 0.236 0.385
group[T.TD] 0.002 0.039 0.060 0.952 -0.073 0.078
Volatility[T.Low Vola.] -0.225 0.025 -8.897 0.000 -0.275 -0.176
BDI 0.002 0.002 0.916 0.360 -0.002 0.007
Group Var 0.011 0.038

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display(HTML(smf.mixedlm(\"K1 ~ BDI + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n", "display(HTML(smf.mixedlm(\"K2 ~ BDI + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n", "display(HTML(smf.mixedlm(\"tau ~ BDI + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n", "display(HTML(smf.mixedlm(\"cti ~ BDI + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n" ] }, { "cell_type": "code", "execution_count": 188, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: K1
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0184
Min. group size: 2 Log-Likelihood: 22.2399
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.709 0.035 20.292 0.000 0.641 0.778
group[T.TD] -0.073 0.057 -1.281 0.200 -0.186 0.039
Volatility[T.Low Vola.] -0.078 0.024 -3.224 0.001 -0.125 -0.031
ES 0.001 0.001 0.921 0.357 -0.001 0.003
Group Var 0.022 0.060

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/strongway/miniforge3/envs/py312/lib/python3.12/site-packages/statsmodels/regression/mixed_linear_model.py:2238: ConvergenceWarning: The MLE may be on the boundary of the parameter space.\n", " warnings.warn(msg, ConvergenceWarning)\n" ] }, { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: K2
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0611
Min. group size: 2 Log-Likelihood: -16.0307
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.496 0.040 12.412 0.000 0.418 0.575
group[T.TD] -0.195 0.058 -3.336 0.001 -0.309 -0.080
Volatility[T.Low Vola.] -0.158 0.044 -3.587 0.000 -0.244 -0.072
ES 0.003 0.001 2.402 0.016 0.001 0.005
Group Var 0.002 0.034

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/strongway/miniforge3/envs/py312/lib/python3.12/site-packages/statsmodels/regression/mixed_linear_model.py:2238: ConvergenceWarning: The MLE may be on the boundary of the parameter space.\n", " warnings.warn(msg, ConvergenceWarning)\n" ] }, { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: tau
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0020
Min. group size: 2 Log-Likelihood: 140.9529
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.018 0.015 1.207 0.228 -0.011 0.047
group[T.TD] 0.003 0.025 0.120 0.904 -0.046 0.052
Volatility[T.Low Vola.] 0.027 0.008 3.366 0.001 0.011 0.043
ES 0.000 0.000 1.029 0.303 -0.000 0.001
Group Var 0.005 0.034

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Model: MixedLM Dependent Variable: cti
No. Observations: 126 Method: REML
No. Groups: 63 Scale: 0.0202
Min. group size: 2 Log-Likelihood: 31.2769
Max. group size: 2 Converged: Yes
Mean group size: 2.0
\n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "\n", " \n", "\n", "
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.331 0.030 11.144 0.000 0.273 0.390
group[T.TD] -0.002 0.047 -0.035 0.972 -0.094 0.090
Volatility[T.Low Vola.] -0.225 0.025 -8.897 0.000 -0.275 -0.176
ES -0.000 0.001 -0.249 0.804 -0.002 0.002
Group Var 0.011 0.038

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display(HTML(smf.mixedlm(\"K1 ~ ES + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n", "display(HTML(smf.mixedlm(\"K2 ~ ES + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n", "display(HTML(smf.mixedlm(\"tau ~ ES + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n", "display(HTML(smf.mixedlm(\"cti ~ ES + group + Volatility\", res, groups=res[\"sub\"]).fit().summary().as_html()))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "ES had a significant correlation with K2. Let's visualize it. " ] }, { "cell_type": "code", "execution_count": 395, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
subsequenceSQ
2A323244
3A323244
12C323210
13C323210
36aril0228
37aril0228
50arm131332
51arm131332
114crs02221
115crs02221
124crw131351
125crw131351
\n", "
" ], "text/plain": [ " sub sequence SQ\n", "2 A32 32 44\n", "3 A32 32 44\n", "12 C32 32 10\n", "13 C32 32 10\n", "36 aril02 2 8\n", "37 aril02 2 8\n", "50 arm13 13 32\n", "51 arm13 13 32\n", "114 crs02 2 21\n", "115 crs02 2 21\n", "124 crw13 13 51\n", "125 crw13 13 51" ] }, "execution_count": 395, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# show those outliers SQ scores\n", "res.query(\"sequence in @outliers_regress\")[['sub', 'sequence', 'SQ']]" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.1" } }, "nbformat": 4, "nbformat_minor": 2 }