99_make_all_ggplots.R 47 KB


  1. library(readr)
  2. library(dplyr)
  3. library(tidyr)
  4. library(purrr)
  5. library(forcats)
  6. library(ggplot2)
  7. library(scales)
  8. library(ggstance)
  9. library(ggdist)
  10. library(cowplot)
  11. library(patchwork)
  12. source(file.path("fig_code", "ggplot2_theme.R"), local=TRUE)
  13. all_chars <- c(letters, "ä", "ö", "ü", "ß")
  14. # get noise ceiling
  15. noise_ceiling_poi <- read_csv(file.path("noise_ceiling", "noise_ceiling_poi.csv"))
  16. noise_ceiling_p1 <- read_csv(file.path("noise_ceiling", "noise_ceiling_p1.csv"))
  17. noise_ceiling_time <- read_csv(file.path("noise_ceiling", "noise_ceiling_time.csv"))
  18. noise_ceiling_poi_all_chs <- read_csv(file.path("noise_ceiling", "noise_ceiling_poi_all_chs.csv"))
  19. noise_ceiling_time_all_chs <- read_csv(file.path("noise_ceiling", "noise_ceiling_time_all_chs.csv"))
  20. pl <- list() # will contain all plots
  21. poi_window <- c(150, 225)
  22. p1_window <- c(80, 130)
  23. rho_limits <- c(-0.1, max(c(noise_ceiling_time$upr, noise_ceiling_poi$upr, noise_ceiling_p1$upr)))
  24. # preregistered analysis --------------------------------------------------
  25. # plots of RDMs
  26. stim_sim <- left_join(
  27. read_csv(file.path("stim_sim", "preregistered", "jacc.csv")),
  28. read_csv(file.path("stim_sim", "preregistered", "ot.csv")),
  29. by = c("char1", "char2")
  30. ) |>
  31. left_join(
  32. read_csv(file.path("stim_sim", "complexity", "complexity.csv")),
  33. by = c("char1", "char2")
  34. ) |>
  35. mutate(
  36. rank_jacc = rank(jacc),
  37. rank_ot = rank(ot),
  38. rank_comp_dist = rank(comp_dist)
  39. ) |>
  40. rowwise() |>
  41. mutate(pair_id = paste(sort(c(char1, char2)), collapse="_")) |>
  42. ungroup() |>
  43. select(-starts_with("."))
  44. rdm_poi <- file.path("rdm_data", "period_of_interest") |>
  45. list.files(pattern=".*\\.csv", full.names=TRUE) |>
  46. map_df(read_csv, col_types=c("subj_id"="c")) |>
  47. select(-starts_with(".")) |>
  48. left_join(stim_sim, by=c("char1", "char2")) |>
  49. group_by(subj_id) |>
  50. mutate(rank_eeg_dissim = rank(eeg_dissim)) |>
  51. ungroup() |>
  52. mutate(
  53. char1 = factor(char1, levels=all_chars, ordered=TRUE),
  54. char2 = factor(char2, levels=all_chars, ordered=TRUE)
  55. )
  56. pl$rdm <- lapply(unique(rdm_poi$subj_id), function(s) {
  57. d_s <- filter(rdm_poi, subj_id==s)
  58. d_s |>
  59. bind_rows(
  60. d_s |>
  61. mutate(char1_tmp=char2, char2=char1, char1=char1_tmp) |>
  62. dplyr::select(-char1_tmp)
  63. ) |>
  64. ggplot(aes(char1, char2, fill=rank_eeg_dissim)) +
  65. geom_tile() +
  66. scale_fill_viridis_c() +
  67. coord_fixed() +
  68. scale_x_discrete(expand=c(0,0)) +
  69. scale_y_discrete(limits=rev, expand=c(0,0)) +
  70. labs(title=s, x=NULL, y=NULL, fill=NULL) +
  71. theme(
  72. plot.title = element_text(hjust=0.5),
  73. axis.line = element_blank(),
  74. legend.position = "none"
  75. )
  76. }) |>
  77. wrap_plots(nrow=3)
  78. ggsave(file.path("fig", "neural_rdm.png"), pl$rdm, width=8, height=5, device="png", type="cairo")
  79. pl$avg_rdm <- rdm_poi |>
  80. group_by(char1, char2) |>
  81. # summarise(avg_eeg_dissim = mean(eeg_dissim)) |>
  82. # mutate(avg_rank_eeg_dissim = rank(avg_eeg_dissim)) |>
  83. summarise(avg_rank_eeg_dissim = mean(rank_eeg_dissim)) |>
  84. ungroup() %>%
  85. bind_rows(
  86. .,
  87. . |>
  88. mutate(char1_tmp=char2, char2=char1, char1=char1_tmp) |>
  89. dplyr::select(-char1_tmp)
  90. ) |>
  91. ggplot(aes(char1, char2, fill=avg_rank_eeg_dissim)) +
  92. geom_tile(linewidth=0.01) +
  93. scale_fill_viridis_c() +
  94. coord_fixed() +
  95. scale_x_discrete(
  96. expand=c(0,0),
  97. breaks=all_chars[c(1, 30)]
  98. ) +
  99. scale_y_discrete(
  100. limits=rev, expand=c(0,0),
  101. breaks=all_chars[c(1, 30)]
  102. ) +
  103. guides(x.sec = "axis", y.sec = "axis") +
  104. labs(x=NULL, y=NULL, fill=NULL) +
  105. theme(
  106. plot.title = element_text(hjust=0.5),
  107. axis.line = element_line(linewidth=axis_linewidth),
  108. axis.text.x.top = element_blank(),
  109. axis.ticks.x.top = element_blank(),
  110. axis.text.y.right = element_blank(),
  111. axis.ticks.y.right = element_blank(),
  112. legend.position = "none"
  113. )
  114. pl$avg_rdm <- ggdraw(pl$avg_rdm) +
  115. draw_label("...", x = 0.5, y = 0.02, size= 10, hjust=0.5, vjust=0) +
  116. draw_label("...", x = 0.02, y = 0.5, size= 10, angle=90, hjust=0.5, vjust=0)
  117. ggsave(file.path("fig", "avg_neural_rdm.png"), pl$avg_rdm, width=2, height=2, device="png", type="cairo")
  118. # plot model RDMs for planned analysis
  119. stim_sim <- left_join(
  120. read_csv(file.path("stim_sim", "preregistered", "jacc.csv")),
  121. read_csv(file.path("stim_sim", "preregistered", "ot.csv")),
  122. by = c("char1", "char2")
  123. ) |>
  124. left_join(
  125. read_csv(file.path("stim_sim", "complexity", "complexity.csv")),
  126. by = c("char1", "char2")
  127. ) |>
  128. mutate(
  129. rank_jacc = rank(jacc),
  130. rank_ot = rank(ot),
  131. rank_comp_dist = rank(comp_dist)
  132. ) |>
  133. mutate(
  134. char1 = factor(char1, levels=all_chars, ordered=TRUE),
  135. char2 = factor(char2, levels=all_chars, ordered=TRUE)
  136. )
  137. pl$ot_rdm <- stim_sim %>%
  138. bind_rows(
  139. .,
  140. . |>
  141. mutate(char1_tmp=char2, char2=char1, char1=char1_tmp) |>
  142. dplyr::select(-char1_tmp)
  143. ) |>
  144. ggplot(aes(char1, char2, fill=rank_ot)) +
  145. geom_tile(linewidth=0.01) +
  146. scale_fill_viridis_c() +
  147. coord_fixed() +
  148. scale_x_discrete(
  149. expand=c(0,0),
  150. breaks=all_chars[c(1, 30)]
  151. ) +
  152. scale_y_discrete(
  153. limits=rev, expand=c(0,0),
  154. breaks=all_chars[c(1, 30)]
  155. ) +
  156. guides(x.sec = "axis", y.sec = "axis") +
  157. labs(x=NULL, y=NULL, fill=NULL) +
  158. theme(
  159. plot.title = element_text(hjust=0.5),
  160. axis.line = element_line(linewidth=axis_linewidth),
  161. axis.text.x.top = element_blank(),
  162. axis.ticks.x.top = element_blank(),
  163. axis.text.y.right = element_blank(),
  164. axis.ticks.y.right = element_blank(),
  165. legend.position = "none"
  166. )
  167. pl$ot_rdm <- ggdraw(pl$ot_rdm) +
  168. draw_label("...", x = 0.5, y = 0.02, size= 10, hjust=0.5, vjust=0) +
  169. draw_label("...", x = 0.02, y = 0.5, size= 10, angle=90, hjust=0.5, vjust=0)
  170. ggsave(file.path("fig", "model_rdm_ot.png"), pl$ot_rdm, width=2, height=2, device="png", type="cairo")
  171. pl$jacc_rdm <- stim_sim %>%
  172. bind_rows(
  173. .,
  174. . |>
  175. mutate(char1_tmp=char2, char2=char1, char1=char1_tmp) |>
  176. dplyr::select(-char1_tmp)
  177. ) |>
  178. ggplot(aes(char1, char2, fill=rank_jacc)) +
  179. geom_tile(linewidth=0.01) +
  180. scale_fill_viridis_c() +
  181. coord_fixed() +
  182. scale_x_discrete(
  183. expand=c(0,0),
  184. breaks=all_chars[c(1, 30)]
  185. ) +
  186. scale_y_discrete(
  187. limits=rev, expand=c(0,0),
  188. breaks=all_chars[c(1, 30)]
  189. ) +
  190. guides(x.sec = "axis", y.sec = "axis") +
  191. labs(x=NULL, y=NULL, fill=NULL) +
  192. theme(
  193. plot.title = element_text(hjust=0.5),
  194. axis.line = element_line(linewidth=axis_linewidth),
  195. axis.text.x.top = element_blank(),
  196. axis.ticks.x.top = element_blank(),
  197. axis.text.y.right = element_blank(),
  198. axis.ticks.y.right = element_blank(),
  199. legend.position = "none"
  200. )
  201. pl$jacc_rdm <- ggdraw(pl$jacc_rdm) +
  202. draw_label("...", x = 0.5, y = 0.02, size= 10, hjust=0.5, vjust=0) +
  203. draw_label("...", x = 0.02, y = 0.5, size= 10, angle=90, hjust=0.5, vjust=0)
  204. ggsave(file.path("fig", "model_rdm_jacc.png"), pl$jacc_rdm, width=2, height=2, device="png", type="cairo")
  205. # plots of the planned comparison
  206. prereg_cor_samps_long <- readRDS(file.path("estimates", "prereg_cor_samps_long.rds")) |>
  207. ungroup() |>
  208. mutate(summ_y_pos = case_match(
  209. cor_lab,
  210. "pcor_rankeegdissim_rankjacc" ~ -10,
  211. "pcor_rankeegdissim_rankot" ~ -12,
  212. "cor_rankeegdissim_rankjacc" ~ -5,
  213. "cor_rankeegdissim_rankot" ~ -7
  214. ))
  215. prereg_bf <- prereg_cor_samps_long |>
  216. pivot_wider(id_cols=starts_with("."), names_from=cor_lab, values_from=Rho) |>
  217. ungroup() |>
  218. select(-starts_with(".")) |>
  219. summarise(
  220. bf1 = sum(cor_rankeegdissim_rankot > 0) /
  221. sum(cor_rankeegdissim_rankot <= 0),
  222. bf2a = sum(cor_rankeegdissim_rankot > cor_rankeegdissim_rankjacc) /
  223. sum(cor_rankeegdissim_rankot <= cor_rankeegdissim_rankjacc),
  224. bf2b = sum(pcor_rankeegdissim_rankot > pcor_rankeegdissim_rankjacc) /
  225. sum(pcor_rankeegdissim_rankot <= pcor_rankeegdissim_rankjacc)
  226. )
  227. pl$posterior_plot_part <- prereg_cor_samps_long |>
  228. mutate(corr_lab = ifelse(is_partial, "Partial Correlation", "Correlation")) |>
  229. ggplot(aes(Rho, colour=model)) +
  230. annotate(geom="rect", xmin=noise_ceiling_poi$lwr, xmax=noise_ceiling_poi$upr, ymin=-Inf, ymax=Inf, colour=NA, fill="lightgrey") +
  231. # geom_density(aes(Rho, colour=model, linetype="Partial Correlation"), linewidth=0.5, trim=TRUE, filter(cor_samps_long, is_partial), key_glyph="path") +
  232. # geom_density(aes(Rho, colour=model, linetype="Correlation"), linewidth=0.5, trim=TRUE, filter(cor_samps_long, !is_partial), key_glyph="path") +
  233. geom_density(aes(Rho, colour=model, linetype=corr_lab), linewidth=0.5, trim=TRUE, key_glyph="path") +
  234. stat_pointinterval(aes(y=summ_y_pos), point_interval=median_hdi, .width = c(.5, .89), point_size=0, data=filter(prereg_cor_samps_long, is_partial), show.legend=FALSE) +
  235. stat_pointinterval(aes(y=summ_y_pos), point_interval=median_hdi, .width = c(.5, .89), point_size=0, data=filter(prereg_cor_samps_long, !is_partial), show.legend=FALSE) +
  236. geom_vline(xintercept=0, linetype="dashed", linewidth=axis_linewidth) +
  237. # scale_fill_manual(values=measure_cols_light, guide="none") +
  238. scale_colour_manual(values=measure_cols) +
  239. scale_linetype_manual(values = c("solid", "dashed")) +
  240. scale_y_continuous(limits = c(-13, NA), breaks=seq(0, 30, 10)) +
  241. labs(
  242. x = "ρ", y = "Posterior Density",
  243. colour=NULL, fill=NULL,
  244. linewidth=NULL, linetype=NULL
  245. ) +
  246. guides(
  247. linetype = guide_legend(order = 2),
  248. colour = guide_legend(order = 1),
  249. linewidth = "none"
  250. ) +
  251. theme(
  252. legend.position="right",
  253. legend.direction = "vertical",
  254. legend.spacing.y = unit(3, "pt"),
  255. legend.margin = margin(),
  256. legend.box.spacing = unit(0, "pt"),
  257. legend.box.margin = margin()
  258. )
  259. ggsave(file.path("fig", "posterior_poi_partial.png"), pl$posterior_plot_part, width=5, height=1.55, device="png", type="cairo")
  260. ggsave(file.path("fig", "posterior_poi_partial.pdf"), pl$posterior_plot_part, width=5, height=1.55, device=cairo_pdf)
  261. ggsave(file.path("fig", "posterior_poi_partial.pdf"), pl$posterior_plot_part, width=5, height=1.55, device=cairo_pdf) # save again (seems to address bug in legend spacing?)
  262. ggsave(file.path("fig", "posterior_poi_partial.svg"), pl$posterior_plot_part, width=4.75, height=1.2)
  263. # timecourse of the planned comparison
  264. tc <- readRDS(file.path("estimates", "planned_timecourse.rds")) |>
  265. ungroup() |>
  266. select(-starts_with("."))
  267. # calculate time-resolved BFs
  268. na2zero <- function(x) {
  269. x[is.na(x)] <- 0
  270. x
  271. }
  272. tc_bf <- tc |>
  273. group_by(time) |>
  274. summarise(
  275. bf1 = sum(cor_rankeegdissim_rankot > 0) /
  276. sum(cor_rankeegdissim_rankot <= 0),
  277. bf2a = sum(cor_rankeegdissim_rankot > cor_rankeegdissim_rankjacc ) /
  278. sum(cor_rankeegdissim_rankot <= cor_rankeegdissim_rankjacc),
  279. bf2b = sum(pcor_rankeegdissim_rankot > pcor_rankeegdissim_rankjacc ) /
  280. sum(pcor_rankeegdissim_rankot <= pcor_rankeegdissim_rankjacc),
  281. n_samps = n()
  282. ) |>
  283. pivot_longer(cols=c(bf1, bf2a, bf2b), names_to="hypothesis", values_to="bf", names_prefix="bf") |>
  284. group_by(hypothesis) |>
  285. mutate(
  286. time_ms = time * 1000,
  287. infinite_BF = (is.infinite(bf)|bf==0),
  288. infinite_section_start = infinite_BF & !lag(infinite_BF),
  289. infinite_section_end = infinite_BF & !lead(infinite_BF),
  290. is_section = cumsum(na2zero(infinite_section_start - lag(infinite_section_end))),
  291. inf_section_nr = ifelse(is_section==0, 0, cumsum(infinite_section_start))
  292. ) |>
  293. ungroup() |>
  294. mutate( bf_bilinear = ifelse(bf<1, -1/bf+1, bf-1) )
  295. # get credible intervals for Rho
  296. tc_cr_i <- tc |>
  297. pivot_longer(cols=c(starts_with("cor"), starts_with("pcor")), names_to="cor_lab", values_to="Rho") |>
  298. mutate(
  299. partial = grepl("^pcor_", cor_lab),
  300. cor_lab = sub("(^cor_)|(^pcor_)", "", cor_lab)
  301. ) |>
  302. group_by(cor_lab, partial, time) |>
  303. median_hdi(.width=.89) |>
  304. mutate(measure = recode(
  305. cor_lab,
  306. rankeegdissim_rankot = "Wasserstein Distance",
  307. rankeegdissim_rankjacc = "Jaccard Distance"
  308. ))
  309. # plot timecourse
  310. ylims <- c(
  311. round(min(c(tc_cr_i$.lower, noise_ceiling_time$lwr)) - 0.005, 2),
  312. round(max(c(tc_cr_i$.upper, noise_ceiling_time$upr)) + 0.005, 2)
  313. )
  314. tc_rho_pl <- lapply(c(FALSE, TRUE), function(is_partial_pl) {
  315. # x_title <- if (is_partial_pl) "Time (ms)" else NULL
  316. # x_ticks <- if (is_partial_pl) element_line() else element_blank()
  317. # x_tick_labs <- if (is_partial_pl) element_text(angle=45, hjust=1, vjust=1) else element_blank()
  318. y_title <- if(is_partial_pl) "Partial ρ" else "ρ"
  319. # linetype <- if(is_partial_pl) "dotted" else "solid"
  320. linetype <- "solid"
  321. x_title <- NULL
  322. x_ticks <- element_blank()
  323. x_tick_labs <- element_blank()
  324. tag <- NULL
  325. tc_cr_i |>
  326. filter(partial==is_partial_pl) |>
  327. arrange(desc(.width)) |>
  328. mutate(time_ms = time * 1000) |>
  329. ggplot() +
  330. geom_ribbon(aes(x=time_ms, ymin=lwr, ymax=upr), data=mutate(noise_ceiling_time, time_ms=time*1000, Rho=NA), fill="lightgrey") +
  331. geom_hline(yintercept=0, linewidth=axis_linewidth) +
  332. geom_vline(xintercept=0, linewidth=axis_linewidth) +
  333. geom_ribbon(aes(time_ms, Rho, fill=measure, group=measure, ymin=.lower, ymax=.upper), alpha=0.4, colour=NA) +
  334. geom_line(aes(time_ms, Rho, colour=measure, group=measure), linetype=linetype) +
  335. scale_fill_manual(values = measure_cols) +
  336. scale_colour_manual(values = measure_cols) +
  337. scale_x_continuous(expand=c(0,0), limits=c(-200, 1000), breaks=seq(-200, 1000, 100)) +
  338. scale_y_continuous(expand=c(0,0), limits=ylims) +
  339. labs(x = x_title, y = y_title, colour=NULL, fill=NULL, tag=tag) +
  340. theme(
  341. # legend.position = if(is_partial_pl) "none" else c(1, 1.2),
  342. legend.position = "inside",
  343. legend.position.inside = c(1, 1),
  344. legend.direction = "horizontal",
  345. legend.justification = c(1, 0),
  346. legend.background = element_blank(),
  347. plot.margin = margin(0,10,0,0, unit="pt"),
  348. legend.key.height = unit(10, units="pt"),
  349. legend.margin = margin(0,0,0,0),
  350. axis.ticks.x = x_ticks,
  351. axis.text.x = x_tick_labs,
  352. axis.line = element_blank(),
  353. plot.background = element_blank()
  354. )
  355. })
  356. tc_pl_joined <- list(
  357. tc_rho_pl[[1]],
  358. tc_rho_pl[[2]] +
  359. labs(x = "Time (ms)") +
  360. theme(
  361. legend.position="none",
  362. axis.ticks.x = element_line(),
  363. axis.text.x = element_text(),
  364. axis.line.x = element_blank()
  365. )
  366. ) |>
  367. wrap_plots(ncol=1, heights=c(2.5, 2.5)) +
  368. plot_annotation(tag_levels = "a")
  369. ggsave(file.path("fig", "tc_joined_CrI.pdf"), tc_pl_joined, width=4.5, height=3, device=cairo_pdf)
  370. ggsave(file.path("fig", "tc_joined_CrI.png"), tc_pl_joined, width=4.5, height=3, device="png", type="cairo")
  371. # tc_pl_joined_fewer_ticks <- list(
  372. # tc_rho_pl[[1]],
  373. # tc_rho_pl[[2]] +
  374. # labs(x = "Time (ms)") +
  375. # scale_x_continuous(expand=c(0,0), limits=c(-200, 1000), breaks=seq(-200, 1000, 100), labels=ifelse(seq(-200, 1000, 100)%%200==0, seq(-200, 1000, 100), "")) +
  376. # theme(
  377. # legend.position="none",
  378. # axis.ticks.x = element_line(),
  379. # axis.text.x = element_text(),
  380. # axis.line.x = element_blank()
  381. # )
  382. # ) |>
  383. # wrap_plots(ncol=1, heights=c(2.5, 2.5)) +
  384. # plot_annotation(tag_levels = "a")
  385. ggsave(file.path("fig", "tc_joined_CrI.svg"), tc_pl_joined, width=5, height=3.5, device="svg")
  386. bf_lims <- max(c(tc_bf$bf[!is.infinite(tc_bf$bf)], 1/tc_bf$bf[!is.infinite(1/tc_bf$bf)]))
  387. bf_ticks <- c(10000, 100, 1, 1/100, 1/10000)
  388. bf_labs <- c("10000", "100", "1", "1/100", "1/10000")
  389. tc_bf_h2a <- filter(tc_bf, hypothesis=="2a") |>
  390. mutate(
  391. bf_bilinear = ifelse(infinite_BF, n_samps-1, bf)
  392. )
  393. bf_pl_bilinear <- tc_bf_h2a |>
  394. ggplot(aes(time_ms, bf_bilinear)) +
  395. annotate("rect", xmin=-Inf, xmax=Inf, ymin=1, ymax=Inf, fill=measure_cols[["Wasserstein Distance"]], alpha=0.4) +
  396. annotate("rect", xmin=-Inf, xmax=Inf, ymin=-Inf, ymax=0, fill=measure_cols[["Jaccard Distance"]], alpha=0.4) +
  397. geom_line() +
  398. geom_line(aes(group=inf_section_nr), y=bf_lims, colour="red", data=filter(tc_bf_h2a, infinite_BF)) +
  399. geom_hline(yintercept=1, linewidth=axis_linewidth) +
  400. geom_vline(xintercept=0, linewidth=axis_linewidth) +
  401. scale_x_continuous(expand=c(0,0), limits=c(-200, 1000), breaks=seq(-200, 1000, 100)) +
  402. scale_y_continuous(limits = c(-9999, bf_lims), breaks=c(-9999, 0, 9999, 19999, 29999, 39999), labels=c("1/10000", "1", "10000", "20000", "30000", "40000")) +
  403. labs(x = "Time (ms)", y = "BF") +
  404. theme(
  405. legend.position = "none",
  406. legend.margin = margin(),
  407. legend.box.margin = margin(),
  408. plot.margin = margin(0,10,2.5,2.5, unit="pt"),
  409. axis.line.x = element_blank(),
  410. axis.line.y = element_blank(),
  411. plot.background = element_blank()
  412. )
  413. tc_pl_joined_bf <- list(
  414. tc_rho_pl[[1]],
  415. plot_spacer(),
  416. tc_rho_pl[[2]] +
  417. theme(legend.position = "none"),
  418. plot_spacer(),
  419. bf_pl_bilinear
  420. ) |>
  421. wrap_plots(ncol=1, heights=c(2.5, -0.3, 2.5, -0.3, 1.75)) +
  422. plot_annotation(tag_levels = "a")
  423. ggsave(file.path("fig", "tc_joined_CrI_bf.pdf"), tc_pl_joined_bf, width=3.9, height=3, device=cairo_pdf)
  424. # geom results ---------------------------------------------------------
  425. geom_palette <- viridisLite::plasma
  426. geom_palette_1_9 <- geom_palette(n=9, end=0.85)[1:9]
  427. geom_palette_1_8 <- geom_palette(n=9, end=0.85)[1:8]
  428. interval_size_range_geom <- c(0.75, 2.5)
  429. # geom_fct_levels <- c("---", "--R", "-S-", "-SR", "T--", "T-R", "TS-", "TSR", "GW")
  430. # geom_fct_labels <- c("-", "R", "S", "SR", "T", "TR", "TS", "TSR", "G-W")
  431. geom_fct_levels <- c("---", "--R", "-S-", "T--", "-SR", "T-R", "TS-", "TSR", "GW")
  432. geom_fct_labels <- c("-", "R", "S", "T", "RS", "RT", "ST", "RST", "G-W")
  433. ot_geom_time_resolved <- readRDS(file.path("estimates", "ot_geom_time_resolved.rds")) |>
  434. ungroup() |>
  435. select(-starts_with(".")) |>
  436. arrange(time, translate, scale, rotate, gromov_wasserstein) |>
  437. mutate(
  438. time_ms = as.numeric(as.character(time)) * 1000,
  439. geom_label = factor(ifelse(
  440. gromov_wasserstein,
  441. "GW",
  442. paste(
  443. ifelse(translate, "T", "-"),
  444. ifelse(scale, "S", "-"),
  445. ifelse(rotate, "R", "-"),
  446. sep = ""
  447. )
  448. ),
  449. levels = geom_fct_levels,
  450. labels = geom_fct_labels)
  451. )
  452. jacc_geom_time_resolved <- readRDS(file.path("estimates", "jacc_geom_time_resolved.rds")) |>
  453. ungroup() |>
  454. select(-starts_with(".")) |>
  455. arrange(time, translate, scale, rotate) |>
  456. mutate(
  457. time_ms = as.numeric(as.character(time)) * 1000,
  458. geom_label = factor(
  459. paste(
  460. ifelse(translate, "T", "-"),
  461. ifelse(scale, "S", "-"),
  462. ifelse(rotate, "R", "-"),
  463. sep = ""
  464. ),
  465. levels = geom_fct_levels[1:8],
  466. labels = geom_fct_labels[1:8])
  467. )
  468. ot_geom_poi_res <- readRDS(file.path("estimates", "ot_geom_poi_draws.rds")) |>
  469. ungroup() |>
  470. select(-starts_with(".")) |>
  471. arrange(translate, scale, rotate, gromov_wasserstein) |>
  472. mutate(
  473. geom_label = factor(ifelse(
  474. gromov_wasserstein,
  475. "GW",
  476. paste(
  477. ifelse(translate, "T", "-"),
  478. ifelse(scale, "S", "-"),
  479. ifelse(rotate, "R", "-"),
  480. sep = ""
  481. )
  482. ),
  483. levels = geom_fct_levels,
  484. labels = geom_fct_labels)
  485. )
  486. jacc_geom_poi_res <- readRDS(file.path("estimates", "jacc_geom_poi_draws.rds")) |>
  487. ungroup() |>
  488. select(-starts_with(".")) |>
  489. arrange(translate, scale, rotate) |>
  490. mutate(
  491. geom_label = factor(
  492. paste(
  493. ifelse(translate, "T", "-"),
  494. ifelse(scale, "S", "-"),
  495. ifelse(rotate, "R", "-"),
  496. sep = ""
  497. ),
  498. levels = geom_fct_levels[1:8],
  499. labels = geom_fct_labels[1:8])
  500. )
  501. ot_geom_p1_res <- readRDS(file.path("estimates", "ot_geom_p1_draws.rds")) |>
  502. ungroup() |>
  503. select(-starts_with(".")) |>
  504. arrange(translate, scale, rotate, gromov_wasserstein) |>
  505. mutate(
  506. geom_label = factor(ifelse(
  507. gromov_wasserstein,
  508. "GW",
  509. paste(
  510. ifelse(translate, "T", "-"),
  511. ifelse(scale, "S", "-"),
  512. ifelse(rotate, "R", "-"),
  513. sep = ""
  514. )
  515. ),
  516. levels = geom_fct_levels,
  517. labels = geom_fct_labels)
  518. )
  519. jacc_geom_p1_res <- readRDS(file.path("estimates", "jacc_geom_p1_draws.rds")) |>
  520. ungroup() |>
  521. select(-starts_with(".")) |>
  522. arrange(translate, scale, rotate) |>
  523. mutate(
  524. geom_label = factor(
  525. paste(
  526. ifelse(translate, "T", "-"),
  527. ifelse(scale, "S", "-"),
  528. ifelse(rotate, "R", "-"),
  529. sep = ""
  530. ),
  531. levels = geom_fct_levels[1:8],
  532. labels = geom_fct_labels[1:8])
  533. )
  534. pl$ot_geom_time_res <- ot_geom_time_resolved |>
  535. ggplot() +
  536. geom_ribbon(aes(x=time_ms, ymin=lwr, ymax=upr), data=mutate(noise_ceiling_time, time_ms=time*1000, Rho=NA), fill="lightgrey") +
  537. geom_line(aes(time_ms, rho, colour=geom_label)) +
  538. geom_vline(xintercept=0, linewidth=axis_linewidth) +
  539. geom_hline(yintercept=0, linewidth=axis_linewidth) +
  540. scale_x_continuous(expand=c(0,0), limits=c(-200, 1000), breaks=seq(-200, 1000, 200)) +
  541. scale_y_continuous(limits=rho_limits, expand=c(0,0)) +
  542. scale_colour_viridis_d() +
  543. scale_colour_manual(values=geom_palette_1_9) +
  544. labs(
  545. x = "Time (ms)",
  546. y = "ρ"
  547. ) +
  548. theme(
  549. axis.line = element_blank(),
  550. legend.position = "none"
  551. )
  552. pl$ot_geom_poi_res <- ot_geom_poi_res |>
  553. ggplot(aes(geom_label, rho, colour=geom_label)) +
  554. annotate(geom="rect", ymin=noise_ceiling_poi$lwr, ymax=noise_ceiling_poi$upr, xmin=-Inf, xmax=Inf, colour=NA, fill="lightgrey") +
  555. stat_pointinterval(point_interval="median_hdi", .width=c(.5, .89), shape=NA, interval_size_range=interval_size_range_geom) +
  556. geom_hline(yintercept=0, linewidth=axis_linewidth) +
  557. geom_vline(xintercept=8.5, linewidth=axis_linewidth, linetype="dashed") +
  558. # scale_colour_viridis_d() +
  559. scale_colour_manual(values=geom_palette_1_9) +
  560. scale_y_continuous(limits=rho_limits, expand=c(0,0)) +
  561. scale_x_discrete() +
  562. labs(
  563. x = "Transformations",
  564. y = "ρ"
  565. ) +
  566. theme(
  567. legend.position = "none",
  568. strip.text = element_blank(),
  569. axis.title.x = element_blank(),
  570. axis.text.x = element_text(hjust=1, vjust=0.5, angle=90),
  571. axis.line.x = element_blank(),
  572. axis.line.y = element_blank(),
  573. axis.title.y = element_blank(),
  574. axis.text.y = element_blank(),
  575. axis.ticks.y = element_blank()
  576. )
  577. pl$ot_geom_p1_res <- ot_geom_p1_res |>
  578. ggplot(aes(geom_label, rho, colour=geom_label)) +
  579. annotate(geom="rect", ymin=noise_ceiling_p1$lwr, ymax=noise_ceiling_p1$upr, xmin=-Inf, xmax=Inf, colour=NA, fill="lightgrey") +
  580. stat_pointinterval(point_interval="median_hdi", .width=c(.5, .89), shape=NA, interval_size_range=interval_size_range_geom) +
  581. geom_hline(yintercept=0, linewidth=axis_linewidth) +
  582. geom_vline(xintercept=8.5, linewidth=axis_linewidth, linetype="dashed") +
  583. # scale_colour_viridis_d() +
  584. scale_colour_manual(values=geom_palette_1_9) +
  585. scale_y_continuous(limits=rho_limits, expand=c(0,0)) +
  586. scale_x_discrete() +
  587. labs(
  588. x = "Transformations",
  589. y = "ρ"
  590. ) +
  591. theme(
  592. legend.position = "none",
  593. strip.text = element_blank(),
  594. axis.title.x = element_blank(),
  595. axis.text.x = element_text(hjust=1, vjust=0.5, angle=90),
  596. axis.line.x = element_blank()
  597. )
  598. pl$jacc_geom_time_res <- jacc_geom_time_resolved |>
  599. ggplot() +
  600. geom_ribbon(aes(x=time_ms, ymin=lwr, ymax=upr), data=mutate(noise_ceiling_time, time_ms=time*1000, Rho=NA), fill="lightgrey") +
  601. geom_line(aes(time_ms, rho, colour=geom_label)) +
  602. geom_vline(xintercept=0, linewidth=axis_linewidth) +
  603. geom_hline(yintercept=0, linewidth=axis_linewidth) +
  604. scale_x_continuous(expand=c(0,0), limits=c(-200, 1000), breaks=seq(-200, 1000, 200)) +
  605. scale_y_continuous(limits=rho_limits, expand=c(0,0)) +
  606. # scale_colour_viridis_d() +
  607. scale_colour_manual(values=geom_palette_1_8) +
  608. labs(
  609. x = "Time (ms)",
  610. y = "ρ"
  611. ) +
  612. theme(
  613. axis.line = element_blank(),
  614. legend.position = "none"
  615. )
  616. pl$jacc_geom_poi_res <- jacc_geom_poi_res |>
  617. ggplot(aes(geom_label, rho, colour=geom_label)) +
  618. annotate(geom="rect", ymin=noise_ceiling_poi$lwr, ymax=noise_ceiling_poi$upr, xmin=-Inf, xmax=Inf, colour=NA, fill="lightgrey") +
  619. stat_pointinterval(point_interval="median_hdi", .width=c(.5, .89), shape=NA, interval_size_range=interval_size_range_geom) +
  620. geom_hline(yintercept=0, linewidth=axis_linewidth) +
  621. # scale_colour_viridis_d() +
  622. scale_colour_manual(values=geom_palette_1_8) +
  623. scale_y_continuous(limits=rho_limits, expand=c(0,0)) +
  624. scale_x_discrete() +
  625. labs(
  626. x = "Transformations",
  627. y = "ρ"
  628. ) +
  629. theme(
  630. legend.position = "none",
  631. strip.text = element_blank(),
  632. axis.title.x = element_blank(),
  633. axis.text.x = element_text(hjust=1, vjust=0.5, angle=90),
  634. axis.line.x = element_blank(),
  635. axis.line.y = element_blank(),
  636. axis.title.y = element_blank(),
  637. axis.text.y = element_blank(),
  638. axis.ticks.y = element_blank()
  639. )
  640. pl$jacc_geom_p1_res <- jacc_geom_p1_res |>
  641. ggplot(aes(geom_label, rho, colour=geom_label)) +
  642. annotate(geom="rect", ymin=noise_ceiling_p1$lwr, ymax=noise_ceiling_p1$upr, xmin=-Inf, xmax=Inf, colour=NA, fill="lightgrey") +
  643. stat_pointinterval(point_interval="median_hdi", .width=c(.5, .89), shape=NA, interval_size_range=interval_size_range_geom) +
  644. geom_hline(yintercept=0, linewidth=axis_linewidth) +
  645. # scale_colour_viridis_d() +
  646. scale_colour_manual(values=geom_palette_1_8) +
  647. scale_y_continuous(limits=rho_limits, expand=c(0,0)) +
  648. scale_x_discrete() +
  649. labs(
  650. x = "Transformations",
  651. y = "ρ"
  652. ) +
  653. theme(
  654. legend.position = "none",
  655. strip.text = element_blank(),
  656. axis.title.x = element_blank(),
  657. axis.text.x = element_text(hjust=1, vjust=0.5, angle=90),
  658. axis.line.x = element_blank()
  659. )
  660. # Joined geom results -----------------------------------------------------
  661. # POI and time course
  662. pl$wass_lab <- ggplot() +
  663. annotate("text", x=0, y=0, size=8/.pt, label="Wasserstein\nDistance", family="Helvetica") +
  664. theme_void()
  665. pl$jacc_lab <- ggplot() +
  666. annotate("text", x=0, y=0, size=8/.pt, label="Jaccard\nDistance", family="Helvetica") +
  667. theme_void()
  668. pl$jacc_ot_geom_poi_joined <- wrap_plots(list(
  669. pl$wass_lab + theme(plot.margin = margin(0, 0, 0, 0, "pt")),
  670. (pl$ot_geom_poi_res + labs(tag="a1") + theme(plot.margin = margin(0, 0, 0, 0, "pt"))),
  671. (pl$ot_geom_time_res + labs(tag="a2") + theme(plot.margin = margin(0, 10, 0, 0, "pt"))),
  672. pl$jacc_lab + theme(plot.margin = margin(0, 0, 0, 0, "pt")),
  673. (pl$jacc_geom_poi_res + labs(tag="b1") + theme(plot.margin = margin(0, 0, 0, 0, "pt"))),
  674. (pl$jacc_geom_time_res + labs(tag="b2") + theme(plot.margin = margin(0, 0, 0, 0, "pt")))
  675. ), widths=c(0.425, 0.75, 2)) +
  676. theme(plot.background = element_blank())
  677. ggsave(file.path("fig", "jacc_ot_geom_poi_joined.png"), pl$jacc_ot_geom_poi_joined, width=6, height=3.5, device="png", type="cairo")
  678. ggsave(file.path("fig", "jacc_ot_geom_poi_joined.pdf"), pl$jacc_ot_geom_poi_joined, width=6, height=3.5, device=cairo_pdf)
  679. # POI, P1, Timecourse
  680. # noise_ceiling_lines_df <- tibble(
  681. # x=rep(c(p1_window, poi_window), each=2),
  682. # y=rep(c(0.325, max(rho_limits)), 4)
  683. # )
  684. pl$jacc_ot_geom_poi_p1_joined <- (
  685. pl$ot_geom_p1_res + labs(title=sprintf("Wasserstein Distance\n\n%s ms", paste(p1_window, collapse="-")), tag="a") + theme(plot.title=element_text(hjust=0, size=8, vjust=0)) |
  686. pl$ot_geom_poi_res + labs(title=sprintf("%s ms", paste(poi_window, collapse="-"))) + theme(plot.title=element_text(hjust=0, size=8, vjust=0)) |
  687. pl$jacc_geom_p1_res + labs(title=sprintf("Jaccard Distance\n\n%s ms", paste(p1_window, collapse="-")), tag="b") + theme(plot.title=element_text(hjust=0, size=8, vjust=0)) |
  688. pl$jacc_geom_poi_res + labs(title=sprintf("%s ms", paste(poi_window, collapse="-"))) + theme(plot.title=element_text(hjust=0, size=8, vjust=0))
  689. ) /
  690. (
  691. pl$ot_geom_time_res +
  692. labs(title="Time-Resolved") +
  693. theme(plot.margin = margin(10, 10, 0, 0, "pt"), plot.title=element_text(hjust=0, size=8)) |
  694. pl$jacc_geom_time_res +
  695. labs(title="Time-Resolved") +
  696. theme(plot.margin = margin(10, 10, 0, 0, "pt"), plot.title=element_text(hjust=0, size=8))
  697. ) +
  698. plot_layout(heights = c(3, 3)) +
  699. theme(plot.background = element_blank())
  700. ggsave(file.path("fig", "jacc_ot_geom_poi_p1_joined.png"), pl$jacc_ot_geom_poi_p1_joined, width=6, height=3, device="png", type="cairo")
  701. ggsave(file.path("fig", "jacc_ot_geom_poi_p1_joined.pdf"), pl$jacc_ot_geom_poi_p1_joined, width=6, height=3, device=cairo_pdf)
  702. # ANN results -------------------------------------------------------------
  703. ann_palette <- viridisLite::plasma
  704. ann_time_cor_res <- readRDS(file.path("estimates", "ANNs_time_resolved.rds")) |>
  705. ungroup() |>
  706. select(-starts_with(".")) |>
  707. mutate(
  708. time_ms = as.numeric(as.character(time)) * 1000,
  709. training_label = factor(case_when(
  710. grepl("_imagenet_letters$", model) ~ "Imagenet\n+ Letters",
  711. grepl("_imagenet$", model) ~ "Imagenet",
  712. grepl("_letters$", model) ~ "Letters"
  713. ), levels = c("Letters", "Imagenet", "Imagenet\n+ Letters")),
  714. model_label = factor(case_when(
  715. grepl("resnet50", model) ~ "ResNet-50",
  716. grepl("cornet-z", model) ~ "CORnet-Z"
  717. ), levels = c("ResNet-50", "CORnet-Z")),
  718. layer_label = sub("^layer", "", layer),
  719. layer_level = case_match(
  720. layer_label,
  721. "V1" ~ 1,
  722. "V2" ~ 2,
  723. "V4" ~ 3,
  724. "IT" ~ 4,
  725. .default = as.numeric(layer_label)
  726. )
  727. )
  728. ann_poi_res <- readRDS(file.path("estimates", "ANNs_poi_draws.rds")) |>
  729. ungroup() |>
  730. select(-starts_with(".")) |>
  731. mutate(
  732. training_label = factor(case_when(
  733. grepl("_imagenet_letters$", model) ~ "Imagenet\n+ Letters",
  734. grepl("_imagenet$", model) ~ "Imagenet",
  735. grepl("_letters$", model) ~ "Letters"
  736. ), levels = c("Letters", "Imagenet", "Imagenet\n+ Letters")),
  737. model_label = factor(case_when(
  738. grepl("resnet50", model) ~ "ResNet-50",
  739. grepl("cornet-z", model) ~ "CORnet-Z"
  740. ), levels = c("ResNet-50", "CORnet-Z")),
  741. layer_label = sub("^layer", "", layer),
  742. layer_level = case_match(
  743. layer_label,
  744. "V1" ~ 1,
  745. "V2" ~ 2,
  746. "V4" ~ 3,
  747. "IT" ~ 4,
  748. .default = as.numeric(layer_label)
  749. )
  750. )
  751. ann_p1_res <- readRDS(file.path("estimates", "ANNs_p1_draws.rds")) |>
  752. ungroup() |>
  753. select(-starts_with(".")) |>
  754. mutate(
  755. training_label = factor(case_when(
  756. grepl("_imagenet_letters$", model) ~ "Imagenet\n+ Letters",
  757. grepl("_imagenet$", model) ~ "Imagenet",
  758. grepl("_letters$", model) ~ "Letters"
  759. ), levels = c("Letters", "Imagenet", "Imagenet\n+ Letters")),
  760. model_label = factor(case_when(
  761. grepl("resnet50", model) ~ "ResNet-50",
  762. grepl("cornet-z", model) ~ "CORnet-Z"
  763. ), levels = c("ResNet-50", "CORnet-Z")),
  764. layer_label = sub("^layer", "", layer),
  765. layer_level = case_match(
  766. layer_label,
  767. "V1" ~ 1,
  768. "V2" ~ 2,
  769. "V4" ~ 3,
  770. "IT" ~ 4,
  771. .default = as.numeric(layer_label)
  772. )
  773. )
  774. ann_p1_pl_list <- lapply(levels(ann_time_cor_res$model_label), function(m_i) {
  775. d_i <- ann_p1_res |>
  776. filter(model_label == m_i)
  777. n_layers <- length(unique(d_i$layer))
  778. ann_palette_i <- ann_palette(n=n_layers, end=0.85)
  779. if (n_layers > 4) {
  780. xscale <- scale_x_discrete(labels = ~ifelse(grepl("\\.0$", .x), .x, ""))
  781. interval_size_range_ann <- c(0.75, 1.5)
  782. } else {
  783. xscale <- scale_x_discrete()
  784. interval_size_range_ann <- c(0.75, 2.5)
  785. }
  786. d_i |>
  787. mutate(layer_level = as.factor(layer_level)) |>
  788. arrange(layer_level) |>
  789. mutate(layer_label = factor(layer_label, levels=unique(layer_label))) |>
  790. ggplot(aes(x=layer_label, y=rho, colour=layer_label)) +
  791. annotate(geom="rect", ymin=noise_ceiling_p1$lwr, ymax=noise_ceiling_p1$upr, xmin=-Inf, xmax=Inf, colour=NA, fill="lightgrey") +
  792. stat_pointinterval(point_interval="median_hdi", .width=c(.5, .89), shape=NA, interval_size_range=interval_size_range_ann) +
  793. geom_hline(yintercept=0, linewidth=axis_linewidth) +
  794. # scale_colour_viridis_d() +
  795. scale_colour_manual(values=ann_palette_i) +
  796. scale_y_continuous(limits=rho_limits, expand=c(0,0)) +
  797. xscale +
  798. labs(
  799. x = "Layer",
  800. y = "ρ"
  801. ) +
  802. theme(
  803. legend.position = "none",
  804. strip.text = element_blank(),
  805. # axis.text.x = element_text(hjust=1, vjust=0.5, angle=90),
  806. axis.line.x = element_blank()
  807. ) +
  808. facet_grid(rows = vars(training_label))
  809. })
  810. ann_poi_pl_list <- lapply(levels(ann_time_cor_res$model_label), function(m_i) {
  811. d_i <- ann_poi_res |>
  812. filter(model_label == m_i)
  813. n_layers <- length(unique(d_i$layer))
  814. ann_palette_i <- ann_palette(n=n_layers, end=0.85)
  815. if (n_layers > 4) {
  816. xscale <- scale_x_discrete(labels = ~ifelse(grepl("\\.0$", .x), .x, ""))
  817. interval_size_range_ann <- c(0.75, 1.5)
  818. } else {
  819. xscale <- scale_x_discrete()
  820. interval_size_range_ann <- c(0.75, 2.5)
  821. }
  822. d_i |>
  823. mutate(layer_level = as.factor(layer_level)) |>
  824. arrange(layer_level) |>
  825. mutate(layer_label = factor(layer_label, levels=unique(layer_label))) |>
  826. ggplot(aes(x=layer_label, y=rho, colour=layer_label)) +
  827. annotate(geom="rect", ymin=noise_ceiling_poi$lwr, ymax=noise_ceiling_poi$upr, xmin=-Inf, xmax=Inf, colour=NA, fill="lightgrey") +
  828. stat_pointinterval(point_interval="median_hdi", .width=c(.5, .89), shape=NA, interval_size_range=interval_size_range_ann) +
  829. geom_hline(yintercept=0, linewidth=axis_linewidth) +
  830. # scale_colour_viridis_d() +
  831. scale_colour_manual(values=ann_palette_i) +
  832. scale_y_continuous(limits=rho_limits, expand=c(0,0)) +
  833. xscale +
  834. labs(
  835. x = "Layer",
  836. y = "ρ"
  837. ) +
  838. theme(
  839. legend.position = "none",
  840. strip.text = element_blank(),
  841. # axis.text.x = element_text(hjust=1, vjust=0.5, angle=90),
  842. axis.line.x = element_blank(),
  843. axis.line.y = element_blank(),
  844. axis.title.y = element_blank(),
  845. axis.text.y = element_blank(),
  846. axis.ticks.y = element_blank()
  847. ) +
  848. facet_grid(rows = vars(training_label))
  849. })
  850. ann_time_res_pl_list <- lapply(levels(ann_time_cor_res$model_label), function(m_i) {
  851. d_i <- ann_time_cor_res |>
  852. filter(model_label == m_i)
  853. n_layers <- length(unique(d_i$layer))
  854. ann_palette_i <- ann_palette(n=n_layers, end=0.85)
  855. if (n_layers > 4) {
  856. linewidth_i <- 0.25
  857. } else {
  858. linewidth_i <- 0.5
  859. }
  860. d_i |>
  861. mutate(layer_level = as.factor(layer_level)) |>
  862. arrange(layer_level) |>
  863. mutate(layer_label = factor(layer_label, levels=unique(layer_label))) |>
  864. ggplot() +
  865. geom_ribbon(aes(x=time_ms, ymin=lwr, ymax=upr), data=mutate(noise_ceiling_time, time_ms=time*1000, Rho=NA), fill="lightgrey") +
  866. geom_line(aes(time_ms, rho, colour=layer_level), linewidth=linewidth_i) +
  867. geom_vline(xintercept=0, linewidth=axis_linewidth) +
  868. geom_hline(yintercept=0, linewidth=axis_linewidth) +
  869. scale_x_continuous(expand=c(0,0), limits=c(-200, 1000), breaks=seq(-200, 1000, 200)) +
  870. scale_y_continuous(limits=rho_limits, expand=c(0,0)) +
  871. scale_colour_manual(values=ann_palette_i) +
  872. facet_grid(rows=vars(training_label)) +
  873. labs(
  874. x = "Time (ms)",
  875. y = "ρ"
  876. ) +
  877. theme(
  878. axis.line = element_blank(),
  879. legend.position = "none",
  880. plot.margin = margin(0, 0, 0, 10, "pt"),
  881. axis.line.y = element_blank(),
  882. axis.title.y = element_blank(),
  883. axis.text.y = element_blank(),
  884. axis.ticks.y = element_blank()
  885. )
  886. })
  887. pl$ann_joined <- wrap_plots(list(
  888. ann_p1_pl_list[[1]] + labs(tag="a", title=sprintf("ResNet-50\n\n%s ms", paste(p1_window, collapse="-"))) + theme(plot.title=element_text(hjust=0, size=8)),
  889. ann_poi_pl_list[[1]] + labs(title=sprintf("\n\n%s ms", paste(poi_window, collapse="-"))) + theme(plot.title=element_text(hjust=0, size=8)),
  890. ann_time_res_pl_list[[1]] + labs(title="\n\nTime-Resolved") + theme(plot.title=element_text(hjust=0, size=8)),
  891. ann_p1_pl_list[[2]] + labs(tag="b", title=sprintf("CORnet-Z\n\n%s ms", paste(p1_window, collapse="-"))) + theme(plot.title=element_text(hjust=0, size=8)),
  892. ann_poi_pl_list[[2]] + labs(title=sprintf("\n\n%s ms", paste(poi_window, collapse="-"))) + theme(plot.title=element_text(hjust=0, size=8)),
  893. ann_time_res_pl_list[[2]] + labs(title="\n\nTime-Resolved") + theme(plot.title=element_text(hjust=0, size=8))
  894. ), nrow=2, ncol=3, widths=c(1, 1, 2))
  895. ggsave(file.path("fig", "ANN_joined.png"), pl$ann_joined, width=5.5, height=6, device="png", type="cairo")
  896. ggsave(file.path("fig", "ANN_joined.pdf"), pl$ann_joined, width=5.5, height=6, device=cairo_pdf)
  897. # controls analysis -------------------------------------------------------
  898. controls_poi_res <- readRDS(file.path("estimates", "controls_cor_samps_long.rds")) |>
  899. ungroup() |>
  900. select(-starts_with(".")) |>
  901. mutate(
  902. model = case_match(
  903. model,
  904. "Jaccard Distance" ~ "Jaccard Distance",
  905. "Wasserstein Distance" ~ "Wasserstein Distance",
  906. "Complexity Distance" ~ "Visual Size Distance",
  907. "Frequency Distance" ~ "Letter Frequency\nDistance",
  908. "Phonological Distance" ~ "Dominant Phoneme\nPhonological Distance",
  909. "Letter Name Phonological Distance" ~ "Letter Name\nPhonological Distance"
  910. )
  911. ) |>
  912. mutate(
  913. partialness = factor(ifelse(is_partial, "Partial\nCorrelations", "Correlations"), levels=c("Correlations", "Partial\nCorrelations")),
  914. model = factor(model, levels=c(
  915. "Jaccard Distance",
  916. "Wasserstein Distance",
  917. "Visual Size Distance",
  918. "Letter Frequency\nDistance",
  919. "Dominant Phoneme\nPhonological Distance",
  920. "Letter Name\nPhonological Distance"
  921. ))
  922. )
  923. controls_time_res <- readRDS(file.path("estimates", "controls_timecourse.rds")) |>
  924. ungroup() |>
  925. select(-starts_with(".")) |>
  926. pivot_longer(cols = c(starts_with("cor_"), starts_with("pcor")), names_to="cor_par", values_to="Rho") |>
  927. mutate(
  928. partialness = factor(
  929. ifelse(grepl("^pcor_", cor_par), "Partial Correlations", "Correlations"),
  930. levels = c("Correlations", "Partial Correlations")
  931. ),
  932. model = factor(case_when(
  933. grepl("rankjacc$", cor_par) ~ "Jaccard Distance",
  934. grepl("rankot$", cor_par) ~ "Wasserstein Distance",
  935. grepl("rankcompdist$", cor_par) ~ "Visual Size Distance",
  936. grepl("rankfreqdist$", cor_par) ~ "Letter Frequency\nDistance",
  937. grepl("rankphondist$", cor_par) ~ "Dominant Phoneme\nPhonological Distance",
  938. grepl("ranknamephondist$", cor_par) ~ "Letter Name\nPhonological Distance",
  939. ), levels=c(
  940. "Jaccard Distance",
  941. "Wasserstein Distance",
  942. "Visual Size Distance",
  943. "Letter Frequency\nDistance",
  944. "Dominant Phoneme\nPhonological Distance",
  945. "Letter Name\nPhonological Distance"
  946. ))
  947. ) |>
  948. select(-cor_par) |>
  949. group_by(time, model, partialness) |>
  950. median_hdi(Rho, .width=.89)
  951. controls_all_chs_poi_res <- readRDS(file.path("estimates", "controls_all_chs_cor_samps_long.rds")) |>
  952. ungroup() |>
  953. select(-starts_with(".")) |>
  954. mutate(
  955. model = case_match(
  956. model,
  957. "Jaccard Distance" ~ "Jaccard Distance",
  958. "Wasserstein Distance" ~ "Wasserstein Distance",
  959. "Complexity Distance" ~ "Visual Size Distance",
  960. "Frequency Distance" ~ "Letter Frequency\nDistance",
  961. "Phonological Distance" ~ "Dominant Phoneme\nPhonological Distance",
  962. "Letter Name Phonological Distance" ~ "Letter Name\nPhonological Distance"
  963. )
  964. ) |>
  965. mutate(
  966. partialness = factor(ifelse(is_partial, "Partial\nCorrelations", "Correlations"), levels=c("Correlations", "Partial\nCorrelations")),
  967. model = factor(model, levels=c(
  968. "Jaccard Distance",
  969. "Wasserstein Distance",
  970. "Visual Size Distance",
  971. "Letter Frequency\nDistance",
  972. "Dominant Phoneme\nPhonological Distance",
  973. "Letter Name\nPhonological Distance"
  974. ))
  975. )
  976. controls_all_chs_time_res <- readRDS(file.path("estimates", "controls_all_chs_timecourse.rds")) |>
  977. ungroup() |>
  978. select(-starts_with(".")) |>
  979. pivot_longer(cols = c(starts_with("cor_"), starts_with("pcor")), names_to="cor_par", values_to="Rho") |>
  980. mutate(
  981. partialness = factor(
  982. ifelse(grepl("^pcor_", cor_par), "Partial Correlations", "Correlations"),
  983. levels = c("Correlations", "Partial Correlations")
  984. ),
  985. model = factor(case_when(
  986. grepl("rankjacc$", cor_par) ~ "Jaccard Distance",
  987. grepl("rankot$", cor_par) ~ "Wasserstein Distance",
  988. grepl("rankcompdist$", cor_par) ~ "Visual Size Distance",
  989. grepl("rankfreqdist$", cor_par) ~ "Letter Frequency\nDistance",
  990. grepl("rankphondist$", cor_par) ~ "Dominant Phoneme\nPhonological Distance",
  991. grepl("ranknamephondist$", cor_par) ~ "Letter Name\nPhonological Distance",
  992. ), levels=c(
  993. "Jaccard Distance",
  994. "Wasserstein Distance",
  995. "Visual Size Distance",
  996. "Letter Frequency\nDistance",
  997. "Dominant Phoneme\nPhonological Distance",
  998. "Letter Name\nPhonological Distance"
  999. ))
  1000. ) |>
  1001. select(-cor_par) |>
  1002. group_by(time, model, partialness) |>
  1003. median_hdi(Rho, .width=.89)
  1004. # colourblind friendly palette combining the original colours with some Okabe-Ito colours
  1005. controls_colours <- c(
  1006. measure_cols,
  1007. "Visual Size Distance" = "#009E73",
  1008. "Letter Frequency\nDistance" = "#F0E442",
  1009. "Dominant Phoneme\nPhonological Distance" = "#CC79A7",
  1010. "Letter Name\nPhonological Distance" = "#56B4E9"
  1011. )
  1012. rho_limits_controls <- c(-0.15, max(c(noise_ceiling_time$upr, noise_ceiling_poi$upr, noise_ceiling_time_all_chs$upr, noise_ceiling_poi_all_chs$upr)))
  1013. dummy_controls_df <- controls_poi_res |>
  1014. select(cor_lab, model, is_partial, partialness) |>
  1015. distinct() |>
  1016. mutate(Rho = 0)
  1017. nc_ch_grps <- bind_rows(
  1018. mutate(noise_ceiling_poi, chs_grp="post"),
  1019. mutate(noise_ceiling_poi_all_chs, chs_grp="all")
  1020. ) |>
  1021. mutate(chs_grp = factor(
  1022. chs_grp,
  1023. levels=c("post", "all"),
  1024. labels=c("150-225 ms\n\nPosterior Channels", "\n\nAll Channels")
  1025. ))
  1026. nc_time_ch_grps <- bind_rows(
  1027. mutate(noise_ceiling_time, chs_grp="post"),
  1028. mutate(noise_ceiling_time_all_chs, chs_grp="all")
  1029. ) |>
  1030. mutate(
  1031. chs_grp = factor(
  1032. chs_grp,
  1033. levels=c("post", "all"),
  1034. labels=c("Time-Resolved\n\nPosterior Channels", "\n\nAll Channels")
  1035. )
  1036. ) |>
  1037. mutate(partialness="Correlations") %>%
  1038. bind_rows(
  1039. .,
  1040. mutate(., partialness="Partial Correlations")
  1041. ) |>
  1042. mutate(partialness = factor(partialness, levels=c("Correlations", "Partial Correlations")))
  1043. pl$controls_poi <- bind_rows(
  1044. mutate(controls_poi_res, chs_grp="post"),
  1045. mutate(controls_all_chs_poi_res, chs_grp="all")
  1046. ) |>
  1047. mutate(
  1048. chs_grp = factor(
  1049. chs_grp,
  1050. levels=c("post", "all"),
  1051. labels=c("150-225 ms\n\nPosterior Channels", "\n\nAll Channels")
  1052. ),
  1053. interval_yloc = as.numeric(forcats::fct_rev(model))-0.15
  1054. ) |>
  1055. ggplot() +
  1056. geom_rect(aes(xmin=lwr, xmax=upr, ymin=-Inf, ymax=Inf), colour=NA, fill="lightgrey", data=nc_ch_grps) +
  1057. geom_vline(xintercept=0, linewidth=axis_linewidth) +
  1058. stat_slab(aes(Rho, model, colour=model, group=partialness, linetype=partialness), fill=NA, height=0.7, show.legend=FALSE, linewidth=axis_linewidth*1.25) +
  1059. geom_vline(aes(xintercept=0, linetype=partialness), key_glyph="path", data=dummy_controls_df, alpha=0) +
  1060. stat_pointinterval(aes(Rho, model, colour=model, group=partialness, y=interval_yloc), point_interval="median_hdi", .width=c(.5, .89), shape=NA, position=position_dodgev(height=0.25)) +
  1061. scale_colour_manual(values=controls_colours, guide="none") +
  1062. scale_linetype(guide=guide_legend(override.aes = list(alpha=1))) +
  1063. labs(
  1064. x = "ρ",
  1065. y = "Model",
  1066. linetype = NULL
  1067. ) +
  1068. scale_x_continuous(expand=c(0,0)) +
  1069. scale_y_discrete(limits=rev, expand=c(0,0)) +
  1070. facet_grid(cols=vars(chs_grp)) +
  1071. theme(
  1072. legend.position = "inside",
  1073. legend.position.inside = c(1, 1),
  1074. legend.justification = c(1, 1),
  1075. # legend.background = element_blank(),
  1076. legend.margin = margin(2, 2, 2, 2, "pt"),
  1077. legend.key = element_blank(),
  1078. strip.text = element_text(hjust=0)
  1079. )
  1080. pl$controls_timecourse <- bind_rows(
  1081. mutate(controls_time_res, chs_grp="post"),
  1082. mutate(controls_all_chs_time_res, chs_grp="all")
  1083. ) |>
  1084. mutate(chs_grp = factor(
  1085. chs_grp,
  1086. levels=c("post", "all"),
  1087. labels=c("Time-Resolved\n\nPosterior Channels", "\n\nAll Channels")
  1088. )) |>
  1089. mutate(time_ms = time*1000) |>
  1090. ggplot(aes(time_ms, Rho, colour=model)) +
  1091. geom_ribbon(aes(x=time_ms, ymin=lwr, ymax=upr), colour=NA, data=mutate(nc_time_ch_grps, time_ms=time*1000, Rho=1, model=NA), fill="lightgrey") +
  1092. geom_hline(yintercept=0, linewidth=axis_linewidth) +
  1093. geom_vline(xintercept=0, linewidth=axis_linewidth) +
  1094. # geom_ribbon(aes(time_ms, Rho, fill=model, group=model, ymin=.lower, ymax=.upper), alpha=0.4, colour=NA) +
  1095. geom_line() +
  1096. scale_colour_manual(values=controls_colours, guide="none") +
  1097. # scale_fill_manual(values=controls_colours, guide="none") +
  1098. scale_x_continuous(expand=c(0,0), limits=c(-200, 1000), breaks=seq(-200, 1000, 200)) +
  1099. scale_y_continuous(limits=rho_limits_controls, expand=c(0,0)) +
  1100. labs(
  1101. x = "Time (ms)",
  1102. y = "ρ"
  1103. ) +
  1104. facet_grid(cols=vars(chs_grp), rows=vars(partialness)) +
  1105. theme(
  1106. axis.line.x = element_blank(),
  1107. axis.line.y = element_blank(),
  1108. strip.text.x.top = element_text(hjust=0),
  1109. panel.spacing.x = unit(20, "pt")
  1110. )
  1111. pl$controls <- plot_grid(
  1112. pl$controls_poi + labs(tag="a"),
  1113. pl$controls_timecourse + labs(tag="b"),
  1114. nrow=2, rel_heights=c(1, 1)
  1115. )
  1116. ggsave(file.path("fig", "controls.png"), pl$controls, width=6, height=6.5, device="png", type="cairo")
  1117. ggsave(file.path("fig", "controls.pdf"), pl$controls, width=6, height=6.5, device=cairo_pdf)
  1118. # sensitivity analysis ----------------------------------------------------
  1119. sens_res <- readRDS(file.path("estimates", "sensitivity_lkj_prior.rds")) |>
  1120. mutate(
  1121. partialness = factor(
  1122. ifelse(is_partial, "Partial Correlations", "Correlations"),
  1123. levels = c("Correlations", "Partial Correlations")
  1124. )
  1125. )
  1126. sens_xbreaks <- 1 * 10 ** seq(-3, 3, 1)
  1127. pl$sens_post <- sens_res |>
  1128. ggplot(aes(eta, Rho, colour=model, fill=model)) +
  1129. geom_ribbon(aes(ymin=.lower, ymax=.upper), colour=NA, alpha=0.4) +
  1130. geom_line() +
  1131. geom_vline(xintercept = 1.5, linetype="dashed") +
  1132. facet_grid(cols=vars(partialness)) +
  1133. scale_colour_manual(values=measure_cols) +
  1134. scale_fill_manual(values=measure_cols) +
  1135. scale_x_continuous(
  1136. trans="log10", breaks=sens_xbreaks,
  1137. limits=c(min(sens_xbreaks), max(sens_xbreaks)),
  1138. expand=c(0,0)
  1139. ) +
  1140. scale_y_continuous(
  1141. breaks = seq(-0.06, 0.06, 0.02),
  1142. limits = c(-0.06, NA)
  1143. ) +
  1144. theme(
  1145. legend.position = "bottom",
  1146. legend.position.inside = c(0.3, 0.1),
  1147. legend.key.height = unit(10, units="pt"),
  1148. legend.margin = margin(0,0,0,0),
  1149. panel.spacing.x = unit(25, "pt"),
  1150. plot.margin = margin(0,12.5,0,0, unit="pt")
  1151. ) +
  1152. labs(
  1153. x = "LKJ η Prior",
  1154. y = "ρ",
  1155. colour = NULL,
  1156. fill = NULL,
  1157. tag = "b"
  1158. )
  1159. pl$sens_priors <- tibble(
  1160. eta = 1 * 10 ** seq(-2, 2, 1),
  1161. prior_string = sprintf("lkjcorr(%g)", eta),
  1162. prior_label = sprintf("η=%s", scales::scientific(eta, digits=2))
  1163. ) |>
  1164. arrange(eta) |>
  1165. mutate(
  1166. prior_label = factor(prior_label, levels=unique(prior_label))
  1167. ) |>
  1168. parse_dist(prior_string) |>
  1169. marginalize_lkjcorr(K = 3) |>
  1170. ggplot(aes(xdist = .dist_obj)) +
  1171. stat_slabinterval(point_interval="median_hdi", .width=c(0.5, 0.89), justification=-0.05, shape="|", slab_colour="black", slab_linewidth=0.5) +
  1172. scale_x_continuous(limits=c(-1, 1)) +
  1173. facet_grid(cols=vars(prior_label)) +
  1174. labs(
  1175. x = "ρ",
  1176. y = "Density",
  1177. tag = "a"
  1178. ) +
  1179. theme(
  1180. axis.ticks.y = element_blank(),
  1181. axis.text.y = element_blank(),
  1182. panel.spacing.x = unit(12, "pt")
  1183. )
  1184. pl$sens <- pl$sens_priors / pl$sens_post
  1185. ggsave(file.path("fig", "sensitivity_analysis.pdf"), pl$sens, width=6.4, height=4, device=cairo_pdf)
  1186. ggsave(file.path("fig", "sensitivity_analysis.png"), pl$sens, width=6.4, height=4, device="png", type="cairo")
  1187. # tibble(prior_string = "lkjcorr(1.5)") |>
  1188. # parse_dist(prior_string) |>
  1189. # marginalize_lkjcorr(K = 3) |>
  1190. # ggplot(aes(xdist = .dist_obj)) +
  1191. # stat_slabinterval(point_interval="median_hdi", .width=c(0.5, 0.89), justification=-0.05, shape="|", slab_colour="black", slab_linewidth=0.5) +
  1192. # scale_x_continuous(limits=c(-1, 1)) +
  1193. # labs(
  1194. # x = "ρ",
  1195. # y = "Density"
  1196. # ) +
  1197. # theme(
  1198. # axis.ticks.y = element_blank(),
  1199. # axis.text.y = element_blank(),
  1200. # panel.spacing.x = unit(12, "pt")
  1201. # )