The central idea with modeling discrete heterogeneity in multilevel models is that the variation is assumed to arise within the context of a dependency structure. Under a complete pooling model, the variation is ignored and we assume that the effect is captured within a single parameter such as modeling a treatment effect across groups. For example, consider the model
where \(p\) parameterises a bernoulli observational model conditional on some exposure \(x\) and we are taking a monolithic perspective on the treatment effect. To account for discrete heterogeneity we would replace \(\beta\) with \(\{ \beta_1, \beta_2, \dots \beta_K \}\). However, to complete the model we neeed to place a prior on \(\beta_k\). Under the assumption of no latent interactions (as in no pooling) we would adopt \(\pi(\beta_1, \dots, \beta_K) = \pi_1(\beta_1)\dots\pi_K(\beta_K)\). Under a partial pooling perspective, we would assume some structure to the prior.
REMAP-CAP uses a multilevel perspective for modeling treatment effects within a domain, across strata, that I have translated to the ROADMAP context although I have made some simplifications for the purposes of illustration.
Assume the following structure for the linear predictor
where the second term represents a joint specific effect of exposure \(x\) within the first domain (I am ignoring silo for the moment and treating everything as generic).
The parameters and priors are structured as follows
Parameters and priors
Parameter
Domain
Joint
Prior
Note
\(\beta_{[d1[1];j]}\)
Surgery
knee,hip
0
Randomised DAIR (ref for rand comparisons)
\(\beta_{[d1[2];j]}\)
Surgery
knee,hip
\(N(\mu_{[d1[2+3];j]}, \sigma_{[d1[2+3]]})\)
Randomised REV (one-stage selected)
\(\beta_{[d1[3];j]}\)
Surgery
knee,hip
\(N(\mu_{[d1[2+3];j]}, \sigma_{[d1[2+3]]})\)
Randomised REV (two-stage selected)
\(\mu_{[d1[2+3];j]}\)
Surgery
knee,hip
\(N(\mu_{[d1[2+3]]}, \tau_{[d1[2+3]]})\)
Mean revision effect by joint
\(\mu_{[d1[2+3]]}\)
Surgery
knee,hip
\(N(0, 1)\)
Overall mean of revision effect
\(\tau_{[d1[2+3]]}\)
Surgery
\(\text{Student}(3,0,2)\)
SD across joint variation
\(\sigma_{[d1[2+3]]}\)
Surgery
\(\text{Student}(3,0,2)\)
SD within joint variation
The idea is to produce a structure whereby estimation of the variance components produces dynamic shrinkage of the treatment effect estimates.
Revision effects
Specifically, when the variance components are estimated to be large then more variation in the strata specific treatment effects is permitted. However, to estimate the variance components well, you will need multiple groups (more than two).
Example - large number of groups and exposure levels
For the sake of an example, assume you have a setting with 8 treatment arms to which patients are randomised and within each treatment arm are a mix of participants with respect to some characteristic that could influence the response. I am only using this many treatment arms because it allows us to see a clear difference between the different analysis approaches. We are interested in the overall treatment effect and heterogeneity due to some membership to a group.
Within this setting there are multiple levels of variation. There is the between treatment arm variation that characterises how different the treatment arms are. There is also the within treatment arm variation due to the subgroup.
We could adopt a range of assumptions to model the responses for each combination of treatment and subgroup.
Model the responses for each treatment arm by subgroup combination independently; no information is shared between any of the combinations. This approach will recover the observed point estimates and could be achieved using a logistic regression to estimate each combination’s mean response.
Model the treatment groups independently, but estimate the variation within each treatment arm due to subgroup membership by sharing a common variance parameter across all the treatment arms. This will reflect the observed mean response in each treatment arm but shrink the subgroup variation towards these means. It will only do this if there is sufficient data to inform the relevant variance estimate.
Model both the between treatment arm variation and the within treatment arm variation by partitioning the total variation. This will tend to shrink the treatment arm means towards an overall mean and the subgroup estimates towards each treatment arm mean. The within group variation could be modelled for each treatment arm independently or be shared across all treatments.
Plus variations on these themes. All of the variance partition approaches require that there are sufficient groups to informed the variance parameters.
Here I assume the following represents the true model
Assume that the true treatment group means are normally distributed around some non-zero mean with standard deviation \(s\) and that the subgroup means are normally distributed around each treatment group mean with a common standard deviation \(s_j\).
Parameter specification
get_par <-function(n_grp =4, n_trt =9,mu =1,s =0.1, s_j =0.3 ){ l <-list() l$n_grp <- n_grp l$n_trt <- n_trt# overall mean effect across all intervention types l$mu <- mu# between intervention type variation l$s <- s# intervention type specific mean l$mu_j <- l$mu +rnorm(n_trt, 0, l$s)# within intervention variation attributable to group membership l$s_j <- s_j# trt x group effects l$mu_j_k <-do.call(rbind, lapply(seq_along(l$mu_j), function(i){rnorm(n_grp, l$mu_j[i], l$s_j) }))colnames(l$mu_j_k) <-paste0("strata", 1:ncol(l$mu_j_k))rownames(l$mu_j_k) <-paste0(1:nrow(l$mu_j_k)) l$d_par <-CJ(j =factor(1:l$n_trt, levels =1:l$n_trt),k =factor(1:l$n_grp, levels =1:l$n_grp) ) l$d_par[, mu := l$mu] l$d_par[, mu_j := l$mu_j[j]] l$d_par[, mu_j_k := l$mu_j_k[cbind(j,k)]] l}
Any single data set will not allow us to recover the parameters exactly, but the differences between the estimates from the various modelling assumptions is informative as to the general patterns that arise.
Data generation function
get_data <-function(N =2000, par =NULL,ff =function(par, j, k){ m1 <-cbind(j, k) eta = par$mu_j_k[m1] eta }){# strata d <-data.table()# intervention - even allocation d[, j :=sample(1:par$n_trt, size = N, replace = T)]# table(d$j)# uneven distribution of groups in the pop z <-rnorm(par$n_grp, 0, 0.5) d[, k :=sample(1:par$n_grp, size = N, replace = T, prob =exp(z)/sum(exp(z)))]# d[, k := sample(1:par$n_grp, size = N, replace = T)]# table(d$j, d$k) d[, eta :=ff(par, j, k)] d[, y :=rbinom(.N, 1, plogis(eta))] d }
Generate data assuming the parameters below with the underlying truth shown in Figure 1. The dashed line shows the overall mean response, the crosses show the treatment arm means and the points show the subgroup heterogeneity around the treatment arm means. Within this first setup, all the treatment arms have the same response and none of the subgroups show any treatment effect either.
Running MCMC with 2 parallel chains...
Chain 1 finished in 2.4 seconds.
Chain 2 finished in 2.4 seconds.
Both chains finished successfully.
Mean chain execution time: 2.4 seconds.
Total execution time: 2.6 seconds.
Figure 2 shows the estimated treatment arm means. The mlm correctly identifies the absence of between treatment variation and as a result, the means are pulled towards the grand mean.
Parameter estimates vs true values
d_1 <-data.table(t(f1$draws(variables ="eta", format ="matrix")))d_1 <-cbind(d, d_1)d_1 <-melt(d_1, id.vars =c("j", "k", "eta", "y"), variable.name ="i_draw")d_2 <-data.table(t(f2$draws(variables ="eta", format ="matrix")))d_2 <-cbind(d, d_2)d_2 <-melt(d_2, id.vars =c("j", "k", "eta", "y"), variable.name ="i_draw")d_3 <-data.table(t(f3$draws(variables ="eta", format ="matrix")))d_3 <-cbind(d, d_3)d_3 <-melt(d_3, id.vars =c("j", "k", "eta", "y"), variable.name ="i_draw")d_mu_j <-rbind( d_1[, .(mu =mean(value)), keyby = .(j, i_draw)][ , .(desc ="no pooling", mean =mean(mu), q5 =quantile(mu, prob =0.05), q95 =quantile(mu, prob =0.95)), keyby = j], d_2[, .(mu =mean(value)), keyby = .(j, i_draw)][ , .(desc ="partial pool (trt)", mean =mean(mu), q5 =quantile(mu, prob =0.05), q95 =quantile(mu, prob =0.95)), keyby = j], d_3[, .(mu =mean(value)), keyby = .(j, i_draw)][ , .(desc ="partial pool (trt+subgrp)", mean =mean(mu), q5 =quantile(mu, prob =0.05), q95 =quantile(mu, prob =0.95)), keyby = j])# MLEd_mle <-cbind(desc ="mle", d_lm[, .(mean =mean(eta_hat)), keyby = .(j)])d_mle <-merge( d_mle, d_lm_j[, .(q5 =quantile(value, prob =0.05),q95 =quantile(value, prob =0.95)), keyby = j], by ="j")d_mu_j <-rbind(d_mu_j, d_mle, fill = T)# Observed# Due to the imbalance between the subgroups, need to weight the# contributions to align (approx) with mle.# Easy here because only one other variable that we need to average# over.d_obs <-merge( d[, .(N_j = .N), keyby = .(j)], d[, .(mu =qlogis(mean(y)), .N), keyby = .(j, k)],by ="j")d_mu_j <-rbind( d_mu_j, d_obs[, .(desc ="observed", mean =sum(mu * N / N_j)), keyby = j] , fill = T )d_mu_j[, desc :=factor(desc, levels =c("observed", "mle", "no pooling", "partial pool (trt)", "partial pool (trt+subgrp)"))]d_mu <- d[, .(mu =qlogis(mean(y)), w = .N/nrow(d)), keyby = .(j, k)][ , .(desc ="observed", mean =sum(mu * w))]p_fig <-ggplot(d_mu_j, aes(x = j, y = mean, group = desc, col = desc)) +geom_point(position =position_dodge(width =0.6)) +geom_linerange(aes(ymin=q5, ymax=q95), position =position_dodge2(width =0.6)) +geom_hline(data = d_mu,aes(yintercept = mean, lty = desc),lwd =0.3, col =1) +scale_x_discrete("Treatment arm") +scale_y_continuous("log-odds treatment success", breaks =seq(-3, 3, by =0.2)) +scale_color_discrete("") +scale_linetype_discrete("") +geom_text(data = d[, .N, keyby = .(j)],aes(x = j, y =min(d_mu_j$q5, na.rm = T) -0.1, label = N), inherit.aes = F) +theme(legend.position="bottom", legend.box="vertical", legend.margin=margin())suppressWarnings(print(p_fig))
Figure 2: Parameter estimates vs true values
Figure 3 shows the estimated treatment by subgroup means. Similar to above, the mlm has determined that the within group variance is negligible and has pulled the subgroup estimates towards the treatment means. In contrast, the ML and independent models follow the data leading to suggest some material subgroup effects.
The mlm partitions the variance into a between treatment variance part that characterises the variation in the treatment arms and a within treatment variance that characterises the variation due to the subgroups.
In Figure 4 shows the prior (red) and the posterior (black) for the variance components (actually the standard deviations). Given the differentiation between the prior and posterior, it is clear that something has been learnt about the variation in the data and the posterior is well identified, although not fully concentrated on the true value of zero. Moreover, both the between group variation (variation due to treatment) and within group variation (variation due to subgroups) are small which leads to the dynamic shrinkage as shown in the subgroup level means.
Variance components (revision)
d_fig <-rbind(cbind(desc ="partial pool (trt)",data.table(f2$draws(variables =c("s_j"), format ="matrix"))),cbind(desc ="partial pool (trt+subgrp)",data.table(f3$draws(variables =c("s", "s_j"), format ="matrix")) ), fill = T)d_fig <-melt(d_fig, id.vars ="desc")d_fig[variable =="s", label :="variation b/w"]d_fig[variable =="s_j", label :="variation w/in"]d_pri <-CJ(variable =c("s", "s_j"),desc =c("partial pool (trt)", "partial pool (trt+subgrp)"),x =seq(min(d_fig$value, na.rm = T), max(d_fig$value, na.rm = T), len =500))d_pri[, y := fGarch::dstd(x, nu =3, mean =0, sd =2)]d_pri[variable =="s", label :="variation b/w"]d_pri[variable =="s_j", label :="variation w/in"]d_pri[desc =="partial pool (trt)"& variable =="s", y :=NA]# # d_smry <- d_fig[, .(mu = mean(value)), keyby = .(label)]p_fig <-ggplot(d_fig, aes(x = value, group = variable)) +geom_density() +geom_line(data = d_pri, aes(x = x, y = y), col =2, lwd =0.2 ) +scale_x_continuous("Standard deviation") +scale_y_continuous("Density") +facet_wrap(desc+label~., ncol =2)suppressWarnings(print(p_fig))
Figure 4: Between and within SD
Example - small number of groups and exposure levels
Now repeat the same exercise with a small number of groups and exposure levels, again we simulate the data assuming no effects are present.
Code
set.seed(1)par <-get_par(n_grp =2, n_trt =2, mu =1, s =0.0, s_j =0.0)d <-get_data(N =3000, par)
The posterior estimates for the variance components are now poorly informed by the data and therefore highly uncertain. Accordingly, the prior is having a much greater influence under this scenario.
Variance components (revision)
d_fig <-rbind(cbind(desc ="partial pool (trt)",data.table(f2$draws(variables =c("s_j"), format ="matrix"))),cbind(desc ="partial pool (trt+subgrp)",data.table(f3$draws(variables =c("s", "s_j"), format ="matrix")) ), fill = T)d_fig <-melt(d_fig, id.vars ="desc")d_fig[variable =="s", label :="variation b/w"]d_fig[variable =="s_j", label :="variation w/in"]d_pri <-CJ(variable =c("s", "s_j"),desc =c("partial pool (trt)", "partial pool (trt+subgrp)"),x =seq(min(d_fig$value, na.rm = T), max(d_fig$value, na.rm = T), len =500))d_pri[, y := fGarch::dstd(x, nu =3, mean =0, sd =2)]d_pri[variable =="s", label :="variation b/w"]d_pri[variable =="s_j", label :="variation w/in"]d_pri[desc =="partial pool (trt)"& variable =="s", y :=NA]p_fig <-ggplot(d_fig, aes(x = value, group = variable)) +geom_density() +geom_line(data = d_pri, aes(x = x, y = y), col =2, lwd =0.2 ) +scale_x_continuous("Standard deviation") +scale_y_continuous("Density") +facet_wrap(desc+label~., ncol =2)suppressWarnings(print(p_fig))
Figure 7: Between and within SD
Based on the above, it is unclear whether the additional complexity of an mlm is warranted when only two groups are available since the results are basically analogous to those of simpler approaches.
The results are also somewhat unsatisfying when assuming non-zero effects between treatments and within subgroups as shown below.
Code
set.seed(2)par <-get_par(n_grp =2, n_trt =2, mu =1, s =0.4, s_j =0.2)d <-get_data(N =3000, par)
Running MCMC with 2 parallel chains...
Chain 2 finished in 6.4 seconds.
Chain 1 finished in 7.4 seconds.
Both chains finished successfully.
Mean chain execution time: 6.9 seconds.
Total execution time: 7.5 seconds.
Warning: 81 of 2000 (4.0%) transitions ended with a divergence.
See https://mc-stan.org/misc/warnings for details.
For the treatment group means, the multi-level model produces estimates that are again basically equivalent to those of the simpler modelling approaches.
Parameter estimates vs true values
d_1 <-data.table(t(f1$draws(variables ="eta", format ="matrix")))d_1 <-cbind(d, d_1)d_1 <-melt(d_1, id.vars =c("j", "k", "eta", "y"), variable.name ="i_draw")d_2 <-data.table(t(f2$draws(variables ="eta", format ="matrix")))d_2 <-cbind(d, d_2)d_2 <-melt(d_2, id.vars =c("j", "k", "eta", "y"), variable.name ="i_draw")d_3 <-data.table(t(f3$draws(variables ="eta", format ="matrix")))d_3 <-cbind(d, d_3)d_3 <-melt(d_3, id.vars =c("j", "k", "eta", "y"), variable.name ="i_draw")d_mu_j <-rbind( d_1[, .(mu =mean(value)), keyby = .(j, i_draw)][ , .(desc ="no pooling", mean =mean(mu), q5 =quantile(mu, prob =0.05), q95 =quantile(mu, prob =0.95)), keyby = j], d_2[, .(mu =mean(value)), keyby = .(j, i_draw)][ , .(desc ="partial pool (trt)", mean =mean(mu), q5 =quantile(mu, prob =0.05), q95 =quantile(mu, prob =0.95)), keyby = j], d_3[, .(mu =mean(value)), keyby = .(j, i_draw)][ , .(desc ="partial pool (trt+subgrp)", mean =mean(mu), q5 =quantile(mu, prob =0.05), q95 =quantile(mu, prob =0.95)), keyby = j])# MLEd_mle <-cbind(desc ="mle", d_lm[, .(mean =mean(eta_hat)), keyby = .(j)])d_mle <-merge( d_mle, d_lm_j[, .(q5 =quantile(value, prob =0.05),q95 =quantile(value, prob =0.95)), keyby = j], by ="j")d_mu_j <-rbind(d_mu_j, d_mle, fill = T)# Observed# Due to the imbalance between the subgroups, need to weight the# contributions to align (approx) with mle.# Easy here because only one other variable that we need to average# over.d_obs <-merge( d[, .(N_j = .N), keyby = .(j)], d[, .(mu =qlogis(mean(y)), .N), keyby = .(j, k)],by ="j")d_mu_j <-rbind( d_mu_j, d_obs[, .(desc ="observed", mean =sum(mu * N / N_j)), keyby = j] , fill = T )d_mu_j[, desc :=factor(desc, levels =c("observed", "mle", "no pooling", "partial pool (trt)", "partial pool (trt+subgrp)"))]d_mu <- d[, .(mu =qlogis(mean(y)), w = .N/nrow(d)), keyby = .(j, k)][ , .(desc ="observed", mean =sum(mu * w))]p_fig <-ggplot(d_mu_j, aes(x = j, y = mean, group = desc, col = desc)) +geom_point(position =position_dodge(width =0.6)) +geom_linerange(aes(ymin=q5, ymax=q95), position =position_dodge2(width =0.6)) +geom_hline(data = d_mu,aes(yintercept = mean, lty = desc),lwd =0.3, col =1) +scale_x_discrete("Treatment arm") +scale_y_continuous("log-odds treatment success", breaks =seq(-3, 3, by =0.2)) +scale_color_discrete("") +scale_linetype_discrete("") +geom_text(data = d[, .N, keyby = .(j)],aes(x = j, y =min(d_mu_j$q5, na.rm = T) -0.1, label = N), inherit.aes = F) +theme(legend.position="bottom", legend.box="vertical", legend.margin=margin())suppressWarnings(print(p_fig))
Figure 8: Parameter estimates vs true values
For the subgroups, the estimates were shrunk towards the treatment group means for the data set simulated here.
However, while the variance is indicated as being possibly small, it is poorly informed by the data and therefore our uncertainty regarding these parameters is high.
Variance components (revision)
d_fig <-rbind(cbind(desc ="partial pool (trt)",data.table(f2$draws(variables =c("s_j"), format ="matrix"))),cbind(desc ="partial pool (trt+subgrp)",data.table(f3$draws(variables =c("s", "s_j"), format ="matrix")) ), fill = T)d_fig <-melt(d_fig, id.vars ="desc")d_fig[variable =="s", label :="variation b/w"]d_fig[variable =="s_j", label :="variation w/in"]d_pri <-CJ(variable =c("s", "s_j"),desc =c("partial pool (trt)", "partial pool (trt+subgrp)"),x =seq(min(d_fig$value, na.rm = T), max(d_fig$value, na.rm = T), len =500))d_pri[, y := fGarch::dstd(x, nu =3, mean =0, sd =2)]d_pri[variable =="s", label :="variation b/w"]d_pri[variable =="s_j", label :="variation w/in"]d_pri[desc =="partial pool (trt)"& variable =="s", y :=NA]# # d_smry <- d_fig[, .(mu = mean(value)), keyby = .(label)]p_fig <-ggplot(d_fig, aes(x = value, group = variable)) +geom_density() +geom_line(data = d_pri, aes(x = x, y = y), col =2, lwd =0.2 ) +scale_x_continuous("Standard deviation") +scale_y_continuous("Density") +facet_wrap(desc+label~., ncol =2)suppressWarnings(print(p_fig))
Figure 10: Between and within SD
We could certainly use a multi-level approach for ROADMAP, but given the small number of groups available, a fixed regularising prior might be more defensible and conceptually reasonable than a dynamic prior. The challenge would be to specify hyper-parameters such that are sufficiently informative to moderate extreme parameter estimates.