04_rsa.R 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. library(readr)
  2. library(dplyr)
  3. library(purrr)
  4. library(tidyr)
  5. library(brms)
  6. library(ggplot2)
  7. library(ggdist)
  8. # library(fda)
  9. library(patchwork)
  10. library(parallel)
  11. n_cores <- parallel::detectCores(all.tests=FALSE, logical=TRUE)
  12. # n_cores <- 14
  13. # options(loo.cores = n_cores)
  14. set.seed(1)
  15. stim_sim <- left_join(
  16. read_csv(file.path("stim_sim", "preregistered", "jacc.csv")),
  17. read_csv(file.path("stim_sim", "preregistered", "ot.csv")),
  18. by = c("char1", "char2")
  19. ) |>
  20. mutate(
  21. rank_jacc = rank(jacc),
  22. rank_ot = rank(ot)
  23. ) |>
  24. rowwise() |>
  25. mutate(pair_id = paste(sort(c(char1, char2)), collapse="_")) |>
  26. ungroup()
  27. cor(stim_sim$rank_ot, stim_sim$rank_jacc)
  28. rdm <- file.path("rdm_data", "time_resolved") |>
  29. list.files(pattern=".*\\.csv", full.names=TRUE) |>
  30. map_df(read_csv, col_types=c("subj_id"="c")) |>
  31. left_join(stim_sim, by=c("char1", "char2")) |>
  32. group_by(subj_id, time) |>
  33. mutate(rank_eeg_dissim = rank(eeg_dissim)) |>
  34. ungroup() |>
  35. arrange(time)
  36. times <- sort(unique(rdm$time))
  37. rdm_poi <- file.path("rdm_data", "period_of_interest") |>
  38. list.files(pattern=".*\\.csv", full.names=TRUE) |>
  39. map_df(read_csv, col_types=c("subj_id"="c")) |>
  40. left_join(stim_sim, by=c("char1", "char2")) |>
  41. group_by(subj_id) |>
  42. mutate(rank_eeg_dissim = rank(eeg_dissim)) |>
  43. ungroup()
  44. # fit full model -------------------------------------------
  45. # this is a multivariate model that estimates correlations between all items
  46. # partial corr model
  47. # quick approximation
  48. # ppcor::pcor(dplyr::select(rdm_poi, starts_with("rank")))
  49. # check that cor2pcor works
  50. # cor_mat <- cor(dplyr::select(rdm_poi, rank_jacc, rank_ot, rank_eeg_dissim))
  51. # corpcor::cor2pcor(cor_mat)
  52. # short function for calculating the population SD
  53. pop_sd <- function(x) sqrt((length(x)-1)/length(x)) * sd(x)
  54. # check correlation priors
  55. lkj_prior <- 1.5
  56. # lkj_prior <- 10000
  57. # r <- rethinking::rlkjcorr(n=1e5, K=3, eta=lkj_prior)
  58. # plot(density(r[, 1, 2]))
  59. # plot(density(r[, 1, 3]))
  60. # plot(density(r[, 1, 2] - r[, 2, 3]))
  61. m_rho_prior_full <- c(
  62. # biased towards 0 for non-diagonal correlations, because lower correlations are more likely in RSA of EEG data
  63. set_prior(sprintf("lkj(%s)", lkj_prior), class="rescor"),
  64. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_poi$rank_eeg_dissim))), class="sigma", resp="rankeegdissim", ub=max(rdm_poi$rank_eeg_dissim)),
  65. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_poi$rank_jacc))), class="sigma", resp="rankjacc", ub=max(rdm_poi$rank_jacc)),
  66. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_poi$rank_ot))), class="sigma", resp="rankot", ub=max(rdm_poi$rank_ot))
  67. )
  68. m_rho_full_poi <- brm(
  69. bf(
  70. mvbind(rank_eeg_dissim, rank_ot, rank_jacc) ~ 0
  71. ) +
  72. set_rescor(rescor=TRUE),
  73. family = brmsfamily("gaussian"),
  74. iter = 24000,
  75. warmup = 6000,
  76. chains = 8,
  77. cores = n_cores,
  78. seed = 1,
  79. # centre each dimension since we don't model intercept
  80. data = mutate(rdm_poi, across(starts_with("rank"), function(x) x - mean(x))),
  81. prior = m_rho_prior_full,
  82. save_pars = save_pars(all=TRUE),
  83. sample_prior = FALSE,
  84. control = list(adapt_delta = 0.99)
  85. )
  86. # full_samps <- as_draws_df(m_rho_full_poi, variable="^rescor_.*", regex=TRUE)
  87. # get a vector of response variable names (will match order of correlation matrix)
  88. var_names <- colnames(get_y(m_rho_full_poi))
  89. # get correlation matrices for each draw
  90. cor_samps <- as_draws_df(m_rho_full_poi, "^Lrescor", regex=TRUE) |>
  91. # convert columns into covariance, correlation, and partial correlation matrix for each draw
  92. pivot_longer(cols=starts_with("Lrescor"), names_to="idx", values_to="r") |>
  93. mutate(idx = sub("Lrescor", "", idx, fixed=TRUE)) |>
  94. mutate(idx = gsub("(\\[)|(\\])", "", idx)) |>
  95. separate(idx, c("idx1", "idx2"), sep=",") |>
  96. mutate(idx1=as.integer(idx1), idx2=as.integer(idx2)) |>
  97. group_by(.chain, .iteration, .draw) |>
  98. nest() |>
  99. # get correlation and partial correlation matrix for each draw
  100. mutate(
  101. cor_mat = map(
  102. data, function(x) {
  103. cov_mat <- matrix(0, nrow=max(x$idx1), ncol=max(x$idx2))
  104. cov_mat[cbind(x$idx1, x$idx2)] <- x$r
  105. cov_mat[upper.tri(cov_mat, diag=FALSE)] <- t(cov_mat)[upper.tri(cov_mat, diag=FALSE)]
  106. colnames(cov_mat) <- var_names
  107. rownames(cov_mat) <- var_names
  108. cov2cor(cov_mat)
  109. }
  110. ),
  111. pcor_mat = map(
  112. cor_mat, function(x) {
  113. pcor_mat <- corpcor::cor2pcor(x)
  114. colnames(pcor_mat) <- var_names
  115. rownames(pcor_mat) <- var_names
  116. pcor_mat
  117. }
  118. ),
  119. cor_rankeegdissim_rankjacc = map_dbl(cor_mat, function(x) x["rankeegdissim", "rankjacc"]),
  120. cor_rankeegdissim_rankot = map_dbl(cor_mat, function(x) x["rankeegdissim", "rankot"]),
  121. pcor_rankeegdissim_rankjacc = map_dbl(pcor_mat, function(x) x["rankeegdissim", "rankjacc"]),
  122. pcor_rankeegdissim_rankot = map_dbl(pcor_mat, function(x) x["rankeegdissim", "rankot"])
  123. )
  124. # calculate Evidence Ratios for each hypothesis
  125. bf_1 <- sum(cor_samps$cor_rankeegdissim_rankot > 0) / sum(cor_samps$cor_rankeegdissim_rankot < 0)
  126. bf_2a <- sum(cor_samps$cor_rankeegdissim_rankot > cor_samps$cor_rankeegdissim_rankjacc) /
  127. sum(cor_samps$cor_rankeegdissim_rankot <= cor_samps$cor_rankeegdissim_rankjacc)
  128. bf_2b <- sum(cor_samps$pcor_rankeegdissim_rankot > cor_samps$pcor_rankeegdissim_rankjacc) /
  129. sum(cor_samps$pcor_rankeegdissim_rankot <= cor_samps$pcor_rankeegdissim_rankjacc)
  130. # save the partial Rho estimates
  131. cor_samps_long <- cor_samps |>
  132. pivot_longer(c(starts_with("pcor_rankeegdissim"), starts_with("cor_rankeegdissim")), names_to="cor_lab", values_to="Rho") |>
  133. mutate(
  134. model = ifelse(grepl("jacc", cor_lab, fixed=TRUE), "Jaccard Distance", "Wasserstein Distance"),
  135. is_partial = grepl("^pcor", cor_lab)
  136. ) |>
  137. dplyr::select(-data, -cor_mat, -pcor_mat)
  138. cor_samps_long |>
  139. group_by(model, is_partial) |>
  140. median_hdi(Rho, .width=.89)
  141. saveRDS(cor_samps_long, file.path("estimates", "prereg_cor_samps_long.rds"))
  142. # reduce RAM
  143. rm(cor_samps, cor_samps_long, m_rho_full_poi, rdm_poi, stim_sim)
  144. gc()
  145. # fit time-resolved models -----------------------------------------------
  146. # fit models for each time point
  147. # Note: fitting models to all subjects' data, where ranked within subject, gives roughly equivalent peak posteriors to averaging over per-subject Rho estimates
  148. # fit models in parallel for speed
  149. library(future)
  150. plan(multicore, workers=n_cores)
  151. message("Fitting time-resolved models in parallel")
  152. m_rho <- brm_multiple(
  153. bf(
  154. mvbind(rank_eeg_dissim, rank_ot, rank_jacc) ~ 0
  155. ) +
  156. set_rescor(rescor=TRUE),
  157. family = brmsfamily("gaussian"),
  158. iter = 10000,
  159. warmup = 5000,
  160. chains = 8,
  161. cores = 1,
  162. seed = 1,
  163. # centre each dimension since we don't model intercept, then split by time points into list
  164. data = rdm |>
  165. group_by(time) |>
  166. mutate(across(starts_with("rank"), function(x) x - mean(x))) |>
  167. ungroup() |>
  168. group_split(time),
  169. combine = FALSE,
  170. prior = m_rho_prior_full,
  171. control = list(adapt_delta = 0.99),
  172. save_pars = save_pars(all=TRUE)
  173. )
  174. # get posteriors for each time point
  175. tc <- map_df(1:length(m_rho), function(i) {
  176. t_i <- times[[i]]
  177. m_i <- m_rho[[i]]
  178. var_names_i <- colnames(get_y(m_i))
  179. as_draws_df(m_i, "^Lrescor", regex=TRUE) |>
  180. # convert columns into covariance, correlation, and partial correlation matrix for each draw
  181. pivot_longer(cols=starts_with("Lrescor"), names_to="idx", values_to="r") |>
  182. mutate(idx = sub("Lrescor", "", idx, fixed=TRUE)) |>
  183. mutate(idx = gsub("(\\[)|(\\])", "", idx)) |>
  184. separate(idx, c("idx1", "idx2"), sep=",") |>
  185. mutate(idx1=as.integer(idx1), idx2=as.integer(idx2)) |>
  186. group_by(.chain, .iteration, .draw) |>
  187. nest() |>
  188. # get correlation and partial correlation matrix for each draw
  189. mutate(
  190. cor_mat = map(
  191. data, function(x) {
  192. cov_mat <- matrix(0, nrow=max(x$idx1), ncol=max(x$idx2))
  193. cov_mat[cbind(x$idx1, x$idx2)] <- x$r
  194. cov_mat[upper.tri(cov_mat, diag=FALSE)] <- t(cov_mat)[upper.tri(cov_mat, diag=FALSE)]
  195. colnames(cov_mat) <- var_names_i
  196. rownames(cov_mat) <- var_names_i
  197. cov2cor(cov_mat)
  198. }
  199. ),
  200. pcor_mat = map(
  201. cor_mat, function(x) {
  202. pcor_mat <- corpcor::cor2pcor(x)
  203. colnames(pcor_mat) <- var_names_i
  204. rownames(pcor_mat) <- var_names_i
  205. pcor_mat
  206. }
  207. ),
  208. cor_rankeegdissim_rankjacc = map_dbl(cor_mat, function(x) x["rankeegdissim", "rankjacc"]),
  209. cor_rankeegdissim_rankot = map_dbl(cor_mat, function(x) x["rankeegdissim", "rankot"]),
  210. pcor_rankeegdissim_rankjacc = map_dbl(pcor_mat, function(x) x["rankeegdissim", "rankjacc"]),
  211. pcor_rankeegdissim_rankot = map_dbl(pcor_mat, function(x) x["rankeegdissim", "rankot"]),
  212. time = t_i
  213. ) |>
  214. dplyr::select(-data, -cor_mat, -pcor_mat)
  215. })
  216. saveRDS(tc, file.path("estimates", "planned_timecourse.rds"))