05_ot_geom.R 6.7 KB


  1. library(readr)
  2. library(dplyr)
  3. library(purrr)
  4. library(brms)
  5. library(tidybayes)
  6. library(parallel)
  7. # N cores for parallelisation
  8. n_cores <- parallel::detectCores(all.tests=FALSE, logical=TRUE)
  9. # n_cores <- 14
  10. # short function for calculating the population SD
  11. pop_sd <- function(x) sqrt((length(x)-1)/length(x)) * sd(x)
  12. # paths to each model's correlations
  13. geom_paths <- file.path("stim_sim", "ot_geom") |>
  14. list.files("^.*\\.csv$", full.names=TRUE, include.dirs=TRUE)
  15. # neural RDM for period of interest
  16. rdm_poi <- file.path("rdm_data", "period_of_interest") |>
  17. list.files(pattern=".*\\.csv", full.names=TRUE) |>
  18. map_df(read_csv, col_types=c("subj_id"="c")) |>
  19. group_by(subj_id) |>
  20. mutate(rank_eeg_dissim = rank(eeg_dissim)) |>
  21. ungroup()
  22. # neural RDM for exploratory P1 period
  23. rdm_p1 <- file.path("rdm_data", "p1_period") |>
  24. list.files(pattern=".*\\.csv", full.names=TRUE) |>
  25. map_df(read_csv, col_types=c("subj_id"="c")) |>
  26. group_by(subj_id) |>
  27. mutate(rank_eeg_dissim = rank(eeg_dissim)) |>
  28. ungroup()
  29. for (period in c("p1", "poi")) {
  30. if (period=="p1") {
  31. rdm_period <- rdm_p1
  32. } else {
  33. rdm_period <- rdm_poi
  34. }
  35. cor_res_period <- map_df(geom_paths, function(path) {
  36. geom_lab <- tools::file_path_sans_ext(basename(path))
  37. geom_lab_vars <- if (geom_lab == "ot_pgw") {
  38. c( T=TRUE, S=TRUE, R=TRUE )
  39. } else {
  40. geom_lab |>
  41. strsplit("_", fixed=TRUE) %>%
  42. .[[1]] %>%
  43. set_names(gsub("\\d", "", .)) %>%
  44. gsub("\\D", "", .) |>
  45. sapply(as.numeric) |>
  46. sapply(as.logical)
  47. }
  48. message(sprintf("%s: %s", period, geom_lab))
  49. geom_sim <- read_csv(path, col_types = cols(char1=col_character(), char2=col_character())) |>
  50. rename(geom_dissim = ot) |>
  51. mutate(
  52. rank_geom_dissim = rank(geom_dissim)
  53. ) |>
  54. select(char1, char2, rank_geom_dissim) # only used variables
  55. if (any( sort(unique(c(rdm_period$char1, rdm_period$char2))) != sort(unique(c(geom_sim$char1, geom_sim$char2))) )) {
  56. stop("ID Mismatch")
  57. }
  58. rdm_period_geom <- left_join(rdm_period, geom_sim, by=c("char1", "char2"))
  59. # fit model
  60. lkj_prior <- 1.5
  61. m_rho_prior_full <- c(
  62. set_prior(sprintf("lkj(%s)", lkj_prior), class="rescor"),
  63. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_period_geom$rank_eeg_dissim))), class="sigma", resp="rankeegdissim", ub=max(rdm_period_geom$rank_eeg_dissim)),
  64. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_period_geom$rank_geom_dissim))), class="sigma", resp="rankgeomdissim", ub=max(rdm_period_geom$rank_geom_dissim))
  65. )
  66. m_rho_full_period <- brm(
  67. bf(
  68. mvbind(rank_eeg_dissim, rank_geom_dissim) ~ 0
  69. ) +
  70. set_rescor(rescor=TRUE),
  71. family = brmsfamily("gaussian"),
  72. iter = 10000,
  73. warmup = 5000,
  74. chains = 8,
  75. cores = n_cores,
  76. seed = 1,
  77. # centre each dimension since we don't model intercept
  78. data = mutate(rdm_period_geom, across(starts_with("rank"), function(x) x - mean(x))),
  79. prior = m_rho_prior_full,
  80. save_pars = save_pars(all=TRUE),
  81. sample_prior = FALSE,
  82. control = list(adapt_delta = 0.99)
  83. )
  84. m_rho_full_period |>
  85. as_draws_df("rescor__rankeegdissim__rankgeomdissim") |>
  86. select(-starts_with(".")) |>
  87. rename(rho = 1) |>
  88. mutate(
  89. translate = geom_lab_vars[["T"]],
  90. scale = geom_lab_vars[["S"]],
  91. rotate = geom_lab_vars[["R"]],
  92. gromov_wasserstein = geom_lab=="ot_pgw"
  93. )
  94. })
  95. saveRDS(cor_res_period, file.path("estimates", sprintf("ot_geom_%s_draws.rds", period)))
  96. }
  97. library(future)
  98. plan(multicore, workers=n_cores)
  99. # import neural rdm
  100. rdm <- file.path("rdm_data", "time_resolved") |>
  101. list.files(pattern=".*\\.csv", full.names=TRUE) |>
  102. map_df(read_csv, col_types=c("subj_id"="c")) |>
  103. group_by(subj_id, time) |>
  104. mutate(rank_eeg_dissim = rank(eeg_dissim)) |>
  105. ungroup() |>
  106. arrange(time)
  107. times <- sort(unique(rdm$time))
  108. # for each model, for each module, get the time-resolved median and HDIs for correlation estimates
  109. # - done via independent models for each time point and ANN
  110. cor_res <- map_df(geom_paths, function(path) {
  111. geom_lab <- tools::file_path_sans_ext(basename(path))
  112. geom_lab_vars <- if (geom_lab == "ot_pgw") {
  113. c( T=TRUE, S=TRUE, R=TRUE )
  114. } else {
  115. geom_lab |>
  116. strsplit("_", fixed=TRUE) %>%
  117. .[[1]] %>%
  118. set_names(gsub("\\d", "", .)) %>%
  119. gsub("\\D", "", .) |>
  120. sapply(as.numeric) |>
  121. sapply(as.logical)
  122. }
  123. message(sprintf("Time-Resolved: %s", geom_lab))
  124. geom <- read_csv(path, col_types = cols(char1=col_character(), char2=col_character())) |>
  125. rename(geom_dissim = ot) |>
  126. mutate(
  127. rank_geom_dissim = rank(geom_dissim)
  128. ) |>
  129. select(char1, char2, rank_geom_dissim) # only used variables
  130. if (any( sort(unique(c(rdm_poi$char1, rdm_poi$char2))) != sort(unique(c(geom$char1, geom$char2))) )) {
  131. stop("ID Mismatch")
  132. }
  133. rdm_geom <- left_join(geom, rdm, by=c("char1", "char2"))
  134. # correlation prior
  135. lkj_prior <- 1
  136. m_rho_prior_full <- c(
  137. set_prior(sprintf("lkj(%s)", lkj_prior), class="rescor"),
  138. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_geom$rank_eeg_dissim))), class="sigma", resp="rankeegdissim", ub=max(rdm_geom$rank_eeg_dissim)),
  139. set_prior(sprintf("constant(%s)", pop_sd(1:max(rdm_geom$rank_geom_dissim))), class="sigma", resp="rankgeomdissim", ub=max(rdm_geom$rank_geom_dissim))
  140. )
  141. # fit the models in parallel
  142. m_rho <- brm_multiple(
  143. bf(
  144. mvbind(rank_eeg_dissim, rank_geom_dissim) ~ 0
  145. ) +
  146. set_rescor(rescor=TRUE),
  147. family = brmsfamily("gaussian"),
  148. iter = 10000,
  149. warmup = 5000,
  150. chains = 8,
  151. cores = 1,
  152. seed = 1,
  153. # centre each dimension since we don't model intercept, then split by time points into list
  154. data = rdm_geom |>
  155. group_by(time) |>
  156. mutate(across(starts_with("rank"), function(x) x - mean(x))) |>
  157. ungroup() |>
  158. group_split(time),
  159. combine = FALSE,
  160. prior = m_rho_prior_full,
  161. control = list(adapt_delta = 0.99),
  162. silent = TRUE,
  163. refresh = 0
  164. )
  165. ests <- map_df(1:length(times), function(t) {
  166. m_t <- m_rho[[t]]
  167. time_t <- times[[t]]
  168. m_t |>
  169. as_draws_df("rescor__rankeegdissim__rankgeomdissim") |>
  170. rename(rho = 1) |>
  171. mutate(time = time_t)
  172. }) |>
  173. group_by(time) |>
  174. median_hdi(rho, .width=0.89)
  175. ests |>
  176. mutate(
  177. translate = geom_lab_vars[["T"]],
  178. scale = geom_lab_vars[["S"]],
  179. rotate = geom_lab_vars[["R"]],
  180. gromov_wasserstein = geom_lab=="ot_pgw"
  181. ) |>
  182. dplyr::select(translate, scale, rotate, gromov_wasserstein, time, rho, .lower, .upper) |>
  183. mutate(
  184. time = factor(time, levels=times)
  185. )
  186. })
  187. saveRDS(cor_res, file.path("estimates", "ot_geom_time_resolved.rds"))