06_jacc_geom.R 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. library(readr)
  2. library(dplyr)
  3. library(purrr)
  4. library(brms)
  5. library(tidybayes)
  6. library(parallel)
  7. # N cores for parallelisation
  8. n_cores <- parallel::detectCores(all.tests=FALSE, logical=TRUE)
  9. # n_cores <- 13
  10. # short function for calculating the population SD
  11. pop_sd <- function(x) sqrt((length(x)-1)/length(x)) * sd(x)
  12. # paths to each model's correlations
  13. geom_paths <- file.path("stim_sim", "jacc_geom") |>
  14. list.files("^.*\\.csv$", full.names=TRUE, include.dirs=TRUE)
  15. # neural RDM for period of interest
  16. rdm_poi <- file.path("rdm_data", "period_of_interest") |>
  17. list.files(pattern=".*\\.csv", full.names=TRUE) |>
  18. map_df(read_csv, col_types=c("subj_id"="c")) |>
  19. group_by(subj_id) |>
  20. mutate(rank_eeg_dissim = rank(eeg_dissim)) |>
  21. ungroup()
  22. # neural RDM for exploratory P1 period
  23. rdm_p1 <- file.path("rdm_data", "p1_period") |>
  24. list.files(pattern=".*\\.csv", full.names=TRUE) |>
  25. map_df(read_csv, col_types=c("subj_id"="c")) |>
  26. group_by(subj_id) |>
  27. mutate(rank_eeg_dissim = rank(eeg_dissim)) |>
  28. ungroup()
  29. for (period in c("p1", "poi")) {
  30. if (period=="p1") {
  31. rdm_period <- rdm_p1
  32. } else {
  33. rdm_period <- rdm_poi
  34. }
  35. cor_res_period <- map_df(geom_paths, function(path) {
  36. geom_lab <- tools::file_path_sans_ext(basename(path))
  37. geom_lab_vars <- geom_lab |>
  38. strsplit("_", fixed=TRUE) %>%
  39. .[[1]] %>%
  40. set_names(gsub("\\d", "", .)) %>%
  41. gsub("\\D", "", .) |>
  42. sapply(as.numeric) |>
  43. sapply(as.logical)
  44. message(sprintf("%s: %s", period, geom_lab))
  45. geom_sim <- read_csv(path, col_types = cols(char1=col_character(), char2=col_character())) |>
  46. rename(geom_dissim = jacc) |>
  47. mutate(
  48. rank_geom_dissim = rank(geom_dissim)
  49. ) |>
  50. select(char1, char2, rank_geom_dissim) # only used variables
  51. if (any( sort(unique(c(rdm_period$char1, rdm_period$char2))) != sort(unique(c(geom_sim$char1, geom_sim$char2))) )) {
  52. stop("ID Mismatch")
  53. }
  54. rdm_period_geom <- left_join(rdm_period, geom_sim, by=c("char1", "char2"))
  55. # fit model
  56. lkj_prior <- 1.5
  57. m_rho_prior_full <- c(
  58. set_prior(sprintf("lkj(%s)", lkj_prior), class="rescor"),
  59. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_period_geom$rank_eeg_dissim))), class="sigma", resp="rankeegdissim", ub=max(rdm_period_geom$rank_eeg_dissim)),
  60. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_period_geom$rank_geom_dissim))), class="sigma", resp="rankgeomdissim", ub=max(rdm_period_geom$rank_geom_dissim))
  61. )
  62. m_rho_full_period <- brm(
  63. bf(
  64. mvbind(rank_eeg_dissim, rank_geom_dissim) ~ 0
  65. ) +
  66. set_rescor(rescor=TRUE),
  67. family = brmsfamily("gaussian"),
  68. iter = 10000,
  69. warmup = 5000,
  70. chains = 8,
  71. cores = n_cores,
  72. seed = 1,
  73. # centre each dimension since we don't model intercept
  74. data = mutate(rdm_period_geom, across(starts_with("rank"), function(x) x - mean(x))),
  75. prior = m_rho_prior_full,
  76. save_pars = save_pars(all=TRUE),
  77. sample_prior = FALSE,
  78. control = list(adapt_delta = 0.99)
  79. )
  80. m_rho_full_period |>
  81. as_draws_df("rescor__rankeegdissim__rankgeomdissim") |>
  82. select(-starts_with(".")) |>
  83. rename(rho = 1) |>
  84. mutate(
  85. translate = geom_lab_vars[["T"]],
  86. scale = geom_lab_vars[["S"]],
  87. rotate = geom_lab_vars[["R"]]
  88. )
  89. })
  90. saveRDS(cor_res_period, file.path("estimates", sprintf("jacc_geom_%s_draws.rds", period)))
  91. }
  92. library(future)
  93. plan(multicore, workers=n_cores)
  94. # import neural rdm
  95. rdm <- file.path("rdm_data", "time_resolved") |>
  96. list.files(pattern=".*\\.csv", full.names=TRUE) |>
  97. map_df(read_csv, col_types=c("subj_id"="c")) |>
  98. group_by(subj_id, time) |>
  99. mutate(rank_eeg_dissim = rank(eeg_dissim)) |>
  100. ungroup() |>
  101. arrange(time)
  102. times <- sort(unique(rdm$time))
  103. # for each model, for each module, get the time-resolved median and HDIs for correlation estimates
  104. # - done via independent models for each time point and ANN
  105. cor_res <- map_df(geom_paths, function(path) {
  106. geom_lab <- tools::file_path_sans_ext(basename(path))
  107. geom_lab_vars <- geom_lab |>
  108. strsplit("_", fixed=TRUE) %>%
  109. .[[1]] %>%
  110. set_names(gsub("\\d", "", .)) %>%
  111. gsub("\\D", "", .) |>
  112. sapply(as.numeric) |>
  113. sapply(as.logical)
  114. message(sprintf("Time-Resolved: %s", geom_lab))
  115. geom <- read_csv(path, col_types = cols(char1=col_character(), char2=col_character())) |>
  116. rename(geom_dissim = jacc) |>
  117. mutate(
  118. rank_geom_dissim = rank(geom_dissim)
  119. ) |>
  120. select(char1, char2, rank_geom_dissim) # only used variables
  121. if (any( sort(unique(c(rdm$char1, rdm$char2))) != sort(unique(c(geom$char1, geom$char2))) )) {
  122. stop("ID Mismatch")
  123. }
  124. rdm_geom <- left_join(geom, rdm, by=c("char1", "char2"))
  125. # correlation prior
  126. lkj_prior <- 1
  127. m_rho_prior_full <- c(
  128. set_prior(sprintf("lkj(%s)", lkj_prior), class="rescor"),
  129. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_geom$rank_eeg_dissim))), class="sigma", resp="rankeegdissim", ub=max(rdm_geom$rank_eeg_dissim)),
  130. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_geom$rank_geom_dissim))), class="sigma", resp="rankgeomdissim", ub=max(rdm_geom$rank_geom_dissim))
  131. )
  132. # fit the models in parallel
  133. m_rho <- brm_multiple(
  134. bf(
  135. mvbind(rank_eeg_dissim, rank_geom_dissim) ~ 0
  136. ) +
  137. set_rescor(rescor=TRUE),
  138. family = brmsfamily("gaussian"),
  139. iter = 5000,
  140. warmup = 2500,
  141. chains = 4,
  142. cores = 1,
  143. seed = 1,
  144. # centre each dimension since we don't model intercept, then split by time points into list
  145. data = rdm_geom |>
  146. group_by(time) |>
  147. mutate(across(starts_with("rank"), function(x) x - mean(x))) |>
  148. ungroup() |>
  149. group_split(time),
  150. combine = FALSE,
  151. prior = m_rho_prior_full,
  152. control = list(adapt_delta = 0.9),
  153. silent = TRUE,
  154. refresh = 0
  155. )
  156. ests <- map_df(1:length(times), function(t) {
  157. m_t <- m_rho[[t]]
  158. time_t <- times[[t]]
  159. m_t |>
  160. as_draws_df("rescor__rankeegdissim__rankgeomdissim") |>
  161. rename(rho = 1) |>
  162. mutate(time = time_t)
  163. }) |>
  164. group_by(time) |>
  165. median_hdi(rho, .width=0.89)
  166. ests |>
  167. mutate(
  168. translate = geom_lab_vars[["T"]],
  169. scale = geom_lab_vars[["S"]],
  170. rotate = geom_lab_vars[["R"]]
  171. ) |>
  172. dplyr::select(translate, scale, rotate, time, rho, .lower, .upper) |>
  173. mutate(
  174. time = factor(time, levels=times)
  175. )
  176. })
  177. saveRDS(cor_res, file.path("estimates", "jacc_geom_time_resolved.rds"))