MLM vs static regularisation

Published

July 29, 2024

Modified

July 16, 2024

Provides a cursory examination of the differences between static regularisation via L1 and L2 penalty vs multilevel approach which achieves dynamic shrinkage. Definitely no free lunch. The model we consider simply examines discrete group means with differential allocation across the groups.

Parameter specification
get_par <- function(
    n_grp = 9,
    mu = 1,
    s_y = 0.3
    ){
  
  l <- list()
  l$n_grp <- n_grp

  # overall mean effect across all intervention types
  l$mu <- mu
  # within intervention variation attributable to group membership
  l$s_y <- s_y
  # intervention type specific mean
  l$mu_j <- rnorm(n_grp, 0, l$s_y)

  l$d_par <- CJ(
    j = factor(1:l$n_grp, levels = 1:l$n_grp)
  )
  l$d_par[, mu := l$mu]
  l$d_par[, mu_j := l$mu_j[j]]
  
  l
}
Data generation function
get_data <- function(
    N = 2000, 
    par = NULL,
    ff = function(par, j){
      eta = par$mu + par$mu_j[j] 
      eta
    }){
  
  # strata
  d <- data.table()
  
  # intervention - even allocation
  z <- rnorm(par$n_grp, 0, 0.5)
  d[, j := sample(1:par$n_grp, size = N, replace = T, prob = exp(z)/sum(exp(z)))]
  d[, eta := ff(par, j)]
  d[, y := rbinom(.N, 1, plogis(eta))]
  
  d  
}
Code
set.seed(1)
par <- get_par(n_grp = 3, mu = 1, s_y = 0.0)
# 100 people per group - will this overcome the prior?
n_per_grp <- 100
d <- get_data(N = n_per_grp * par$n_grp, par)
Code
m1 <- cmdstanr::cmdstan_model("stan/mlm-ex-04.stan")
m2 <- cmdstanr::cmdstan_model("stan/mlm-ex-05.stan")
m3 <- cmdstanr::cmdstan_model("stan/mlm-ex-06.stan")

ld <- list(
  N = nrow(d), 
  y = d$y,  J = length(unique(d$j)),  j = d$j, # intervention
  pri_s_norm = 1, 
  pri_s_exp = 1, 
  pri_r = 1/2, 
  # when alpha = 1, we have pure lasso regression (double exp)
  # when alpha = 0, we have pure ridge regression (normal with 1/lambda scale).
  pri_alpha = 0.6,
  pri_lambda = 1, # scale is 1/pri_lambda
  prior_only = 1
)

