10_sensitivity_analysis.R 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # This script runs a sensitivity analysis, where we alter the value of the LKJ eta parameter and extract the posteriors
  2. library(readr)
  3. library(dplyr)
  4. library(purrr)
  5. library(tidyr)
  6. library(brms)
  7. library(tidybayes)
  8. library(parallel)
  9. n_cores <- parallel::detectCores() # total number of cores
  10. n_cores_per_model <- 4 # will be best if n_cores is divisible by this
  11. n_parallel_models <- floor(n_cores / n_cores_per_model) # number of models to be fit in parallel at one time
  12. set.seed(1)
  13. stim_sim <- left_join(
  14. read_csv(file.path("stim_sim", "preregistered", "jacc.csv")),
  15. read_csv(file.path("stim_sim", "preregistered", "ot.csv")),
  16. by = c("char1", "char2")
  17. ) |>
  18. mutate(
  19. rank_jacc = rank(jacc),
  20. rank_ot = rank(ot)
  21. ) |>
  22. rowwise() |>
  23. mutate(pair_id = paste(sort(c(char1, char2)), collapse="_")) |>
  24. ungroup()
  25. rdm_poi <- file.path("rdm_data", "period_of_interest") |>
  26. list.files(pattern=".*\\.csv", full.names=TRUE) |>
  27. map_df(read_csv, col_types=c("subj_id"="c")) |>
  28. left_join(stim_sim, by=c("char1", "char2")) |>
  29. group_by(subj_id) |>
  30. mutate(rank_eeg_dissim = rank(eeg_dissim)) |>
  31. ungroup()
  32. # short function for calculating the population SD
  33. pop_sd <- function(x) sqrt((length(x)-1)/length(x)) * sd(x)
  34. # function to fit the model
  35. fit_mod <- function(lkj_prior, sample_prior="no") {
  36. m_rho_prior_full <- c(
  37. set_prior(sprintf("lkj(%s)", lkj_prior), class="rescor"),
  38. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_poi$rank_eeg_dissim))), class="sigma", resp="rankeegdissim", ub=max(rdm_poi$rank_eeg_dissim)),
  39. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_poi$rank_jacc))), class="sigma", resp="rankjacc", ub=max(rdm_poi$rank_jacc)),
  40. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_poi$rank_ot))), class="sigma", resp="rankot", ub=max(rdm_poi$rank_ot))
  41. )
  42. m_rho_full_poi <- brm(
  43. bf(
  44. mvbind(rank_eeg_dissim, rank_ot, rank_jacc) ~ 0
  45. ) +
  46. set_rescor(rescor=TRUE),
  47. family = brmsfamily("gaussian"),
  48. iter = 24000,
  49. warmup = 6000,
  50. chains = 8,
  51. cores = n_cores_per_model,
  52. seed = 1,
  53. # centre each dimension since we don't model intercept
  54. data = mutate(rdm_poi, across(starts_with("rank"), function(x) x - mean(x))),
  55. prior = m_rho_prior_full,
  56. save_pars = save_pars(all=TRUE),
  57. sample_prior = sample_prior,
  58. control = list(adapt_delta = 0.99)
  59. )
  60. m_rho_full_poi
  61. }
  62. # function to fit the model and return the median and 89% HDI for Wasserstein and Jaccard for the posteriors
  63. sample_and_summarise_posteriors <- function(lkj_prior) {
  64. m_rho_full_poi <- fit_mod(lkj_prior, sample_prior="no")
  65. # get a vector of response variable names (will match order of correlation matrix)
  66. var_names <- colnames(get_y(m_rho_full_poi))
  67. # summarise the posteriors
  68. summ <- as_draws_df(m_rho_full_poi, "^Lrescor", regex=TRUE) |>
  69. # convert columns into covariance, correlation, and partial correlation matrix for each draw
  70. pivot_longer(cols=starts_with("Lrescor"), names_to="idx", values_to="r") |>
  71. mutate(idx = sub("Lrescor", "", idx, fixed=TRUE)) |>
  72. mutate(idx = gsub("(\\[)|(\\])", "", idx)) |>
  73. separate(idx, c("idx1", "idx2"), sep=",") |>
  74. mutate(idx1=as.integer(idx1), idx2=as.integer(idx2)) |>
  75. group_by(.chain, .iteration, .draw) |>
  76. nest() |>
  77. # get correlation and partial correlation matrix for each draw
  78. mutate(
  79. cor_mat = map(
  80. data, function(x) {
  81. cov_mat <- matrix(0, nrow=max(x$idx1), ncol=max(x$idx2))
  82. cov_mat[cbind(x$idx1, x$idx2)] <- x$r
  83. cov_mat[upper.tri(cov_mat, diag=FALSE)] <- t(cov_mat)[upper.tri(cov_mat, diag=FALSE)]
  84. colnames(cov_mat) <- var_names
  85. rownames(cov_mat) <- var_names
  86. cov2cor(cov_mat)
  87. }
  88. ),
  89. pcor_mat = map(
  90. cor_mat, function(x) {
  91. pcor_mat <- corpcor::cor2pcor(x)
  92. colnames(pcor_mat) <- var_names
  93. rownames(pcor_mat) <- var_names
  94. pcor_mat
  95. }
  96. ),
  97. cor_rankeegdissim_rankjacc = map_dbl(cor_mat, function(x) x["rankeegdissim", "rankjacc"]),
  98. cor_rankeegdissim_rankot = map_dbl(cor_mat, function(x) x["rankeegdissim", "rankot"]),
  99. pcor_rankeegdissim_rankjacc = map_dbl(pcor_mat, function(x) x["rankeegdissim", "rankjacc"]),
  100. pcor_rankeegdissim_rankot = map_dbl(pcor_mat, function(x) x["rankeegdissim", "rankot"])
  101. ) |>
  102. pivot_longer(c(starts_with("pcor_rankeegdissim"), starts_with("cor_rankeegdissim")), names_to="cor_lab", values_to="Rho") |>
  103. mutate(
  104. model = ifelse(grepl("jacc", cor_lab, fixed=TRUE), "Jaccard Distance", "Wasserstein Distance"),
  105. is_partial = grepl("^pcor", cor_lab)
  106. ) |>
  107. dplyr::select(-data, -cor_mat, -pcor_mat) |>
  108. group_by(model, is_partial) |>
  109. median_hdi(Rho, .width=.89) |>
  110. ungroup() |>
  111. mutate(eta = lkj_prior) |>
  112. select(eta, everything())
  113. summ
  114. }
  115. # set up a vector of eta values to use as the lkj prior
  116. n_models <- 50
  117. lkj_vals <- 1 * 10 ** seq(-3, 3, length.out=n_models)
  118. # make a cluster with as many workers as necessary
  119. cl <- makeCluster( n_parallel_models )
  120. cl_pkg <- clusterEvalQ(cl, {
  121. library(dplyr)
  122. library(purrr)
  123. library(tidyr)
  124. library(brms)
  125. library(tidybayes)
  126. })
  127. cl_obj <- clusterExport(cl, list(
  128. "pop_sd", "rdm_poi",
  129. "fit_mod",
  130. "n_cores_per_model"
  131. ))
  132. posterior_res <- parLapply(cl, lkj_vals, sample_and_summarise_posteriors) |>
  133. reduce(bind_rows)
  134. stopCluster(cl)
  135. saveRDS(posterior_res, file.path("estimates", "sensitivity_lkj_prior.rds"))