08_controls_rsa.R 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. library(readr)
  2. library(dplyr)
  3. library(purrr)
  4. library(tidyr)
  5. library(brms)
  6. library(ggplot2)
  7. library(ggdist)
  8. # library(fda)
  9. library(patchwork)
  10. library(parallel)
  11. n_cores <- parallel::detectCores(all.tests=FALSE, logical=TRUE)
  12. # n_cores <- 14
  13. # options(loo.cores = n_cores)
  14. set.seed(1)
  15. stim_sim <- left_join(
  16. read_csv(file.path("stim_sim", "preregistered", "jacc.csv")),
  17. read_csv(file.path("stim_sim", "preregistered", "ot.csv")),
  18. by = c("char1", "char2")
  19. ) |>
  20. left_join(
  21. read_csv(file.path("stim_sim", "complexity", "complexity.csv")),
  22. by = c("char1", "char2")
  23. ) |>
  24. left_join(
  25. read_csv(file.path("stim_sim", "frequency", "frequency.csv")),
  26. by = c("char1", "char2")
  27. ) |>
  28. left_join(
  29. read_csv(file.path("stim_sim", "phonology", "dominant_phonemes.csv")),
  30. by = c("char1", "char2")
  31. ) |>
  32. left_join(
  33. read_csv(file.path("stim_sim", "phonology", "letter_names.csv")),
  34. by = c("char1", "char2")
  35. ) |>
  36. mutate(
  37. rank_jacc = rank(jacc),
  38. rank_ot = rank(ot),
  39. rank_comp_dist = rank(comp_dist),
  40. rank_freq_dist = rank(freq_dist),
  41. rank_phon_dist = rank(phon_dist),
  42. rank_name_phon_dist = rank(name_phon_dist)
  43. ) |>
  44. rowwise() |>
  45. mutate(pair_id = paste(sort(c(char1, char2)), collapse="_")) |>
  46. ungroup()
  47. rdm <- file.path("rdm_data", "time_resolved") |>
  48. list.files(pattern=".*\\.csv", full.names=TRUE) |>
  49. map_df(read_csv, col_types=c("subj_id"="c")) |>
  50. left_join(stim_sim, by=c("char1", "char2")) |>
  51. group_by(subj_id, time) |>
  52. mutate(rank_eeg_dissim = rank(eeg_dissim)) |>
  53. ungroup() |>
  54. arrange(time)
  55. times <- sort(unique(rdm$time))
  56. rdm_poi <- file.path("rdm_data", "period_of_interest") |>
  57. list.files(pattern=".*\\.csv", full.names=TRUE) |>
  58. map_df(read_csv, col_types=c("subj_id"="c")) |>
  59. left_join(stim_sim, by=c("char1", "char2")) |>
  60. group_by(subj_id) |>
  61. mutate(rank_eeg_dissim = rank(eeg_dissim)) |>
  62. ungroup()
  63. # fit full model -------------------------------------------
  64. # this is a multivariate model that estimates correlations between all items
  65. # partial corr model
  66. # short function for calculating the population SD
  67. pop_sd <- function(x) sqrt((length(x)-1)/length(x)) * sd(x)
  68. # check correlation priors
  69. lkj_prior <- 1.5
  70. m_rho_prior_full <- c(
  71. # biased towards 0 for non-diagonal correlations, because lower correlations are more likely in RSA of EEG data
  72. set_prior(sprintf("lkj(%s)", lkj_prior), class="rescor"),
  73. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_poi$rank_eeg_dissim))), class="sigma", resp="rankeegdissim", ub=max(rdm_poi$rank_eeg_dissim)),
  74. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_poi$rank_jacc))), class="sigma", resp="rankjacc", ub=max(rdm_poi$rank_jacc)),
  75. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_poi$rank_ot))), class="sigma", resp="rankot", ub=max(rdm_poi$rank_ot)),
  76. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_poi$rank_comp_dist))), class="sigma", resp="rankcompdist", ub=max(rdm_poi$rank_comp_dist)),
  77. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_poi$rank_freq_dist))), class="sigma", resp="rankfreqdist", ub=max(rdm_poi$rank_freq_dist)),
  78. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_poi$rank_phon_dist))), class="sigma", resp="rankphondist", ub=max(rdm_poi$rank_phon_dist)),
  79. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_poi$rank_name_phon_dist))), class="sigma", resp="ranknamephondist", ub=max(rdm_poi$rank_name_phon_dist))
  80. )
  81. m_rho_full_poi <- brm(
  82. bf(
  83. mvbind(rank_eeg_dissim, rank_jacc, rank_ot, rank_comp_dist, rank_freq_dist, rank_phon_dist, rank_name_phon_dist) ~ 0
  84. ) +
  85. set_rescor(rescor=TRUE),
  86. family = brmsfamily("gaussian"),
  87. iter = 24000,
  88. warmup = 6000,
  89. chains = 8,
  90. cores = n_cores,
  91. seed = 1,
  92. # centre each dimension since we don't model intercept
  93. data = mutate(rdm_poi, across(starts_with("rank"), function(x) x - mean(x))),
  94. prior = m_rho_prior_full,
  95. save_pars = save_pars(all=TRUE),
  96. sample_prior = FALSE,
  97. control = list(adapt_delta = 0.99)
  98. )
  99. # get a vector of response variable names (will match order of correlation matrix)
  100. var_names <- colnames(get_y(m_rho_full_poi))
  101. # get correlation matrices for each draw
  102. cor_samps <- as_draws_df(m_rho_full_poi, "^Lrescor", regex=TRUE) |>
  103. # convert columns into covariance, correlation, and partial correlation matrix for each draw
  104. pivot_longer(cols=starts_with("Lrescor"), names_to="idx", values_to="r") |>
  105. mutate(idx = sub("Lrescor", "", idx, fixed=TRUE)) |>
  106. mutate(idx = gsub("(\\[)|(\\])", "", idx)) |>
  107. separate(idx, c("idx1", "idx2"), sep=",") |>
  108. mutate(idx1=as.integer(idx1), idx2=as.integer(idx2)) |>
  109. group_by(.chain, .iteration, .draw) |>
  110. nest() |>
  111. # get correlation and partial correlation matrix for each draw
  112. mutate(
  113. cor_mat = map(
  114. data, function(x) {
  115. cov_mat <- matrix(0, nrow=max(x$idx1), ncol=max(x$idx2))
  116. cov_mat[cbind(x$idx1, x$idx2)] <- x$r
  117. cov_mat[upper.tri(cov_mat, diag=FALSE)] <- t(cov_mat)[upper.tri(cov_mat, diag=FALSE)]
  118. colnames(cov_mat) <- var_names
  119. rownames(cov_mat) <- var_names
  120. cov2cor(cov_mat)
  121. }
  122. ),
  123. pcor_mat = map(
  124. cor_mat, function(x) {
  125. pcor_mat <- corpcor::cor2pcor(x)
  126. colnames(pcor_mat) <- var_names
  127. rownames(pcor_mat) <- var_names
  128. pcor_mat
  129. }
  130. ),
  131. cor_rankeegdissim_rankjacc = map_dbl(cor_mat, function(x) x["rankeegdissim", "rankjacc"]),
  132. cor_rankeegdissim_rankot = map_dbl(cor_mat, function(x) x["rankeegdissim", "rankot"]),
  133. cor_rankeegdissim_rankcompdist = map_dbl(cor_mat, function(x) x["rankeegdissim", "rankcompdist"]),
  134. cor_rankeegdissim_rankfreqdist = map_dbl(cor_mat, function(x) x["rankeegdissim", "rankfreqdist"]),
  135. cor_rankeegdissim_rankphondist = map_dbl(cor_mat, function(x) x["rankeegdissim", "rankphondist"]),
  136. cor_rankeegdissim_ranknamephondist = map_dbl(cor_mat, function(x) x["rankeegdissim", "ranknamephondist"]),
  137. pcor_rankeegdissim_rankjacc = map_dbl(pcor_mat, function(x) x["rankeegdissim", "rankjacc"]),
  138. pcor_rankeegdissim_rankot = map_dbl(pcor_mat, function(x) x["rankeegdissim", "rankot"]),
  139. pcor_rankeegdissim_rankcompdist = map_dbl(pcor_mat, function(x) x["rankeegdissim", "rankcompdist"]),
  140. pcor_rankeegdissim_rankfreqdist = map_dbl(pcor_mat, function(x) x["rankeegdissim", "rankfreqdist"]),
  141. pcor_rankeegdissim_rankphondist = map_dbl(pcor_mat, function(x) x["rankeegdissim", "rankphondist"]),
  142. pcor_rankeegdissim_ranknamephondist = map_dbl(pcor_mat, function(x) x["rankeegdissim", "ranknamephondist"])
  143. )
  144. # calculate Evidence Ratios for each hypothesis
  145. bf_1 <- sum(cor_samps$cor_rankeegdissim_rankot > 0) / sum(cor_samps$cor_rankeegdissim_rankot < 0)
  146. bf_2a <- sum(cor_samps$cor_rankeegdissim_rankot > cor_samps$cor_rankeegdissim_rankjacc) /
  147. sum(cor_samps$cor_rankeegdissim_rankot <= cor_samps$cor_rankeegdissim_rankjacc)
  148. bf_2b <- sum(cor_samps$pcor_rankeegdissim_rankot > cor_samps$pcor_rankeegdissim_rankjacc) /
  149. sum(cor_samps$pcor_rankeegdissim_rankot <= cor_samps$pcor_rankeegdissim_rankjacc)
  150. # save the partial Rho estimates
  151. cor_samps_long <- cor_samps |>
  152. pivot_longer(c(starts_with("pcor_rankeegdissim"), starts_with("cor_rankeegdissim")), names_to="cor_lab", values_to="Rho") |>
  153. mutate(
  154. model = case_when(
  155. grepl("rankjacc", cor_lab, fixed=TRUE) ~ "Jaccard Distance",
  156. grepl("rankot", cor_lab, fixed=TRUE) ~ "Wasserstein Distance",
  157. grepl("rankcompdist", cor_lab, fixed=TRUE) ~ "Complexity Distance",
  158. grepl("rankfreqdist", cor_lab, fixed=TRUE) ~ "Frequency Distance",
  159. grepl("rankphondist", cor_lab, fixed=TRUE) ~ "Phonological Distance",
  160. grepl("ranknamephondist", cor_lab, fixed=TRUE) ~ "Letter Name Phonological Distance",
  161. .default = NA
  162. ),
  163. is_partial = grepl("^pcor", cor_lab)
  164. ) |>
  165. dplyr::select(-data, -cor_mat, -pcor_mat)
  166. cor_samps_long |>
  167. group_by(model, is_partial) |>
  168. median_hdi(Rho, .width=.89)
  169. saveRDS(cor_samps_long, file.path("estimates", "controls_cor_samps_long.rds"))
  170. # fit time-resolved models -----------------------------------------------
  171. # fit models for each time point
  172. # Note: fitting models to all subjects' data, where ranked within subject, gives roughly equivalent peak posteriors to averaging over per-subject Rho estimates
  173. # fit models in parallel for speed
  174. library(future)
  175. plan(multicore, workers=n_cores)
  176. message("Fitting time-resolved models in parallel")
  177. m_rho <- brm_multiple(
  178. bf(
  179. mvbind(rank_eeg_dissim, rank_jacc, rank_ot, rank_comp_dist, rank_freq_dist, rank_phon_dist, rank_name_phon_dist) ~ 0
  180. ) +
  181. set_rescor(rescor=TRUE),
  182. family = brmsfamily("gaussian"),
  183. iter = 10000,
  184. warmup = 5000,
  185. chains = 4,
  186. cores = 1,
  187. seed = 1,
  188. # centre each dimension since we don't model intercept, then split by time points into list
  189. data = rdm |>
  190. group_by(time) |>
  191. mutate(across(starts_with("rank"), function(x) x - mean(x))) |>
  192. ungroup() |>
  193. group_split(time),
  194. combine = FALSE,
  195. prior = m_rho_prior_full,
  196. control = list(adapt_delta = 0.99),
  197. save_pars = save_pars(all=TRUE)
  198. )
  199. # get posteriors for each time point
  200. tc <- map_df(1:length(m_rho), function(i) {
  201. t_i <- times[[i]]
  202. m_i <- m_rho[[i]]
  203. var_names_i <- colnames(get_y(m_i))
  204. as_draws_df(m_i, "^Lrescor", regex=TRUE) |>
  205. # convert columns into covariance, correlation, and partial correlation matrix for each draw
  206. pivot_longer(cols=starts_with("Lrescor"), names_to="idx", values_to="r") |>
  207. mutate(idx = sub("Lrescor", "", idx, fixed=TRUE)) |>
  208. mutate(idx = gsub("(\\[)|(\\])", "", idx)) |>
  209. separate(idx, c("idx1", "idx2"), sep=",") |>
  210. mutate(idx1=as.integer(idx1), idx2=as.integer(idx2)) |>
  211. group_by(.chain, .iteration, .draw) |>
  212. nest() |>
  213. # get correlation and partial correlation matrix for each draw
  214. mutate(
  215. cor_mat = map(
  216. data, function(x) {
  217. cov_mat <- matrix(0, nrow=max(x$idx1), ncol=max(x$idx2))
  218. cov_mat[cbind(x$idx1, x$idx2)] <- x$r
  219. cov_mat[upper.tri(cov_mat, diag=FALSE)] <- t(cov_mat)[upper.tri(cov_mat, diag=FALSE)]
  220. colnames(cov_mat) <- var_names_i
  221. rownames(cov_mat) <- var_names_i
  222. cov2cor(cov_mat)
  223. }
  224. ),
  225. pcor_mat = map(
  226. cor_mat, function(x) {
  227. pcor_mat <- corpcor::cor2pcor(x)
  228. colnames(pcor_mat) <- var_names_i
  229. rownames(pcor_mat) <- var_names_i
  230. pcor_mat
  231. }
  232. ),
  233. cor_rankeegdissim_rankjacc = map_dbl(cor_mat, function(x) x["rankeegdissim", "rankjacc"]),
  234. cor_rankeegdissim_rankot = map_dbl(cor_mat, function(x) x["rankeegdissim", "rankot"]),
  235. cor_rankeegdissim_rankcompdist = map_dbl(cor_mat, function(x) x["rankeegdissim", "rankcompdist"]),
  236. cor_rankeegdissim_rankfreqdist = map_dbl(cor_mat, function(x) x["rankeegdissim", "rankfreqdist"]),
  237. cor_rankeegdissim_rankphondist = map_dbl(cor_mat, function(x) x["rankeegdissim", "rankphondist"]),
  238. cor_rankeegdissim_ranknamephondist = map_dbl(cor_mat, function(x) x["rankeegdissim", "ranknamephondist"]),
  239. pcor_rankeegdissim_rankjacc = map_dbl(pcor_mat, function(x) x["rankeegdissim", "rankjacc"]),
  240. pcor_rankeegdissim_rankot = map_dbl(pcor_mat, function(x) x["rankeegdissim", "rankot"]),
  241. pcor_rankeegdissim_rankcompdist = map_dbl(pcor_mat, function(x) x["rankeegdissim", "rankcompdist"]),
  242. pcor_rankeegdissim_rankfreqdist = map_dbl(pcor_mat, function(x) x["rankeegdissim", "rankfreqdist"]),
  243. pcor_rankeegdissim_rankphondist = map_dbl(pcor_mat, function(x) x["rankeegdissim", "rankphondist"]),
  244. pcor_rankeegdissim_ranknamephondist = map_dbl(pcor_mat, function(x) x["rankeegdissim", "ranknamephondist"]),
  245. time = t_i
  246. ) |>
  247. dplyr::select(-data, -cor_mat, -pcor_mat)
  248. })
  249. saveRDS(tc, file.path("estimates", "controls_timecourse.rds"))