f1 <- m1$sample(
  ld, iter_warmup = 1000, iter_sampling = 5000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 1 finished in 0.1 seconds.
Chain 2 finished in 0.1 seconds.

Both chains finished successfully.
Mean chain execution time: 0.1 seconds.
Total execution time: 0.2 seconds.
Code
f2 <- m2$sample(
  ld, iter_warmup = 1000, iter_sampling = 5000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 1 finished in 0.1 seconds.
Chain 2 finished in 0.1 seconds.

Both chains finished successfully.
Mean chain execution time: 0.1 seconds.
Total execution time: 0.1 seconds.
Code
f3 <- m3$sample(
  ld, iter_warmup = 1000, iter_sampling = 5000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 1 finished in 0.1 seconds.
Chain 2 finished in 0.1 seconds.

Both chains finished successfully.
Mean chain execution time: 0.1 seconds.
Total execution time: 0.1 seconds.

The priors on the group level offsets. The elastic net priors are a preset combination of ridge and lasso regression.

Implied priors on the group level deviations from the grand mean
d_1 <- data.table(f1$draws(variables = "z_eta", format = "matrix"))
d_1 <- melt(d_1, measure.vars = names(d_1))
d_1[, desc := "unpooled"]

d_2 <- data.table(f2$draws(variables = "z_eta", format = "matrix"))
d_2 <- melt(d_2, measure.vars = names(d_2))
d_2[, desc := "elastic-net"]

d_3 <- data.table(f3$draws(variables = "z_eta", format = "matrix"))
d_3 <- melt(d_3, measure.vars = names(d_3))
d_3[, desc := "mlm"]

d_fig <- rbind(d_1, d_2, d_3)
d_fig[, desc := factor(desc, levels = c("unpooled", "elastic-net", "mlm"))]

x <- seq(min(d_fig$value), max(d_fig$value), len = 1000)
# for ridge
den_2a <- dnorm(x, 0, 1/ld$pri_lambda)
# for lasso
den_2b <- extraDistr::dlaplace(x, 0, 1/ld$pri_lambda)

d_sta <- rbind(
  data.table(desc = "elastic-net", mod = "ridge", x = x, y = den_2a),
  data.table(desc = "elastic-net", mod = "lasso", x = x, y = den_2b)
)
d_sta[, desc := factor(desc, levels = c("unpooled", "elastic-net", "mlm"))]

ggplot(d_fig, aes(x = value, group = variable)) +
  geom_line(
    data = d_sta[mod == "lasso"],
    aes(x = x, y = y, group = desc), inherit.aes = F,
    col = 2, lty = 2, lwd = 0.7) +
  geom_line(
    data = d_sta[mod == "ridge"],
    aes(x = x, y = y, group = desc), inherit.aes = F,
    col = 3, lty = 5, lwd = 0.7) +
  geom_density(lwd = 0.2) +
  facet_wrap(~desc, ncol = 1)
Figure 1: Implied priors on the group level deviations from the grand mean
Code
set.seed(1)
par <- get_par(n_grp = 10, mu = 1, s_y = 0.5)
# 100 people per group - will this overcome the prior?
n_per_grp <- 100
d <- get_data(N = n_per_grp * par$n_grp, par)
Code
ld <- list(
  N = nrow(d), 
  y = d$y,  J = length(unique(d$j)),  j = d$j, # intervention
  pri_s_norm = 2, 
  pri_s_exp = 1, 
  pri_r = 1/2,  
  # when alpha = 1, we have pure lasso regression (double exp)
  # when alpha = 0, we have pure ridge regression (normal with 1/lambda scale).
  pri_alpha = 0.9,
  pri_lambda = 0.76, # scale is 1/pri_lambda for both ridge and lasso
  prior_only = 0
)

f1 <- m1$sample(
  ld, iter_warmup = 1000, iter_sampling = 1000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 1 finished in 0.7 seconds.
Chain 2 finished in 0.7 seconds.

Both chains finished successfully.
Mean chain execution time: 0.7 seconds.
Total execution time: 0.8 seconds.
Code
f2 <- m2$sample(
  ld, iter_warmup = 1000, iter_sampling = 1000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 1 finished in 0.5 seconds.
Chain 2 finished in 0.5 seconds.

Both chains finished successfully.
Mean chain execution time: 0.5 seconds.
Total execution time: 0.6 seconds.
Code
f3 <- m3$sample(
  ld, iter_warmup = 1000, iter_sampling = 1000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 1 finished in 0.4 seconds.
Chain 2 finished in 0.5 seconds.

Both chains finished successfully.
Mean chain execution time: 0.4 seconds.
Total execution time: 0.5 seconds.

Treatment group posteriors. By adjusting \(\alpha\) and \(\lambda\) in the elastic net model we can a-priori determine the magnitude of regularisation we want to achieve. In practice, the quantities should be pre-specified and a sensitivity analysis on the main analysis based on varying these quantities.

Posterior inference on the group level means
d_0 <- d[, .(desc = "observed", eta = qlogis(mean(y))), keyby = j]

d_1 <- data.table(f1$draws(variables = "eta", format = "matrix"))
d_1 <- melt(d_1, measure.vars = names(d_1), value.name = "eta")
d_1[, desc := "unpooled"]

d_2 <- data.table(f2$draws(variables = "eta", format = "matrix"))
d_2 <- melt(d_2, measure.vars = names(d_2), value.name = "eta")
d_2[, desc := "elastic-net"]

d_3 <- data.table(f3$draws(variables = "eta", format = "matrix"))
d_3 <- melt(d_3, measure.vars = names(d_3), value.name = "eta")
d_3[, desc := "mlm"]

d_fig <- rbind(d_1, d_2, d_3)
d_fig[, j := gsub("eta[", "", variable, fixed = T)]
d_fig[, j := gsub("]", "", j, fixed = T)]
d_fig[, variable := NULL]
d_fig[, desc := factor(
  desc, levels = c("observed", "unpooled", "elastic-net", "mlm"))]
d_fig <- rbind(d_fig, d_0)

d_fig <- d_fig[, .(
  mu = mean(eta), 
  q5 = quantile(eta, prob = 0.05),
  q95 = quantile(eta, prob = 0.95)), keyby = .(desc, j)]

d_fig[, j := factor(j, levels = paste0(1:par$n_grp))]

ggplot(d_fig, aes(x = j, y = mu, group = desc, col = desc)) +
  geom_point(position = position_dodge(width = 0.6)) +
  geom_linerange(
    aes(ymin = q5, ymax = q95),
    position = position_dodge2(width = 0.6)) +
  scale_color_discrete("") +
  geom_text(data = d[, .N, keyby = .(j)],
            aes(x = j, y = min(d_fig$q5, na.rm = T) - 0.2, 
                label = N), inherit.aes = F) 
Figure 2: Posterior inference on the group level means
Code
set.seed(2)
par <- get_par(n_grp = 10, mu = 1, s_y = 0.5)
n_per_grp <- 500
d <- get_data(N = n_per_grp * par$n_grp, par)
Code
ld <- list(
  N = nrow(d), 
  y = d$y,  J = length(unique(d$j)),  j = d$j, # intervention
  pri_s_norm = 2, 
  pri_s_exp = 1, 
  pri_r = 1/2,  
  # when alpha = 1, we have pure lasso regression (double exp)
  # when alpha = 0, we have pure ridge regression (normal with 1/lambda scale).
  pri_alpha = 0.9,
  pri_lambda = 0.76, # scale is 1/pri_lambda for both ridge and lasso
  prior_only = 0
)

f1 <- m1$sample(
  ld, iter_warmup = 1000, iter_sampling = 1000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 1 finished in 3.7 seconds.
Chain 2 finished in 4.0 seconds.

Both chains finished successfully.
Mean chain execution time: 3.9 seconds.
Total execution time: 4.1 seconds.
Code
f2 <- m2$sample(
  ld, iter_warmup = 1000, iter_sampling = 1000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 2 finished in 3.0 seconds.
Chain 1 finished in 3.1 seconds.

Both chains finished successfully.
Mean chain execution time: 3.1 seconds.
Total execution time: 3.2 seconds.
Code
f3 <- m3$sample(
  ld, iter_warmup = 1000, iter_sampling = 1000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 1 finished in 3.4 seconds.
Chain 2 finished in 3.6 seconds.

Both chains finished successfully.
Mean chain execution time: 3.5 seconds.
Total execution time: 3.7 seconds.

However, note that larger data sets will still overwhelm the static regularisation as shown in Figure 3.

Posterior inference on the group level means
d_0 <- d[, .(desc = "observed", eta = qlogis(mean(y))), keyby = j]

d_1 <- data.table(f1$draws(variables = "eta", format = "matrix"))
d_1 <- melt(d_1, measure.vars = names(d_1), value.name = "eta")
d_1[, desc := "unpooled"]

d_2 <- data.table(f2$draws(variables = "eta", format = "matrix"))
d_2 <- melt(d_2, measure.vars = names(d_2), value.name = "eta")
d_2[, desc := "elastic-net"]

d_3 <- data.table(f3$draws(variables = "eta", format = "matrix"))
d_3 <- melt(d_3, measure.vars = names(d_3), value.name = "eta")
d_3[, desc := "mlm"]

d_fig <- rbind(d_1, d_2, d_3)
d_fig[, j := gsub("eta[", "", variable, fixed = T)]
d_fig[, j := gsub("]", "", j, fixed = T)]
d_fig[, variable := NULL]
d_fig[, desc := factor(
  desc, levels = c("observed", "unpooled", "elastic-net", "mlm"))]
d_fig <- rbind(d_fig, d_0)

d_fig <- d_fig[, .(
  mu = mean(eta), 
  q5 = quantile(eta, prob = 0.05),
  q95 = quantile(eta, prob = 0.95)), keyby = .(desc, j)]

d_fig[, j := factor(j, levels = paste0(1:par$n_grp))]

ggplot(d_fig, aes(x = j, y = mu, group = desc, col = desc)) +
  geom_point(position = position_dodge(width = 0.6)) +
  geom_linerange(
    aes(ymin = q5, ymax = q95),
    position = position_dodge2(width = 0.6)) +
  scale_color_discrete("") +
  geom_text(data = d[, .N, keyby = .(j)],
            aes(x = j, y = min(d_fig$q5, na.rm = T) - 0.2, 
                label = N), inherit.aes = F) 
Figure 3: Posterior inference on the group level means
Code
set.seed(7)
par <- get_par(n_grp = 2, mu = 1, s_y = 0.5)
n_per_grp <- 200
d <- get_data(N = n_per_grp * par$n_grp, par)
Code
ld <- list(
  N = nrow(d), 
  y = d$y,  J = length(unique(d$j)),  j = d$j, # intervention
  pri_s_norm = 2, 
  pri_s_exp = 1, 
  pri_r = 1/2,  
  # when alpha = 1, we have pure lasso regression (double exp)
  # when alpha = 0, we have pure ridge regression (normal with 1/lambda scale).
  pri_alpha = 0.9,
  pri_lambda = 0.76, # scale is 1/pri_lambda for both ridge and lasso
  prior_only = 0
)

f1 <- m1$sample(
  ld, iter_warmup = 1000, iter_sampling = 1000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 1 finished in 0.3 seconds.
Chain 2 finished in 0.3 seconds.

Both chains finished successfully.
Mean chain execution time: 0.3 seconds.
Total execution time: 0.4 seconds.
Code
f2 <- m2$sample(
  ld, iter_warmup = 1000, iter_sampling = 1000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 1 finished in 0.2 seconds.
Chain 2 finished in 0.2 seconds.

Both chains finished successfully.
Mean chain execution time: 0.2 seconds.
Total execution time: 0.3 seconds.
Code
f3 <- m3$sample(
  ld, iter_warmup = 1000, iter_sampling = 1000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 1 finished in 0.4 seconds.
Chain 2 finished in 0.4 seconds.

Both chains finished successfully.
Mean chain execution time: 0.4 seconds.
Total execution time: 0.4 seconds.

And in the case with smaller number of groups, there is probably immaterial differences between the approaches as shown in Figure 4.

Posterior inference on the group level means
d_0 <- d[, .(desc = "observed", eta = qlogis(mean(y))), keyby = j]

d_1 <- data.table(f1$draws(variables = "eta", format = "matrix"))
d_1 <- melt(d_1, measure.vars = names(d_1), value.name = "eta")
d_1[, desc := "unpooled"]

d_2 <- data.table(f2$draws(variables = "eta", format = "matrix"))
d_2 <- melt(d_2, measure.vars = names(d_2), value.name = "eta")
d_2[, desc := "elastic-net"]

d_3 <- data.table(f3$draws(variables = "eta", format = "matrix"))
d_3 <- melt(d_3, measure.vars = names(d_3), value.name = "eta")
d_3[, desc := "mlm"]

d_fig <- rbind(d_1, d_2, d_3)
d_fig[, j := gsub("eta[", "", variable, fixed = T)]
d_fig[, j := gsub("]", "", j, fixed = T)]
d_fig[, variable := NULL]
d_fig[, desc := factor(
  desc, levels = c("observed", "unpooled", "elastic-net", "mlm"))]
d_fig <- rbind(d_fig, d_0)

d_fig <- d_fig[, .(
  mu = mean(eta), 
  q5 = quantile(eta, prob = 0.05),
  q95 = quantile(eta, prob = 0.95)), keyby = .(desc, j)]

d_fig[, j := factor(j, levels = paste0(1:par$n_grp))]

ggplot(d_fig, aes(x = j, y = mu, group = desc, col = desc)) +
  geom_point(position = position_dodge(width = 0.6)) +
  geom_linerange(
    aes(ymin = q5, ymax = q95),
    position = position_dodge2(width = 0.6)) +
  scale_color_discrete("") +
  geom_text(data = d[, .N, keyby = .(j)],
            aes(x = j, y = min(d_fig$q5, na.rm = T) - 0.2, 
                label = N), inherit.aes = F) 
Figure 4: Posterior inference on the group level means