|
- library(readr)
- library(dplyr)
- library(tidyr)
- library(purrr)
- library(forcats)
- library(ggplot2)
- library(scales)
- library(ggstance)
- library(ggdist)
- library(cowplot)
- library(patchwork)
- source(file.path("fig_code", "ggplot2_theme.R"), local=TRUE)
- all_chars <- c(letters, "ä", "ö", "ü", "ß")
- # get noise ceiling
- noise_ceiling_poi <- read_csv(file.path("noise_ceiling", "noise_ceiling_poi.csv"))
- noise_ceiling_p1 <- read_csv(file.path("noise_ceiling", "noise_ceiling_p1.csv"))
- noise_ceiling_time <- read_csv(file.path("noise_ceiling", "noise_ceiling_time.csv"))
- noise_ceiling_poi_all_chs <- read_csv(file.path("noise_ceiling", "noise_ceiling_poi_all_chs.csv"))
- noise_ceiling_time_all_chs <- read_csv(file.path("noise_ceiling", "noise_ceiling_time_all_chs.csv"))
- pl <- list() # will contain all plots
- poi_window <- c(150, 225)
- p1_window <- c(80, 130)
- rho_limits <- c(-0.1, max(c(noise_ceiling_time$upr, noise_ceiling_poi$upr, noise_ceiling_p1$upr)))
- # preregistered analysis --------------------------------------------------
- # plots of RDMs
- stim_sim <- left_join(
- read_csv(file.path("stim_sim", "preregistered", "jacc.csv")),
- read_csv(file.path("stim_sim", "preregistered", "ot.csv")),
- by = c("char1", "char2")
- ) |>
- left_join(
- read_csv(file.path("stim_sim", "complexity", "complexity.csv")),
- by = c("char1", "char2")
- ) |>
- mutate(
- rank_jacc = rank(jacc),
- rank_ot = rank(ot),
- rank_comp_dist = rank(comp_dist)
- ) |>
- rowwise() |>
- mutate(pair_id = paste(sort(c(char1, char2)), collapse="_")) |>
- ungroup() |>
- select(-starts_with("."))
- rdm_poi <- file.path("rdm_data", "period_of_interest") |>
- list.files(pattern=".*\\.csv", full.names=TRUE) |>
- map_df(read_csv, col_types=c("subj_id"="c")) |>
- select(-starts_with(".")) |>
- left_join(stim_sim, by=c("char1", "char2")) |>
- group_by(subj_id) |>
- mutate(rank_eeg_dissim = rank(eeg_dissim)) |>
- ungroup() |>
- mutate(
- char1 = factor(char1, levels=all_chars, ordered=TRUE),
- char2 = factor(char2, levels=all_chars, ordered=TRUE)
- )
- pl$rdm <- lapply(unique(rdm_poi$subj_id), function(s) {
- d_s <- filter(rdm_poi, subj_id==s)
- d_s |>
- bind_rows(
- d_s |>
- mutate(char1_tmp=char2, char2=char1, char1=char1_tmp) |>
- dplyr::select(-char1_tmp)
- ) |>
- ggplot(aes(char1, char2, fill=rank_eeg_dissim)) +
- geom_tile() +
- scale_fill_viridis_c() +
- coord_fixed() +
- scale_x_discrete(expand=c(0,0)) +
- scale_y_discrete(limits=rev, expand=c(0,0)) +
- labs(title=s, x=NULL, y=NULL, fill=NULL) +
- theme(
- plot.title = element_text(hjust=0.5),
- axis.line = element_blank(),
- legend.position = "none"
- )
- }) |>
- wrap_plots(nrow=3)
- ggsave(file.path("fig", "neural_rdm.png"), pl$rdm, width=8, height=5, device="png", type="cairo")
- pl$avg_rdm <- rdm_poi |>
- group_by(char1, char2) |>
- # summarise(avg_eeg_dissim = mean(eeg_dissim)) |>
- # mutate(avg_rank_eeg_dissim = rank(avg_eeg_dissim)) |>
- summarise(avg_rank_eeg_dissim = mean(rank_eeg_dissim)) |>
- ungroup() %>%
- bind_rows(
- .,
- . |>
- mutate(char1_tmp=char2, char2=char1, char1=char1_tmp) |>
- dplyr::select(-char1_tmp)
- ) |>
- ggplot(aes(char1, char2, fill=avg_rank_eeg_dissim)) +
- geom_tile(linewidth=0.01) +
- scale_fill_viridis_c() +
- coord_fixed() +
- scale_x_discrete(
- expand=c(0,0),
- breaks=all_chars[c(1, 30)]
- ) +
- scale_y_discrete(
- limits=rev, expand=c(0,0),
- breaks=all_chars[c(1, 30)]
- ) +
- guides(x.sec = "axis", y.sec = "axis") +
- labs(x=NULL, y=NULL, fill=NULL) +
- theme(
- plot.title = element_text(hjust=0.5),
- axis.line = element_line(linewidth=axis_linewidth),
- axis.text.x.top = element_blank(),
- axis.ticks.x.top = element_blank(),
- axis.text.y.right = element_blank(),
- axis.ticks.y.right = element_blank(),
- legend.position = "none"
- )
- pl$avg_rdm <- ggdraw(pl$avg_rdm) +
- draw_label("...", x = 0.5, y = 0.02, size= 10, hjust=0.5, vjust=0) +
- draw_label("...", x = 0.02, y = 0.5, size= 10, angle=90, hjust=0.5, vjust=0)
- ggsave(file.path("fig", "avg_neural_rdm.png"), pl$avg_rdm, width=2, height=2, device="png", type="cairo")
- # plot model RDMs for planned analysis
- stim_sim <- left_join(
- read_csv(file.path("stim_sim", "preregistered", "jacc.csv")),
- read_csv(file.path("stim_sim", "preregistered", "ot.csv")),
- by = c("char1", "char2")
- ) |>
- left_join(
- read_csv(file.path("stim_sim", "complexity", "complexity.csv")),
- by = c("char1", "char2")
- ) |>
- mutate(
- rank_jacc = rank(jacc),
- rank_ot = rank(ot),
- rank_comp_dist = rank(comp_dist)
- ) |>
- mutate(
- char1 = factor(char1, levels=all_chars, ordered=TRUE),
- char2 = factor(char2, levels=all_chars, ordered=TRUE)
- )
- pl$ot_rdm <- stim_sim %>%
- bind_rows(
- .,
- . |>
- mutate(char1_tmp=char2, char2=char1, char1=char1_tmp) |>
- dplyr::select(-char1_tmp)
- ) |>
- ggplot(aes(char1, char2, fill=rank_ot)) +
- geom_tile(linewidth=0.01) +
- scale_fill_viridis_c() +
- coord_fixed() +
- scale_x_discrete(
- expand=c(0,0),
- breaks=all_chars[c(1, 30)]
- ) +
- scale_y_discrete(
- limits=rev, expand=c(0,0),
- breaks=all_chars[c(1, 30)]
- ) +
- guides(x.sec = "axis", y.sec = "axis") +
- labs(x=NULL, y=NULL, fill=NULL) +
- theme(
- plot.title = element_text(hjust=0.5),
- axis.line = element_line(linewidth=axis_linewidth),
- axis.text.x.top = element_blank(),
- axis.ticks.x.top = element_blank(),
- axis.text.y.right = element_blank(),
- axis.ticks.y.right = element_blank(),
- legend.position = "none"
- )
- pl$ot_rdm <- ggdraw(pl$ot_rdm) +
- draw_label("...", x = 0.5, y = 0.02, size= 10, hjust=0.5, vjust=0) +
- draw_label("...", x = 0.02, y = 0.5, size= 10, angle=90, hjust=0.5, vjust=0)
- ggsave(file.path("fig", "model_rdm_ot.png"), pl$ot_rdm, width=2, height=2, device="png", type="cairo")
- pl$jacc_rdm <- stim_sim %>%
- bind_rows(
- .,
- . |>
- mutate(char1_tmp=char2, char2=char1, char1=char1_tmp) |>
- dplyr::select(-char1_tmp)
- ) |>
- ggplot(aes(char1, char2, fill=rank_jacc)) +
- geom_tile(linewidth=0.01) +
- scale_fill_viridis_c() +
- coord_fixed() +
- scale_x_discrete(
- expand=c(0,0),
- breaks=all_chars[c(1, 30)]
- ) +
- scale_y_discrete(
- limits=rev, expand=c(0,0),
- breaks=all_chars[c(1, 30)]
- ) +
- guides(x.sec = "axis", y.sec = "axis") +
- labs(x=NULL, y=NULL, fill=NULL) +
- theme(
- plot.title = element_text(hjust=0.5),
- axis.line = element_line(linewidth=axis_linewidth),
- axis.text.x.top = element_blank(),
- axis.ticks.x.top = element_blank(),
- axis.text.y.right = element_blank(),
- axis.ticks.y.right = element_blank(),
- legend.position = "none"
- )
- pl$jacc_rdm <- ggdraw(pl$jacc_rdm) +
- draw_label("...", x = 0.5, y = 0.02, size= 10, hjust=0.5, vjust=0) +
- draw_label("...", x = 0.02, y = 0.5, size= 10, angle=90, hjust=0.5, vjust=0)
- ggsave(file.path("fig", "model_rdm_jacc.png"), pl$jacc_rdm, width=2, height=2, device="png", type="cairo")
- # plots of the planned comparison
- prereg_cor_samps_long <- readRDS(file.path("estimates", "prereg_cor_samps_long.rds")) |>
- ungroup() |>
- mutate(summ_y_pos = case_match(
- cor_lab,
- "pcor_rankeegdissim_rankjacc" ~ -10,
- "pcor_rankeegdissim_rankot" ~ -12,
- "cor_rankeegdissim_rankjacc" ~ -5,
- "cor_rankeegdissim_rankot" ~ -7
- ))
- prereg_bf <- prereg_cor_samps_long |>
- pivot_wider(id_cols=starts_with("."), names_from=cor_lab, values_from=Rho) |>
- ungroup() |>
- select(-starts_with(".")) |>
- summarise(
- bf1 = sum(cor_rankeegdissim_rankot > 0) /
- sum(cor_rankeegdissim_rankot <= 0),
- bf2a = sum(cor_rankeegdissim_rankot > cor_rankeegdissim_rankjacc) /
- sum(cor_rankeegdissim_rankot <= cor_rankeegdissim_rankjacc),
- bf2b = sum(pcor_rankeegdissim_rankot > pcor_rankeegdissim_rankjacc) /
- sum(pcor_rankeegdissim_rankot <= pcor_rankeegdissim_rankjacc)
- )
- pl$posterior_plot_part <- prereg_cor_samps_long |>
- mutate(corr_lab = ifelse(is_partial, "Partial Correlation", "Correlation")) |>
- ggplot(aes(Rho, colour=model)) +
- annotate(geom="rect", xmin=noise_ceiling_poi$lwr, xmax=noise_ceiling_poi$upr, ymin=-Inf, ymax=Inf, colour=NA, fill="lightgrey") +
- # geom_density(aes(Rho, colour=model, linetype="Partial Correlation"), linewidth=0.5, trim=TRUE, filter(cor_samps_long, is_partial), key_glyph="path") +
- # geom_density(aes(Rho, colour=model, linetype="Correlation"), linewidth=0.5, trim=TRUE, filter(cor_samps_long, !is_partial), key_glyph="path") +
- geom_density(aes(Rho, colour=model, linetype=corr_lab), linewidth=0.5, trim=TRUE, key_glyph="path") +
- stat_pointinterval(aes(y=summ_y_pos), point_interval=median_hdi, .width = c(.5, .89), point_size=0, data=filter(prereg_cor_samps_long, is_partial), show.legend=FALSE) +
- stat_pointinterval(aes(y=summ_y_pos), point_interval=median_hdi, .width = c(.5, .89), point_size=0, data=filter(prereg_cor_samps_long, !is_partial), show.legend=FALSE) +
- geom_vline(xintercept=0, linetype="dashed", linewidth=axis_linewidth) +
- # scale_fill_manual(values=measure_cols_light, guide="none") +
- scale_colour_manual(values=measure_cols) +
- scale_linetype_manual(values = c("solid", "dashed")) +
- scale_y_continuous(limits = c(-13, NA), breaks=seq(0, 30, 10)) +
- labs(
- x = "ρ", y = "Posterior Density",
- colour=NULL, fill=NULL,
- linewidth=NULL, linetype=NULL
- ) +
- guides(
- linetype = guide_legend(order = 2),
- colour = guide_legend(order = 1),
- linewidth = "none"
- ) +
- theme(
- legend.position="right",
- legend.direction = "vertical",
- legend.spacing.y = unit(3, "pt"),
- legend.margin = margin(),
- legend.box.spacing = unit(0, "pt"),
- legend.box.margin = margin()
- )
- ggsave(file.path("fig", "posterior_poi_partial.png"), pl$posterior_plot_part, width=5, height=1.55, device="png", type="cairo")
- ggsave(file.path("fig", "posterior_poi_partial.pdf"), pl$posterior_plot_part, width=5, height=1.55, device=cairo_pdf)
- ggsave(file.path("fig", "posterior_poi_partial.pdf"), pl$posterior_plot_part, width=5, height=1.55, device=cairo_pdf) # save again (seems to address bug in legend spacing?)
- ggsave(file.path("fig", "posterior_poi_partial.svg"), pl$posterior_plot_part, width=4.75, height=1.2)
- # timecourse of the planned comparison
- tc <- readRDS(file.path("estimates", "planned_timecourse.rds")) |>
- ungroup() |>
- select(-starts_with("."))
- # calculate time-resolved BFs
- na2zero <- function(x) {
- x[is.na(x)] <- 0
- x
- }
- tc_bf <- tc |>
- group_by(time) |>
- summarise(
- bf1 = sum(cor_rankeegdissim_rankot > 0) /
- sum(cor_rankeegdissim_rankot <= 0),
- bf2a = sum(cor_rankeegdissim_rankot > cor_rankeegdissim_rankjacc ) /
- sum(cor_rankeegdissim_rankot <= cor_rankeegdissim_rankjacc),
- bf2b = sum(pcor_rankeegdissim_rankot > pcor_rankeegdissim_rankjacc ) /
- sum(pcor_rankeegdissim_rankot <= pcor_rankeegdissim_rankjacc),
- n_samps = n()
- ) |>
- pivot_longer(cols=c(bf1, bf2a, bf2b), names_to="hypothesis", values_to="bf", names_prefix="bf") |>
- group_by(hypothesis) |>
- mutate(
- time_ms = time * 1000,
- infinite_BF = (is.infinite(bf)|bf==0),
- infinite_section_start = infinite_BF & !lag(infinite_BF),
- infinite_section_end = infinite_BF & !lead(infinite_BF),
- is_section = cumsum(na2zero(infinite_section_start - lag(infinite_section_end))),
- inf_section_nr = ifelse(is_section==0, 0, cumsum(infinite_section_start))
- ) |>
- ungroup() |>
- mutate( bf_bilinear = ifelse(bf<1, -1/bf+1, bf-1) )
- # get credible intervals for Rho
- tc_cr_i <- tc |>
- pivot_longer(cols=c(starts_with("cor"), starts_with("pcor")), names_to="cor_lab", values_to="Rho") |>
- mutate(
- partial = grepl("^pcor_", cor_lab),
- cor_lab = sub("(^cor_)|(^pcor_)", "", cor_lab)
- ) |>
- group_by(cor_lab, partial, time) |>
- median_hdi(.width=.89) |>
- mutate(measure = recode(
- cor_lab,
- rankeegdissim_rankot = "Wasserstein Distance",
- rankeegdissim_rankjacc = "Jaccard Distance"
- ))
- # plot timecourse
- ylims <- c(
- round(min(c(tc_cr_i$.lower, noise_ceiling_time$lwr)) - 0.005, 2),
- round(max(c(tc_cr_i$.upper, noise_ceiling_time$upr)) + 0.005, 2)
- )
- tc_rho_pl <- lapply(c(FALSE, TRUE), function(is_partial_pl) {
- # x_title <- if (is_partial_pl) "Time (ms)" else NULL
- # x_ticks <- if (is_partial_pl) element_line() else element_blank()
- # x_tick_labs <- if (is_partial_pl) element_text(angle=45, hjust=1, vjust=1) else element_blank()
- y_title <- if(is_partial_pl) "Partial ρ" else "ρ"
- # linetype <- if(is_partial_pl) "dotted" else "solid"
- linetype <- "solid"
-
- x_title <- NULL
- x_ticks <- element_blank()
- x_tick_labs <- element_blank()
-
- tag <- NULL
-
- tc_cr_i |>
- filter(partial==is_partial_pl) |>
- arrange(desc(.width)) |>
- mutate(time_ms = time * 1000) |>
- ggplot() +
- geom_ribbon(aes(x=time_ms, ymin=lwr, ymax=upr), data=mutate(noise_ceiling_time, time_ms=time*1000, Rho=NA), fill="lightgrey") +
- geom_hline(yintercept=0, linewidth=axis_linewidth) +
- geom_vline(xintercept=0, linewidth=axis_linewidth) +
- geom_ribbon(aes(time_ms, Rho, fill=measure, group=measure, ymin=.lower, ymax=.upper), alpha=0.4, colour=NA) +
- geom_line(aes(time_ms, Rho, colour=measure, group=measure), linetype=linetype) +
- scale_fill_manual(values = measure_cols) +
- scale_colour_manual(values = measure_cols) +
- scale_x_continuous(expand=c(0,0), limits=c(-200, 1000), breaks=seq(-200, 1000, 100)) +
- scale_y_continuous(expand=c(0,0), limits=ylims) +
- labs(x = x_title, y = y_title, colour=NULL, fill=NULL, tag=tag) +
- theme(
- # legend.position = if(is_partial_pl) "none" else c(1, 1.2),
- legend.position = "inside",
- legend.position.inside = c(1, 1),
- legend.direction = "horizontal",
- legend.justification = c(1, 0),
- legend.background = element_blank(),
- plot.margin = margin(0,10,0,0, unit="pt"),
- legend.key.height = unit(10, units="pt"),
- legend.margin = margin(0,0,0,0),
- axis.ticks.x = x_ticks,
- axis.text.x = x_tick_labs,
- axis.line = element_blank(),
- plot.background = element_blank()
- )
- })
- tc_pl_joined <- list(
- tc_rho_pl[[1]],
- tc_rho_pl[[2]] +
- labs(x = "Time (ms)") +
- theme(
- legend.position="none",
- axis.ticks.x = element_line(),
- axis.text.x = element_text(),
- axis.line.x = element_blank()
- )
- ) |>
- wrap_plots(ncol=1, heights=c(2.5, 2.5)) +
- plot_annotation(tag_levels = "a")
- ggsave(file.path("fig", "tc_joined_CrI.pdf"), tc_pl_joined, width=4.5, height=3, device=cairo_pdf)
- ggsave(file.path("fig", "tc_joined_CrI.png"), tc_pl_joined, width=4.5, height=3, device="png", type="cairo")
- # tc_pl_joined_fewer_ticks <- list(
- # tc_rho_pl[[1]],
- # tc_rho_pl[[2]] +
- # labs(x = "Time (ms)") +
- # scale_x_continuous(expand=c(0,0), limits=c(-200, 1000), breaks=seq(-200, 1000, 100), labels=ifelse(seq(-200, 1000, 100)%%200==0, seq(-200, 1000, 100), "")) +
- # theme(
- # legend.position="none",
- # axis.ticks.x = element_line(),
- # axis.text.x = element_text(),
- # axis.line.x = element_blank()
- # )
- # ) |>
- # wrap_plots(ncol=1, heights=c(2.5, 2.5)) +
- # plot_annotation(tag_levels = "a")
- ggsave(file.path("fig", "tc_joined_CrI.svg"), tc_pl_joined, width=5, height=3.5, device="svg")
- bf_lims <- max(c(tc_bf$bf[!is.infinite(tc_bf$bf)], 1/tc_bf$bf[!is.infinite(1/tc_bf$bf)]))
- bf_ticks <- c(10000, 100, 1, 1/100, 1/10000)
- bf_labs <- c("10000", "100", "1", "1/100", "1/10000")
- tc_bf_h2a <- filter(tc_bf, hypothesis=="2a") |>
- mutate(
- bf_bilinear = ifelse(infinite_BF, n_samps-1, bf)
- )
- bf_pl_bilinear <- tc_bf_h2a |>
- ggplot(aes(time_ms, bf_bilinear)) +
- annotate("rect", xmin=-Inf, xmax=Inf, ymin=1, ymax=Inf, fill=measure_cols[["Wasserstein Distance"]], alpha=0.4) +
- annotate("rect", xmin=-Inf, xmax=Inf, ymin=-Inf, ymax=0, fill=measure_cols[["Jaccard Distance"]], alpha=0.4) +
- geom_line() +
- geom_line(aes(group=inf_section_nr), y=bf_lims, colour="red", data=filter(tc_bf_h2a, infinite_BF)) +
- geom_hline(yintercept=1, linewidth=axis_linewidth) +
- geom_vline(xintercept=0, linewidth=axis_linewidth) +
- scale_x_continuous(expand=c(0,0), limits=c(-200, 1000), breaks=seq(-200, 1000, 100)) +
- scale_y_continuous(limits = c(-9999, bf_lims), breaks=c(-9999, 0, 9999, 19999, 29999, 39999), labels=c("1/10000", "1", "10000", "20000", "30000", "40000")) +
- labs(x = "Time (ms)", y = "BF") +
- theme(
- legend.position = "none",
- legend.margin = margin(),
- legend.box.margin = margin(),
- plot.margin = margin(0,10,2.5,2.5, unit="pt"),
- axis.line.x = element_blank(),
- axis.line.y = element_blank(),
- plot.background = element_blank()
- )
- tc_pl_joined_bf <- list(
- tc_rho_pl[[1]],
- plot_spacer(),
- tc_rho_pl[[2]] +
- theme(legend.position = "none"),
- plot_spacer(),
- bf_pl_bilinear
- ) |>
- wrap_plots(ncol=1, heights=c(2.5, -0.3, 2.5, -0.3, 1.75)) +
- plot_annotation(tag_levels = "a")
- ggsave(file.path("fig", "tc_joined_CrI_bf.pdf"), tc_pl_joined_bf, width=3.9, height=3, device=cairo_pdf)
- # geom results ---------------------------------------------------------
- geom_palette <- viridisLite::plasma
- geom_palette_1_9 <- geom_palette(n=9, end=0.85)[1:9]
- geom_palette_1_8 <- geom_palette(n=9, end=0.85)[1:8]
- interval_size_range_geom <- c(0.75, 2.5)
- # geom_fct_levels <- c("---", "--R", "-S-", "-SR", "T--", "T-R", "TS-", "TSR", "GW")
- # geom_fct_labels <- c("-", "R", "S", "SR", "T", "TR", "TS", "TSR", "G-W")
- geom_fct_levels <- c("---", "--R", "-S-", "T--", "-SR", "T-R", "TS-", "TSR", "GW")
- geom_fct_labels <- c("-", "R", "S", "T", "RS", "RT", "ST", "RST", "G-W")
- ot_geom_time_resolved <- readRDS(file.path("estimates", "ot_geom_time_resolved.rds")) |>
- ungroup() |>
- select(-starts_with(".")) |>
- arrange(time, translate, scale, rotate, gromov_wasserstein) |>
- mutate(
- time_ms = as.numeric(as.character(time)) * 1000,
- geom_label = factor(ifelse(
- gromov_wasserstein,
- "GW",
- paste(
- ifelse(translate, "T", "-"),
- ifelse(scale, "S", "-"),
- ifelse(rotate, "R", "-"),
- sep = ""
- )
- ),
- levels = geom_fct_levels,
- labels = geom_fct_labels)
- )
- jacc_geom_time_resolved <- readRDS(file.path("estimates", "jacc_geom_time_resolved.rds")) |>
- ungroup() |>
- select(-starts_with(".")) |>
- arrange(time, translate, scale, rotate) |>
- mutate(
- time_ms = as.numeric(as.character(time)) * 1000,
- geom_label = factor(
- paste(
- ifelse(translate, "T", "-"),
- ifelse(scale, "S", "-"),
- ifelse(rotate, "R", "-"),
- sep = ""
- ),
- levels = geom_fct_levels[1:8],
- labels = geom_fct_labels[1:8])
- )
- ot_geom_poi_res <- readRDS(file.path("estimates", "ot_geom_poi_draws.rds")) |>
- ungroup() |>
- select(-starts_with(".")) |>
- arrange(translate, scale, rotate, gromov_wasserstein) |>
- mutate(
- geom_label = factor(ifelse(
- gromov_wasserstein,
- "GW",
- paste(
- ifelse(translate, "T", "-"),
- ifelse(scale, "S", "-"),
- ifelse(rotate, "R", "-"),
- sep = ""
- )
- ),
- levels = geom_fct_levels,
- labels = geom_fct_labels)
- )
- jacc_geom_poi_res <- readRDS(file.path("estimates", "jacc_geom_poi_draws.rds")) |>
- ungroup() |>
- select(-starts_with(".")) |>
- arrange(translate, scale, rotate) |>
- mutate(
- geom_label = factor(
- paste(
- ifelse(translate, "T", "-"),
- ifelse(scale, "S", "-"),
- ifelse(rotate, "R", "-"),
- sep = ""
- ),
- levels = geom_fct_levels[1:8],
- labels = geom_fct_labels[1:8])
- )
- ot_geom_p1_res <- readRDS(file.path("estimates", "ot_geom_p1_draws.rds")) |>
- ungroup() |>
- select(-starts_with(".")) |>
- arrange(translate, scale, rotate, gromov_wasserstein) |>
- mutate(
- geom_label = factor(ifelse(
- gromov_wasserstein,
- "GW",
- paste(
- ifelse(translate, "T", "-"),
- ifelse(scale, "S", "-"),
- ifelse(rotate, "R", "-"),
- sep = ""
- )
- ),
- levels = geom_fct_levels,
- labels = geom_fct_labels)
- )
- jacc_geom_p1_res <- readRDS(file.path("estimates", "jacc_geom_p1_draws.rds")) |>
- ungroup() |>
- select(-starts_with(".")) |>
- arrange(translate, scale, rotate) |>
- mutate(
- geom_label = factor(
- paste(
- ifelse(translate, "T", "-"),
- ifelse(scale, "S", "-"),
- ifelse(rotate, "R", "-"),
- sep = ""
- ),
- levels = geom_fct_levels[1:8],
- labels = geom_fct_labels[1:8])
- )
- pl$ot_geom_time_res <- ot_geom_time_resolved |>
- ggplot() +
- geom_ribbon(aes(x=time_ms, ymin=lwr, ymax=upr), data=mutate(noise_ceiling_time, time_ms=time*1000, Rho=NA), fill="lightgrey") +
- geom_line(aes(time_ms, rho, colour=geom_label)) +
- geom_vline(xintercept=0, linewidth=axis_linewidth) +
- geom_hline(yintercept=0, linewidth=axis_linewidth) +
- scale_x_continuous(expand=c(0,0), limits=c(-200, 1000), breaks=seq(-200, 1000, 200)) +
- scale_y_continuous(limits=rho_limits, expand=c(0,0)) +
- scale_colour_viridis_d() +
- scale_colour_manual(values=geom_palette_1_9) +
- labs(
- x = "Time (ms)",
- y = "ρ"
- ) +
- theme(
- axis.line = element_blank(),
- legend.position = "none"
- )
- pl$ot_geom_poi_res <- ot_geom_poi_res |>
- ggplot(aes(geom_label, rho, colour=geom_label)) +
- annotate(geom="rect", ymin=noise_ceiling_poi$lwr, ymax=noise_ceiling_poi$upr, xmin=-Inf, xmax=Inf, colour=NA, fill="lightgrey") +
- stat_pointinterval(point_interval="median_hdi", .width=c(.5, .89), shape=NA, interval_size_range=interval_size_range_geom) +
- geom_hline(yintercept=0, linewidth=axis_linewidth) +
- geom_vline(xintercept=8.5, linewidth=axis_linewidth, linetype="dashed") +
- # scale_colour_viridis_d() +
- scale_colour_manual(values=geom_palette_1_9) +
- scale_y_continuous(limits=rho_limits, expand=c(0,0)) +
- scale_x_discrete() +
- labs(
- x = "Transformations",
- y = "ρ"
- ) +
- theme(
- legend.position = "none",
- strip.text = element_blank(),
- axis.title.x = element_blank(),
- axis.text.x = element_text(hjust=1, vjust=0.5, angle=90),
- axis.line.x = element_blank(),
- axis.line.y = element_blank(),
- axis.title.y = element_blank(),
- axis.text.y = element_blank(),
- axis.ticks.y = element_blank()
- )
- pl$ot_geom_p1_res <- ot_geom_p1_res |>
- ggplot(aes(geom_label, rho, colour=geom_label)) +
- annotate(geom="rect", ymin=noise_ceiling_p1$lwr, ymax=noise_ceiling_p1$upr, xmin=-Inf, xmax=Inf, colour=NA, fill="lightgrey") +
- stat_pointinterval(point_interval="median_hdi", .width=c(.5, .89), shape=NA, interval_size_range=interval_size_range_geom) +
- geom_hline(yintercept=0, linewidth=axis_linewidth) +
- geom_vline(xintercept=8.5, linewidth=axis_linewidth, linetype="dashed") +
- # scale_colour_viridis_d() +
- scale_colour_manual(values=geom_palette_1_9) +
- scale_y_continuous(limits=rho_limits, expand=c(0,0)) +
- scale_x_discrete() +
- labs(
- x = "Transformations",
- y = "ρ"
- ) +
- theme(
- legend.position = "none",
- strip.text = element_blank(),
- axis.title.x = element_blank(),
- axis.text.x = element_text(hjust=1, vjust=0.5, angle=90),
- axis.line.x = element_blank()
- )
- pl$jacc_geom_time_res <- jacc_geom_time_resolved |>
- ggplot() +
- geom_ribbon(aes(x=time_ms, ymin=lwr, ymax=upr), data=mutate(noise_ceiling_time, time_ms=time*1000, Rho=NA), fill="lightgrey") +
- geom_line(aes(time_ms, rho, colour=geom_label)) +
- geom_vline(xintercept=0, linewidth=axis_linewidth) +
- geom_hline(yintercept=0, linewidth=axis_linewidth) +
- scale_x_continuous(expand=c(0,0), limits=c(-200, 1000), breaks=seq(-200, 1000, 200)) +
- scale_y_continuous(limits=rho_limits, expand=c(0,0)) +
- # scale_colour_viridis_d() +
- scale_colour_manual(values=geom_palette_1_8) +
- labs(
- x = "Time (ms)",
- y = "ρ"
- ) +
- theme(
- axis.line = element_blank(),
- legend.position = "none"
- )
- pl$jacc_geom_poi_res <- jacc_geom_poi_res |>
- ggplot(aes(geom_label, rho, colour=geom_label)) +
- annotate(geom="rect", ymin=noise_ceiling_poi$lwr, ymax=noise_ceiling_poi$upr, xmin=-Inf, xmax=Inf, colour=NA, fill="lightgrey") +
- stat_pointinterval(point_interval="median_hdi", .width=c(.5, .89), shape=NA, interval_size_range=interval_size_range_geom) +
- geom_hline(yintercept=0, linewidth=axis_linewidth) +
- # scale_colour_viridis_d() +
- scale_colour_manual(values=geom_palette_1_8) +
- scale_y_continuous(limits=rho_limits, expand=c(0,0)) +
- scale_x_discrete() +
- labs(
- x = "Transformations",
- y = "ρ"
- ) +
- theme(
- legend.position = "none",
- strip.text = element_blank(),
- axis.title.x = element_blank(),
- axis.text.x = element_text(hjust=1, vjust=0.5, angle=90),
- axis.line.x = element_blank(),
- axis.line.y = element_blank(),
- axis.title.y = element_blank(),
- axis.text.y = element_blank(),
- axis.ticks.y = element_blank()
- )
- pl$jacc_geom_p1_res <- jacc_geom_p1_res |>
- ggplot(aes(geom_label, rho, colour=geom_label)) +
- annotate(geom="rect", ymin=noise_ceiling_p1$lwr, ymax=noise_ceiling_p1$upr, xmin=-Inf, xmax=Inf, colour=NA, fill="lightgrey") +
- stat_pointinterval(point_interval="median_hdi", .width=c(.5, .89), shape=NA, interval_size_range=interval_size_range_geom) +
- geom_hline(yintercept=0, linewidth=axis_linewidth) +
- # scale_colour_viridis_d() +
- scale_colour_manual(values=geom_palette_1_8) +
- scale_y_continuous(limits=rho_limits, expand=c(0,0)) +
- scale_x_discrete() +
- labs(
- x = "Transformations",
- y = "ρ"
- ) +
- theme(
- legend.position = "none",
- strip.text = element_blank(),
- axis.title.x = element_blank(),
- axis.text.x = element_text(hjust=1, vjust=0.5, angle=90),
- axis.line.x = element_blank()
- )
- # Joined geom results -----------------------------------------------------
- # POI and time course
- pl$wass_lab <- ggplot() +
- annotate("text", x=0, y=0, size=8/.pt, label="Wasserstein\nDistance", family="Helvetica") +
- theme_void()
- pl$jacc_lab <- ggplot() +
- annotate("text", x=0, y=0, size=8/.pt, label="Jaccard\nDistance", family="Helvetica") +
- theme_void()
- pl$jacc_ot_geom_poi_joined <- wrap_plots(list(
- pl$wass_lab + theme(plot.margin = margin(0, 0, 0, 0, "pt")),
- (pl$ot_geom_poi_res + labs(tag="a1") + theme(plot.margin = margin(0, 0, 0, 0, "pt"))),
- (pl$ot_geom_time_res + labs(tag="a2") + theme(plot.margin = margin(0, 10, 0, 0, "pt"))),
- pl$jacc_lab + theme(plot.margin = margin(0, 0, 0, 0, "pt")),
- (pl$jacc_geom_poi_res + labs(tag="b1") + theme(plot.margin = margin(0, 0, 0, 0, "pt"))),
- (pl$jacc_geom_time_res + labs(tag="b2") + theme(plot.margin = margin(0, 0, 0, 0, "pt")))
- ), widths=c(0.425, 0.75, 2)) +
- theme(plot.background = element_blank())
- ggsave(file.path("fig", "jacc_ot_geom_poi_joined.png"), pl$jacc_ot_geom_poi_joined, width=6, height=3.5, device="png", type="cairo")
- ggsave(file.path("fig", "jacc_ot_geom_poi_joined.pdf"), pl$jacc_ot_geom_poi_joined, width=6, height=3.5, device=cairo_pdf)
- # POI, P1, Timecourse
- # noise_ceiling_lines_df <- tibble(
- # x=rep(c(p1_window, poi_window), each=2),
- # y=rep(c(0.325, max(rho_limits)), 4)
- # )
- pl$jacc_ot_geom_poi_p1_joined <- (
- pl$ot_geom_p1_res + labs(title=sprintf("Wasserstein Distance\n\n%s ms", paste(p1_window, collapse="-")), tag="a") + theme(plot.title=element_text(hjust=0, size=8, vjust=0)) |
- pl$ot_geom_poi_res + labs(title=sprintf("%s ms", paste(poi_window, collapse="-"))) + theme(plot.title=element_text(hjust=0, size=8, vjust=0)) |
- pl$jacc_geom_p1_res + labs(title=sprintf("Jaccard Distance\n\n%s ms", paste(p1_window, collapse="-")), tag="b") + theme(plot.title=element_text(hjust=0, size=8, vjust=0)) |
- pl$jacc_geom_poi_res + labs(title=sprintf("%s ms", paste(poi_window, collapse="-"))) + theme(plot.title=element_text(hjust=0, size=8, vjust=0))
- ) /
- (
- pl$ot_geom_time_res +
- labs(title="Time-Resolved") +
- theme(plot.margin = margin(10, 10, 0, 0, "pt"), plot.title=element_text(hjust=0, size=8)) |
- pl$jacc_geom_time_res +
- labs(title="Time-Resolved") +
- theme(plot.margin = margin(10, 10, 0, 0, "pt"), plot.title=element_text(hjust=0, size=8))
- ) +
- plot_layout(heights = c(3, 3)) +
- theme(plot.background = element_blank())
- ggsave(file.path("fig", "jacc_ot_geom_poi_p1_joined.png"), pl$jacc_ot_geom_poi_p1_joined, width=6, height=3, device="png", type="cairo")
- ggsave(file.path("fig", "jacc_ot_geom_poi_p1_joined.pdf"), pl$jacc_ot_geom_poi_p1_joined, width=6, height=3, device=cairo_pdf)
- # ANN results -------------------------------------------------------------
- ann_palette <- viridisLite::plasma
- ann_time_cor_res <- readRDS(file.path("estimates", "ANNs_time_resolved.rds")) |>
- ungroup() |>
- select(-starts_with(".")) |>
- mutate(
- time_ms = as.numeric(as.character(time)) * 1000,
- training_label = factor(case_when(
- grepl("_imagenet_letters$", model) ~ "Imagenet\n+ Letters",
- grepl("_imagenet$", model) ~ "Imagenet",
- grepl("_letters$", model) ~ "Letters"
- ), levels = c("Letters", "Imagenet", "Imagenet\n+ Letters")),
- model_label = factor(case_when(
- grepl("resnet50", model) ~ "ResNet-50",
- grepl("cornet-z", model) ~ "CORnet-Z"
- ), levels = c("ResNet-50", "CORnet-Z")),
- layer_label = sub("^layer", "", layer),
- layer_level = case_match(
- layer_label,
- "V1" ~ 1,
- "V2" ~ 2,
- "V4" ~ 3,
- "IT" ~ 4,
- .default = as.numeric(layer_label)
- )
- )
- ann_poi_res <- readRDS(file.path("estimates", "ANNs_poi_draws.rds")) |>
- ungroup() |>
- select(-starts_with(".")) |>
- mutate(
- training_label = factor(case_when(
- grepl("_imagenet_letters$", model) ~ "Imagenet\n+ Letters",
- grepl("_imagenet$", model) ~ "Imagenet",
- grepl("_letters$", model) ~ "Letters"
- ), levels = c("Letters", "Imagenet", "Imagenet\n+ Letters")),
- model_label = factor(case_when(
- grepl("resnet50", model) ~ "ResNet-50",
- grepl("cornet-z", model) ~ "CORnet-Z"
- ), levels = c("ResNet-50", "CORnet-Z")),
- layer_label = sub("^layer", "", layer),
- layer_level = case_match(
- layer_label,
- "V1" ~ 1,
- "V2" ~ 2,
- "V4" ~ 3,
- "IT" ~ 4,
- .default = as.numeric(layer_label)
- )
- )
- ann_p1_res <- readRDS(file.path("estimates", "ANNs_p1_draws.rds")) |>
- ungroup() |>
- select(-starts_with(".")) |>
- mutate(
- training_label = factor(case_when(
- grepl("_imagenet_letters$", model) ~ "Imagenet\n+ Letters",
- grepl("_imagenet$", model) ~ "Imagenet",
- grepl("_letters$", model) ~ "Letters"
- ), levels = c("Letters", "Imagenet", "Imagenet\n+ Letters")),
- model_label = factor(case_when(
- grepl("resnet50", model) ~ "ResNet-50",
- grepl("cornet-z", model) ~ "CORnet-Z"
- ), levels = c("ResNet-50", "CORnet-Z")),
- layer_label = sub("^layer", "", layer),
- layer_level = case_match(
- layer_label,
- "V1" ~ 1,
- "V2" ~ 2,
- "V4" ~ 3,
- "IT" ~ 4,
- .default = as.numeric(layer_label)
- )
- )
- ann_p1_pl_list <- lapply(levels(ann_time_cor_res$model_label), function(m_i) {
- d_i <- ann_p1_res |>
- filter(model_label == m_i)
-
- n_layers <- length(unique(d_i$layer))
- ann_palette_i <- ann_palette(n=n_layers, end=0.85)
-
- if (n_layers > 4) {
- xscale <- scale_x_discrete(labels = ~ifelse(grepl("\\.0$", .x), .x, ""))
- interval_size_range_ann <- c(0.75, 1.5)
- } else {
- xscale <- scale_x_discrete()
- interval_size_range_ann <- c(0.75, 2.5)
- }
-
- d_i |>
- mutate(layer_level = as.factor(layer_level)) |>
- arrange(layer_level) |>
- mutate(layer_label = factor(layer_label, levels=unique(layer_label))) |>
- ggplot(aes(x=layer_label, y=rho, colour=layer_label)) +
- annotate(geom="rect", ymin=noise_ceiling_p1$lwr, ymax=noise_ceiling_p1$upr, xmin=-Inf, xmax=Inf, colour=NA, fill="lightgrey") +
- stat_pointinterval(point_interval="median_hdi", .width=c(.5, .89), shape=NA, interval_size_range=interval_size_range_ann) +
- geom_hline(yintercept=0, linewidth=axis_linewidth) +
- # scale_colour_viridis_d() +
- scale_colour_manual(values=ann_palette_i) +
- scale_y_continuous(limits=rho_limits, expand=c(0,0)) +
- xscale +
- labs(
- x = "Layer",
- y = "ρ"
- ) +
- theme(
- legend.position = "none",
- strip.text = element_blank(),
- # axis.text.x = element_text(hjust=1, vjust=0.5, angle=90),
- axis.line.x = element_blank()
- ) +
- facet_grid(rows = vars(training_label))
- })
- ann_poi_pl_list <- lapply(levels(ann_time_cor_res$model_label), function(m_i) {
- d_i <- ann_poi_res |>
- filter(model_label == m_i)
-
- n_layers <- length(unique(d_i$layer))
- ann_palette_i <- ann_palette(n=n_layers, end=0.85)
-
- if (n_layers > 4) {
- xscale <- scale_x_discrete(labels = ~ifelse(grepl("\\.0$", .x), .x, ""))
- interval_size_range_ann <- c(0.75, 1.5)
- } else {
- xscale <- scale_x_discrete()
- interval_size_range_ann <- c(0.75, 2.5)
- }
-
- d_i |>
- mutate(layer_level = as.factor(layer_level)) |>
- arrange(layer_level) |>
- mutate(layer_label = factor(layer_label, levels=unique(layer_label))) |>
- ggplot(aes(x=layer_label, y=rho, colour=layer_label)) +
- annotate(geom="rect", ymin=noise_ceiling_poi$lwr, ymax=noise_ceiling_poi$upr, xmin=-Inf, xmax=Inf, colour=NA, fill="lightgrey") +
- stat_pointinterval(point_interval="median_hdi", .width=c(.5, .89), shape=NA, interval_size_range=interval_size_range_ann) +
- geom_hline(yintercept=0, linewidth=axis_linewidth) +
- # scale_colour_viridis_d() +
- scale_colour_manual(values=ann_palette_i) +
- scale_y_continuous(limits=rho_limits, expand=c(0,0)) +
- xscale +
- labs(
- x = "Layer",
- y = "ρ"
- ) +
- theme(
- legend.position = "none",
- strip.text = element_blank(),
- # axis.text.x = element_text(hjust=1, vjust=0.5, angle=90),
- axis.line.x = element_blank(),
- axis.line.y = element_blank(),
- axis.title.y = element_blank(),
- axis.text.y = element_blank(),
- axis.ticks.y = element_blank()
- ) +
- facet_grid(rows = vars(training_label))
- })
- ann_time_res_pl_list <- lapply(levels(ann_time_cor_res$model_label), function(m_i) {
- d_i <- ann_time_cor_res |>
- filter(model_label == m_i)
-
- n_layers <- length(unique(d_i$layer))
- ann_palette_i <- ann_palette(n=n_layers, end=0.85)
-
- if (n_layers > 4) {
- linewidth_i <- 0.25
- } else {
- linewidth_i <- 0.5
- }
-
- d_i |>
- mutate(layer_level = as.factor(layer_level)) |>
- arrange(layer_level) |>
- mutate(layer_label = factor(layer_label, levels=unique(layer_label))) |>
- ggplot() +
- geom_ribbon(aes(x=time_ms, ymin=lwr, ymax=upr), data=mutate(noise_ceiling_time, time_ms=time*1000, Rho=NA), fill="lightgrey") +
- geom_line(aes(time_ms, rho, colour=layer_level), linewidth=linewidth_i) +
- geom_vline(xintercept=0, linewidth=axis_linewidth) +
- geom_hline(yintercept=0, linewidth=axis_linewidth) +
- scale_x_continuous(expand=c(0,0), limits=c(-200, 1000), breaks=seq(-200, 1000, 200)) +
- scale_y_continuous(limits=rho_limits, expand=c(0,0)) +
- scale_colour_manual(values=ann_palette_i) +
- facet_grid(rows=vars(training_label)) +
- labs(
- x = "Time (ms)",
- y = "ρ"
- ) +
- theme(
- axis.line = element_blank(),
- legend.position = "none",
- plot.margin = margin(0, 0, 0, 10, "pt"),
- axis.line.y = element_blank(),
- axis.title.y = element_blank(),
- axis.text.y = element_blank(),
- axis.ticks.y = element_blank()
- )
- })
- pl$ann_joined <- wrap_plots(list(
- ann_p1_pl_list[[1]] + labs(tag="a", title=sprintf("ResNet-50\n\n%s ms", paste(p1_window, collapse="-"))) + theme(plot.title=element_text(hjust=0, size=8)),
- ann_poi_pl_list[[1]] + labs(title=sprintf("\n\n%s ms", paste(poi_window, collapse="-"))) + theme(plot.title=element_text(hjust=0, size=8)),
- ann_time_res_pl_list[[1]] + labs(title="\n\nTime-Resolved") + theme(plot.title=element_text(hjust=0, size=8)),
- ann_p1_pl_list[[2]] + labs(tag="b", title=sprintf("CORnet-Z\n\n%s ms", paste(p1_window, collapse="-"))) + theme(plot.title=element_text(hjust=0, size=8)),
- ann_poi_pl_list[[2]] + labs(title=sprintf("\n\n%s ms", paste(poi_window, collapse="-"))) + theme(plot.title=element_text(hjust=0, size=8)),
- ann_time_res_pl_list[[2]] + labs(title="\n\nTime-Resolved") + theme(plot.title=element_text(hjust=0, size=8))
- ), nrow=2, ncol=3, widths=c(1, 1, 2))
- ggsave(file.path("fig", "ANN_joined.png"), pl$ann_joined, width=5.5, height=6, device="png", type="cairo")
- ggsave(file.path("fig", "ANN_joined.pdf"), pl$ann_joined, width=5.5, height=6, device=cairo_pdf)
- # controls analysis -------------------------------------------------------
- controls_poi_res <- readRDS(file.path("estimates", "controls_cor_samps_long.rds")) |>
- ungroup() |>
- select(-starts_with(".")) |>
- mutate(
- model = case_match(
- model,
- "Jaccard Distance" ~ "Jaccard Distance",
- "Wasserstein Distance" ~ "Wasserstein Distance",
- "Complexity Distance" ~ "Visual Size Distance",
- "Frequency Distance" ~ "Letter Frequency\nDistance",
- "Phonological Distance" ~ "Dominant Phoneme\nPhonological Distance",
- "Letter Name Phonological Distance" ~ "Letter Name\nPhonological Distance"
- )
- ) |>
- mutate(
- partialness = factor(ifelse(is_partial, "Partial\nCorrelations", "Correlations"), levels=c("Correlations", "Partial\nCorrelations")),
- model = factor(model, levels=c(
- "Jaccard Distance",
- "Wasserstein Distance",
- "Visual Size Distance",
- "Letter Frequency\nDistance",
- "Dominant Phoneme\nPhonological Distance",
- "Letter Name\nPhonological Distance"
- ))
- )
- controls_time_res <- readRDS(file.path("estimates", "controls_timecourse.rds")) |>
- ungroup() |>
- select(-starts_with(".")) |>
- pivot_longer(cols = c(starts_with("cor_"), starts_with("pcor")), names_to="cor_par", values_to="Rho") |>
- mutate(
- partialness = factor(
- ifelse(grepl("^pcor_", cor_par), "Partial Correlations", "Correlations"),
- levels = c("Correlations", "Partial Correlations")
- ),
- model = factor(case_when(
- grepl("rankjacc$", cor_par) ~ "Jaccard Distance",
- grepl("rankot$", cor_par) ~ "Wasserstein Distance",
- grepl("rankcompdist$", cor_par) ~ "Visual Size Distance",
- grepl("rankfreqdist$", cor_par) ~ "Letter Frequency\nDistance",
- grepl("rankphondist$", cor_par) ~ "Dominant Phoneme\nPhonological Distance",
- grepl("ranknamephondist$", cor_par) ~ "Letter Name\nPhonological Distance",
- ), levels=c(
- "Jaccard Distance",
- "Wasserstein Distance",
- "Visual Size Distance",
- "Letter Frequency\nDistance",
- "Dominant Phoneme\nPhonological Distance",
- "Letter Name\nPhonological Distance"
- ))
- ) |>
- select(-cor_par) |>
- group_by(time, model, partialness) |>
- median_hdi(Rho, .width=.89)
- controls_all_chs_poi_res <- readRDS(file.path("estimates", "controls_all_chs_cor_samps_long.rds")) |>
- ungroup() |>
- select(-starts_with(".")) |>
- mutate(
- model = case_match(
- model,
- "Jaccard Distance" ~ "Jaccard Distance",
- "Wasserstein Distance" ~ "Wasserstein Distance",
- "Complexity Distance" ~ "Visual Size Distance",
- "Frequency Distance" ~ "Letter Frequency\nDistance",
- "Phonological Distance" ~ "Dominant Phoneme\nPhonological Distance",
- "Letter Name Phonological Distance" ~ "Letter Name\nPhonological Distance"
- )
- ) |>
- mutate(
- partialness = factor(ifelse(is_partial, "Partial\nCorrelations", "Correlations"), levels=c("Correlations", "Partial\nCorrelations")),
- model = factor(model, levels=c(
- "Jaccard Distance",
- "Wasserstein Distance",
- "Visual Size Distance",
- "Letter Frequency\nDistance",
- "Dominant Phoneme\nPhonological Distance",
- "Letter Name\nPhonological Distance"
- ))
- )
- controls_all_chs_time_res <- readRDS(file.path("estimates", "controls_all_chs_timecourse.rds")) |>
- ungroup() |>
- select(-starts_with(".")) |>
- pivot_longer(cols = c(starts_with("cor_"), starts_with("pcor")), names_to="cor_par", values_to="Rho") |>
- mutate(
- partialness = factor(
- ifelse(grepl("^pcor_", cor_par), "Partial Correlations", "Correlations"),
- levels = c("Correlations", "Partial Correlations")
- ),
- model = factor(case_when(
- grepl("rankjacc$", cor_par) ~ "Jaccard Distance",
- grepl("rankot$", cor_par) ~ "Wasserstein Distance",
- grepl("rankcompdist$", cor_par) ~ "Visual Size Distance",
- grepl("rankfreqdist$", cor_par) ~ "Letter Frequency\nDistance",
- grepl("rankphondist$", cor_par) ~ "Dominant Phoneme\nPhonological Distance",
- grepl("ranknamephondist$", cor_par) ~ "Letter Name\nPhonological Distance",
- ), levels=c(
- "Jaccard Distance",
- "Wasserstein Distance",
- "Visual Size Distance",
- "Letter Frequency\nDistance",
- "Dominant Phoneme\nPhonological Distance",
- "Letter Name\nPhonological Distance"
- ))
- ) |>
- select(-cor_par) |>
- group_by(time, model, partialness) |>
- median_hdi(Rho, .width=.89)
- # colourblind friendly palette combining the original colours with some Okabe-Ito colours
- controls_colours <- c(
- measure_cols,
- "Visual Size Distance" = "#009E73",
- "Letter Frequency\nDistance" = "#F0E442",
- "Dominant Phoneme\nPhonological Distance" = "#CC79A7",
- "Letter Name\nPhonological Distance" = "#56B4E9"
- )
- rho_limits_controls <- c(-0.15, max(c(noise_ceiling_time$upr, noise_ceiling_poi$upr, noise_ceiling_time_all_chs$upr, noise_ceiling_poi_all_chs$upr)))
- dummy_controls_df <- controls_poi_res |>
- select(cor_lab, model, is_partial, partialness) |>
- distinct() |>
- mutate(Rho = 0)
- nc_ch_grps <- bind_rows(
- mutate(noise_ceiling_poi, chs_grp="post"),
- mutate(noise_ceiling_poi_all_chs, chs_grp="all")
- ) |>
- mutate(chs_grp = factor(
- chs_grp,
- levels=c("post", "all"),
- labels=c("150-225 ms\n\nPosterior Channels", "\n\nAll Channels")
- ))
- nc_time_ch_grps <- bind_rows(
- mutate(noise_ceiling_time, chs_grp="post"),
- mutate(noise_ceiling_time_all_chs, chs_grp="all")
- ) |>
- mutate(
- chs_grp = factor(
- chs_grp,
- levels=c("post", "all"),
- labels=c("Time-Resolved\n\nPosterior Channels", "\n\nAll Channels")
- )
- ) |>
- mutate(partialness="Correlations") %>%
- bind_rows(
- .,
- mutate(., partialness="Partial Correlations")
- ) |>
- mutate(partialness = factor(partialness, levels=c("Correlations", "Partial Correlations")))
- pl$controls_poi <- bind_rows(
- mutate(controls_poi_res, chs_grp="post"),
- mutate(controls_all_chs_poi_res, chs_grp="all")
- ) |>
- mutate(
- chs_grp = factor(
- chs_grp,
- levels=c("post", "all"),
- labels=c("150-225 ms\n\nPosterior Channels", "\n\nAll Channels")
- ),
- interval_yloc = as.numeric(forcats::fct_rev(model))-0.15
- ) |>
- ggplot() +
- geom_rect(aes(xmin=lwr, xmax=upr, ymin=-Inf, ymax=Inf), colour=NA, fill="lightgrey", data=nc_ch_grps) +
- geom_vline(xintercept=0, linewidth=axis_linewidth) +
- stat_slab(aes(Rho, model, colour=model, group=partialness, linetype=partialness), fill=NA, height=0.7, show.legend=FALSE, linewidth=axis_linewidth*1.25) +
- geom_vline(aes(xintercept=0, linetype=partialness), key_glyph="path", data=dummy_controls_df, alpha=0) +
- stat_pointinterval(aes(Rho, model, colour=model, group=partialness, y=interval_yloc), point_interval="median_hdi", .width=c(.5, .89), shape=NA, position=position_dodgev(height=0.25)) +
- scale_colour_manual(values=controls_colours, guide="none") +
- scale_linetype(guide=guide_legend(override.aes = list(alpha=1))) +
- labs(
- x = "ρ",
- y = "Model",
- linetype = NULL
- ) +
- scale_x_continuous(expand=c(0,0)) +
- scale_y_discrete(limits=rev, expand=c(0,0)) +
- facet_grid(cols=vars(chs_grp)) +
- theme(
- legend.position = "inside",
- legend.position.inside = c(1, 1),
- legend.justification = c(1, 1),
- # legend.background = element_blank(),
- legend.margin = margin(2, 2, 2, 2, "pt"),
- legend.key = element_blank(),
- strip.text = element_text(hjust=0)
- )
- pl$controls_timecourse <- bind_rows(
- mutate(controls_time_res, chs_grp="post"),
- mutate(controls_all_chs_time_res, chs_grp="all")
- ) |>
- mutate(chs_grp = factor(
- chs_grp,
- levels=c("post", "all"),
- labels=c("Time-Resolved\n\nPosterior Channels", "\n\nAll Channels")
- )) |>
- mutate(time_ms = time*1000) |>
- ggplot(aes(time_ms, Rho, colour=model)) +
- geom_ribbon(aes(x=time_ms, ymin=lwr, ymax=upr), colour=NA, data=mutate(nc_time_ch_grps, time_ms=time*1000, Rho=1, model=NA), fill="lightgrey") +
- geom_hline(yintercept=0, linewidth=axis_linewidth) +
- geom_vline(xintercept=0, linewidth=axis_linewidth) +
- # geom_ribbon(aes(time_ms, Rho, fill=model, group=model, ymin=.lower, ymax=.upper), alpha=0.4, colour=NA) +
- geom_line() +
- scale_colour_manual(values=controls_colours, guide="none") +
- # scale_fill_manual(values=controls_colours, guide="none") +
- scale_x_continuous(expand=c(0,0), limits=c(-200, 1000), breaks=seq(-200, 1000, 200)) +
- scale_y_continuous(limits=rho_limits_controls, expand=c(0,0)) +
- labs(
- x = "Time (ms)",
- y = "ρ"
- ) +
- facet_grid(cols=vars(chs_grp), rows=vars(partialness)) +
- theme(
- axis.line.x = element_blank(),
- axis.line.y = element_blank(),
- strip.text.x.top = element_text(hjust=0),
- panel.spacing.x = unit(20, "pt")
- )
- pl$controls <- plot_grid(
- pl$controls_poi + labs(tag="a"),
- pl$controls_timecourse + labs(tag="b"),
- nrow=2, rel_heights=c(1, 1)
- )
- ggsave(file.path("fig", "controls.png"), pl$controls, width=6, height=6.5, device="png", type="cairo")
- ggsave(file.path("fig", "controls.pdf"), pl$controls, width=6, height=6.5, device=cairo_pdf)
- # sensitivity analysis ----------------------------------------------------
- sens_res <- readRDS(file.path("estimates", "sensitivity_lkj_prior.rds")) |>
- mutate(
- partialness = factor(
- ifelse(is_partial, "Partial Correlations", "Correlations"),
- levels = c("Correlations", "Partial Correlations")
- )
- )
- sens_xbreaks <- 1 * 10 ** seq(-3, 3, 1)
- pl$sens_post <- sens_res |>
- ggplot(aes(eta, Rho, colour=model, fill=model)) +
- geom_ribbon(aes(ymin=.lower, ymax=.upper), colour=NA, alpha=0.4) +
- geom_line() +
- geom_vline(xintercept = 1.5, linetype="dashed") +
- facet_grid(cols=vars(partialness)) +
- scale_colour_manual(values=measure_cols) +
- scale_fill_manual(values=measure_cols) +
- scale_x_continuous(
- trans="log10", breaks=sens_xbreaks,
- limits=c(min(sens_xbreaks), max(sens_xbreaks)),
- expand=c(0,0)
- ) +
- scale_y_continuous(
- breaks = seq(-0.06, 0.06, 0.02),
- limits = c(-0.06, NA)
- ) +
- theme(
- legend.position = "bottom",
- legend.position.inside = c(0.3, 0.1),
- legend.key.height = unit(10, units="pt"),
- legend.margin = margin(0,0,0,0),
- panel.spacing.x = unit(25, "pt"),
- plot.margin = margin(0,12.5,0,0, unit="pt")
- ) +
- labs(
- x = "LKJ η Prior",
- y = "ρ",
- colour = NULL,
- fill = NULL,
- tag = "b"
- )
- pl$sens_priors <- tibble(
- eta = 1 * 10 ** seq(-2, 2, 1),
- prior_string = sprintf("lkjcorr(%g)", eta),
- prior_label = sprintf("η=%s", scales::scientific(eta, digits=2))
- ) |>
- arrange(eta) |>
- mutate(
- prior_label = factor(prior_label, levels=unique(prior_label))
- ) |>
- parse_dist(prior_string) |>
- marginalize_lkjcorr(K = 3) |>
- ggplot(aes(xdist = .dist_obj)) +
- stat_slabinterval(point_interval="median_hdi", .width=c(0.5, 0.89), justification=-0.05, shape="|", slab_colour="black", slab_linewidth=0.5) +
- scale_x_continuous(limits=c(-1, 1)) +
- facet_grid(cols=vars(prior_label)) +
- labs(
- x = "ρ",
- y = "Density",
- tag = "a"
- ) +
- theme(
- axis.ticks.y = element_blank(),
- axis.text.y = element_blank(),
- panel.spacing.x = unit(12, "pt")
- )
- pl$sens <- pl$sens_priors / pl$sens_post
- ggsave(file.path("fig", "sensitivity_analysis.pdf"), pl$sens, width=6.4, height=4, device=cairo_pdf)
- ggsave(file.path("fig", "sensitivity_analysis.png"), pl$sens, width=6.4, height=4, device="png", type="cairo")
- # tibble(prior_string = "lkjcorr(1.5)") |>
- # parse_dist(prior_string) |>
- # marginalize_lkjcorr(K = 3) |>
- # ggplot(aes(xdist = .dist_obj)) +
- # stat_slabinterval(point_interval="median_hdi", .width=c(0.5, 0.89), justification=-0.05, shape="|", slab_colour="black", slab_linewidth=0.5) +
- # scale_x_continuous(limits=c(-1, 1)) +
- # labs(
- # x = "ρ",
- # y = "Density"
- # ) +
- # theme(
- # axis.ticks.y = element_blank(),
- # axis.text.y = element_blank(),
- # panel.spacing.x = unit(12, "pt")
- # )
|