123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223 |
- library(readr)
- library(dplyr)
- library(purrr)
- library(brms)
- library(tidybayes)
- library(parallel)
- # N cores for parallelisation
- n_cores <- parallel::detectCores(all.tests=FALSE, logical=TRUE)
- # n_cores <- 14
- # short function for calculating the population SD
- pop_sd <- function(x) sqrt((length(x)-1)/length(x)) * sd(x)
- # paths to each model's correlations
- geom_paths <- file.path("stim_sim", "ot_geom") |>
- list.files("^.*\\.csv$", full.names=TRUE, include.dirs=TRUE)
- # neural RDM for period of interest
- rdm_poi <- file.path("rdm_data", "period_of_interest") |>
- list.files(pattern=".*\\.csv", full.names=TRUE) |>
- map_df(read_csv, col_types=c("subj_id"="c")) |>
- group_by(subj_id) |>
- mutate(rank_eeg_dissim = rank(eeg_dissim)) |>
- ungroup()
- # neural RDM for exploratory P1 period
- rdm_p1 <- file.path("rdm_data", "p1_period") |>
- list.files(pattern=".*\\.csv", full.names=TRUE) |>
- map_df(read_csv, col_types=c("subj_id"="c")) |>
- group_by(subj_id) |>
- mutate(rank_eeg_dissim = rank(eeg_dissim)) |>
- ungroup()
- for (period in c("p1", "poi")) {
- if (period=="p1") {
- rdm_period <- rdm_p1
- } else {
- rdm_period <- rdm_poi
- }
-
- cor_res_period <- map_df(geom_paths, function(path) {
- geom_lab <- tools::file_path_sans_ext(basename(path))
-
- geom_lab_vars <- if (geom_lab == "ot_pgw") {
- c( T=TRUE, S=TRUE, R=TRUE )
- } else {
- geom_lab |>
- strsplit("_", fixed=TRUE) %>%
- .[[1]] %>%
- set_names(gsub("\\d", "", .)) %>%
- gsub("\\D", "", .) |>
- sapply(as.numeric) |>
- sapply(as.logical)
- }
-
- message(sprintf("%s: %s", period, geom_lab))
-
- geom_sim <- read_csv(path, col_types = cols(char1=col_character(), char2=col_character())) |>
- rename(geom_dissim = ot) |>
- mutate(
- rank_geom_dissim = rank(geom_dissim)
- ) |>
- select(char1, char2, rank_geom_dissim) # only used variables
-
- if (any( sort(unique(c(rdm_period$char1, rdm_period$char2))) != sort(unique(c(geom_sim$char1, geom_sim$char2))) )) {
- stop("ID Mismatch")
- }
-
- rdm_period_geom <- left_join(rdm_period, geom_sim, by=c("char1", "char2"))
-
- # fit model
- lkj_prior <- 1.5
-
- m_rho_prior_full <- c(
- set_prior(sprintf("lkj(%s)", lkj_prior), class="rescor"),
- set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_period_geom$rank_eeg_dissim))), class="sigma", resp="rankeegdissim", ub=max(rdm_period_geom$rank_eeg_dissim)),
- set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_period_geom$rank_geom_dissim))), class="sigma", resp="rankgeomdissim", ub=max(rdm_period_geom$rank_geom_dissim))
- )
-
- m_rho_full_period <- brm(
- bf(
- mvbind(rank_eeg_dissim, rank_geom_dissim) ~ 0
- ) +
- set_rescor(rescor=TRUE),
- family = brmsfamily("gaussian"),
- iter = 10000,
- warmup = 5000,
- chains = 8,
- cores = n_cores,
- seed = 1,
- # centre each dimension since we don't model intercept
- data = mutate(rdm_period_geom, across(starts_with("rank"), function(x) x - mean(x))),
- prior = m_rho_prior_full,
- save_pars = save_pars(all=TRUE),
- sample_prior = FALSE,
- control = list(adapt_delta = 0.99)
- )
-
- m_rho_full_period |>
- as_draws_df("rescor__rankeegdissim__rankgeomdissim") |>
- select(-starts_with(".")) |>
- rename(rho = 1) |>
- mutate(
- translate = geom_lab_vars[["T"]],
- scale = geom_lab_vars[["S"]],
- rotate = geom_lab_vars[["R"]],
- gromov_wasserstein = geom_lab=="ot_pgw"
- )
- })
-
- saveRDS(cor_res_period, file.path("estimates", sprintf("ot_geom_%s_draws.rds", period)))
- }
- library(future)
- plan(multicore, workers=n_cores)
- # import neural rdm
- rdm <- file.path("rdm_data", "time_resolved") |>
- list.files(pattern=".*\\.csv", full.names=TRUE) |>
- map_df(read_csv, col_types=c("subj_id"="c")) |>
- group_by(subj_id, time) |>
- mutate(rank_eeg_dissim = rank(eeg_dissim)) |>
- ungroup() |>
- arrange(time)
- times <- sort(unique(rdm$time))
- # for each model, for each module, get the time-resolved median and HDIs for correlation estimates
- # - done via independent models for each time point and ANN
- cor_res <- map_df(geom_paths, function(path) {
- geom_lab <- tools::file_path_sans_ext(basename(path))
-
- geom_lab_vars <- if (geom_lab == "ot_pgw") {
- c( T=TRUE, S=TRUE, R=TRUE )
- } else {
- geom_lab |>
- strsplit("_", fixed=TRUE) %>%
- .[[1]] %>%
- set_names(gsub("\\d", "", .)) %>%
- gsub("\\D", "", .) |>
- sapply(as.numeric) |>
- sapply(as.logical)
- }
-
- message(sprintf("Time-Resolved: %s", geom_lab))
-
- geom <- read_csv(path, col_types = cols(char1=col_character(), char2=col_character())) |>
- rename(geom_dissim = ot) |>
- mutate(
- rank_geom_dissim = rank(geom_dissim)
- ) |>
- select(char1, char2, rank_geom_dissim) # only used variables
-
- if (any( sort(unique(c(rdm_poi$char1, rdm_poi$char2))) != sort(unique(c(geom$char1, geom$char2))) )) {
- stop("ID Mismatch")
- }
-
- rdm_geom <- left_join(geom, rdm, by=c("char1", "char2"))
-
- # correlation prior
- lkj_prior <- 1
-
- m_rho_prior_full <- c(
- set_prior(sprintf("lkj(%s)", lkj_prior), class="rescor"),
- set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_geom$rank_eeg_dissim))), class="sigma", resp="rankeegdissim", ub=max(rdm_geom$rank_eeg_dissim)),
- set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_geom$rank_geom_dissim))), class="sigma", resp="rankgeomdissim", ub=max(rdm_geom$rank_geom_dissim))
- )
-
- # fit the models in parallel
- m_rho <- brm_multiple(
- bf(
- mvbind(rank_eeg_dissim, rank_geom_dissim) ~ 0
- ) +
- set_rescor(rescor=TRUE),
- family = brmsfamily("gaussian"),
- iter = 10000,
- warmup = 5000,
- chains = 8,
- cores = 1,
- seed = 1,
- # centre each dimension since we don't model intercept, then split by time points into list
- data = rdm_geom |>
- group_by(time) |>
- mutate(across(starts_with("rank"), function(x) x - mean(x))) |>
- ungroup() |>
- group_split(time),
- combine = FALSE,
- prior = m_rho_prior_full,
- control = list(adapt_delta = 0.99),
- silent = TRUE,
- refresh = 0
- )
-
- ests <- map_df(1:length(times), function(t) {
- m_t <- m_rho[[t]]
- time_t <- times[[t]]
-
- m_t |>
- as_draws_df("rescor__rankeegdissim__rankgeomdissim") |>
- rename(rho = 1) |>
- mutate(time = time_t)
- }) |>
- group_by(time) |>
- median_hdi(rho, .width=0.89)
-
- ests |>
- mutate(
- translate = geom_lab_vars[["T"]],
- scale = geom_lab_vars[["S"]],
- rotate = geom_lab_vars[["R"]],
- gromov_wasserstein = geom_lab=="ot_pgw"
- ) |>
- dplyr::select(translate, scale, rotate, gromov_wasserstein, time, rho, .lower, .upper) |>
- mutate(
- time = factor(time, levels=times)
- )
- })
- saveRDS(cor_res, file.path("estimates", "ot_geom_time_resolved.rds"))
|