Stan implementation of AFT survival model with log-logistic event times

survival analysis
bayesian
Author

maj

Published

March 13, 2025

Modified

March 30, 2025

Code
library(data.table)
library(ggplot2)
library(gt)
suppressPackageStartupMessages(library(flexsurv))

In a parametric AFT model, the effect of covariates is to speed or slow down time.

\[ \begin{aligned} \log(T) = X\gamma + \text{error} \end{aligned} \]

Where:

and the error term is made up of a scale parameter \(\sigma\) and a random variable \(W\) with a specific distribution. In the usual setup, we observe the event/censoring indicator and the associated event or censoring time \(C\), with the event and censoring process assumed to be independent.

For the log-logistic model, the residual distribution is determined by the shape parameter. If \(\log(T) = X\gamma + \sigma W\) where \(W\) has a logistic distribution then \(T\) follows a log-logistic distribution with scale parameter \(\alpha = \exp(X\gamma)\) and shape parameter \(\beta = 1/\sigma\). For further reference see section 2.2.4 in [1], chapter 13 of [2], chapter 6 of [3] (possibly the clearest explanation) and [4].

The hazard function associated with log-logistic event times is hump-shaped, a bit like the log normal case but with longer tails. It initially increases, reaches a maximum and then decreases toward 0 as lifetimes become larger and larger. Definitions for the density function can be found in the stan docs: https://mc-stan.org/docs/functions-reference/positive_continuous_distributions.html#log-logistic-distribution and in the flexsurv help file, see ?flexsurv::dllogis. Unlike lognormal, the log-logistic has a closed form hazard function.

\[ \begin{aligned} f = \frac{(\beta/\alpha)(t/\alpha)^{\beta-1}}{(1 + (t/\alpha)^\beta)^2} \end{aligned} \]

with shape parameter \(\beta >0\) and scale parameter \(\alpha >0\). The cumulative distribution function is

\[ \begin{aligned} F = \frac{1}{1 + (t/\alpha)^{-\beta}} \end{aligned} \]

the survival function is \(1 - F\):

\[ \begin{aligned} S &= 1 - \frac{1}{1 + (t/\alpha)^{-\beta}} \\ &= \frac{1 + (t/\alpha)^{-\beta}}{1 + (t/\alpha)^{-\beta}} - \frac{1}{1 + (t/\alpha)^{-\beta}} \\ &= \frac{(t/\alpha)^{-\beta}}{1 + (t/\alpha)^{-\beta}} \\ &= \frac{1}{(t/\alpha)^\beta (1 + (t/\alpha)^{-\beta})} \\ &= \frac{1}{1 + (t/\alpha)^{\beta}} \\ \end{aligned} \]

the hazard function is \(f/S\):

\[ \begin{aligned} h &= \frac{\frac{(\beta/\alpha)(t/\alpha)^{\beta-1}}{(1 + (t/\alpha)^\beta)^2}}{\frac{1}{1 + (t/\alpha)^{\beta}}} \\ &= \frac{(\beta/\alpha)(t/\alpha)^{\beta-1} (1 + (t/\alpha)^{\beta}) }{(1 + (t/\alpha)^\beta)^2} \\ &= \frac{(\beta/\alpha)(t/\alpha)^{\beta-1}}{1 + (t/\alpha)^\beta} \quad \text{cancelling similar terms} \end{aligned} \]

Say that we want to simulate data where there was 10% cumulative incidence by day 360, e.g. in the first 360 days of life about 10% of infants will experience a medical attendance for RSV, a respiratory illness.

We want \(S(360) = \frac{1}{1 + (360/\alpha)^\beta} = \pi = 0.9\). As \(\alpha\) is the scale parameter, which is usually modelled as a linear function of parameters (treatment effects etc), assume that \(\beta\) is known and solve for \(\alpha\)

\[ \begin{aligned} \alpha = \frac{360}{((1/\pi) - 1)^{1/\beta}} \end{aligned} \]

For example, say \(\beta = 2\), this implies \(\alpha = \frac{360}{((1/0.9) - 1)^{1/2}} = 1080\) and gives the functional forms as shown below. Setting the survival probability to 0.5 and solving for time gives the median survival time under these parameters, i.e \(1080 \times 1^{1/\beta} = 1080\). Obviously, these values are just for demonstration and can be calibrated to subject matter expertise as necessary for simulating trial designs etc.

Code
# log-logistic parameters
# shape parameter
b <- 2
# scale
a <- 360 / ( (1/0.9)-1 )^(1/b)    

# Create a data.table with days from 1 to 360
dt <- data.table(day = 1:1080)

# Compute the survival function S(t) = 1 / (1 + (t/a)^b)
dt[, survival := 1 / (1 + (day / a)^b)]

# Compute the density f(t) = (gamma/alpha) * (t/alpha)^(gamma-1) / [1 + (t/alpha)^gamma]^2
dt[, density := (b / a) * (day / a)^(b - 1) / (1 + (day / a)^b)^2]

# Compute the hazard function h(t) = f(t) / S(t)
dt[, hazard := density / survival]

# Plot the survival curve
ggplot(dt, aes(x = day, y = survival)) +
  geom_line(color = "blue", lwd = 0.4) +
  geom_vline(xintercept = 360, lwd = 0.2) +
  labs(title = "Survival Curve for RSV (Log-Logistic Model)",
       x = "Day of Life", y = "Survival Probability S(t)") +
  scale_y_continuous("Survival S(t)", limits = c(0.5, 1), seq(0.5, 1, by = 0.1)) +
  theme_minimal()
Warning: Removed 1 row containing missing values or values outside the scale range
(`geom_line()`).
Code
# Plot the hazard function
ggplot(dt, aes(x = day, y = hazard)) +
  geom_line(color = "darkgreen", lwd = 0.4) +
  geom_vline(xintercept = 360, lwd = 0.2) +
  labs(title = "Hazard Function for RSV (Log-Logistic Model)",
       x = "Day of Life", y = "Hazard h(t)") +
  theme_minimal()
# Plot the density function
ggplot(dt, aes(x = day, y = density)) +
  geom_line(color = "red", lwd = 0.4) +
  geom_vline(xintercept = 360, lwd = 0.2) +
  labs(title = "Density Function for RSV (Log-Logistic Model)",
       x = "Day of Life", y = "Density f(t)") +
  theme_minimal()

In an AFT model, which is appropriate when we are more concerned with direct assessment of event times (AFT is also a way to work around non-proportional hazards) rather than a focus on instantaneous risk, the scale parameter is allowed to vary with the covariates, such as:

\[ \begin{aligned} \alpha_i &= \exp(\mu_i) \\ \mu &= \gamma_0 + \gamma_1 x_1 + \dots \end{aligned} \]

The density for observation \(i\) is then:

\[ \begin{aligned} f(t_i) &= \frac{\beta}{\exp(\mu_i)} \left( \frac{t_i}{\exp(\mu_i)}\right)^{\beta - 1} \left[ 1 + \left(\frac{t_i}{\exp(\mu_i)} \right)^\beta \right]^{-2} \end{aligned} \]

taking logs of this gives the log-likelihood for observation \(i\):

\[ \begin{aligned} \log f(t_i) &= \log(\beta) - \mu_i + (\beta - 1)\left[ \log(t_i) - \mu_i \right] -2 \log \left( 1 + (t_i/\exp(\mu_i))^\beta \right) \end{aligned} \]

for the right censored records, the survival function is used:

\[ \begin{aligned} S = \frac{1}{1 + (t_i/\exp(\mu_i))^\beta} \end{aligned} \]

taking logs:

\[ \begin{aligned} \log S &= 0 + \log \left[ 1 + \left( \frac{t_i}{\exp(\mu_i)} \right)^\beta \right] \end{aligned} \]

Implement stan model:

// Log-logistic AFT model
data {
  int<lower=0> N;             // Number of observations
  int<lower=0> P;             // Number of predictors
  matrix[N, P] X;             // Predictor matrix X[, 1] is intercept
  vector<lower=0>[N] y;       // Observed survival times
  vector<lower=0, upper=1>[N] event;  // Event indicator (1=event, 0=censored)
  
  int N_pred;
  vector[N_pred] t_surv;    // time to predict survival at
  
}

parameters {
  vector[P] gamma;             // Regression coefficients for scale
  real<lower=0> shape;        // Shape parameter (b in the formula)
}

transformed parameters {
  // Location parameter (log-scale)
  vector[N] mu;      
  
  mu = X * gamma;
}

model {
  // Priors - arbitrary at the moment
  target += normal_lpdf(gamma | 0, 2);
  target += gamma_lpdf(shape | 1, 0.1);
  
  // Likelihood
  for (i in 1:N) {
    if (event[i] == 1) {
      // For observed events, use the log-logistic density
      target += log(shape) - mu[i] + (shape - 1) * (log(y[i]) - mu[i]) - 
                2 * log1p(pow(y[i] / exp(mu[i]), shape));
    } else {
      // For censored observations, use the log survival function
      target += -log1p(pow(y[i] / exp(mu[i]), shape));
    }
  }
}

generated quantities {
  vector[N_pred] surv0;
  vector[N_pred] surv1;
  
  real med_surv_time0;
  real med_surv_time1;
  
  // obviously raising 1 to anything is 1 so only need the scale part
  med_surv_time0 = exp(gamma[1]);
  med_surv_time1 = exp(gamma[1] + gamma[2]);
  
  for(i in 1:N_pred){
    surv0[i] =  1 / (1 + pow(t_surv[i]/exp(gamma[1]),  shape));
    surv1[i] =  1 / (1 + pow(t_surv[i]/exp(gamma[1] + gamma[2]),  shape));
  }
  
  
}

Running the model with the assumed data gives parameter estimates.

Code
mod_01 <- cmdstanr::cmdstan_model("stan/log-logistic-aft-01.stan")

# Simulation parameters
N <- 2000 

gamma_true <- c(log(1080), 1)  
# True shape parameter
shape_true <- 2  

# Simulate covariates
simulate_data <- function(
    N = 2000,
    gamma_true = c(log(1080), 0.3)  ,
    shape_true = 2  ,
    t_cen = 360
    ) {
  
  d <- data.table(
    trt = rep(0:1, length  = N)
  )
  
  d[trt == 0, scale := exp(gamma_true[1])]
  d[trt == 1, scale := exp(gamma_true[1] + gamma_true[2])]
  
  d[, t_evt := flexsurv::rllogis(.N, shape = shape_true, scale = scale)]
  d[, evt := as.numeric(t_evt <= t_cen)]
  d[evt == 1, t_evt_obs := t_evt]
  d[evt == 0, t_evt_obs := t_cen]

  # d[, .N, keyby = .(trt, evt)]
  d
}

# Simulate data
d_sim <- simulate_data()
# d_sim[, .N, keyby = .(trt, evt)]

# Prepare data for Stan
ld <- list(
  N = nrow(d_sim),
  P = 2,
  X = cbind(1, d_sim$trt),
  y = d_sim$t_evt_obs,
  event = d_sim$evt,
  N_pred = 361,
  t_surv = 0:360
)


# Fit the Stan model - sink to remove the noise
# snk <- capture.output(
  m1 <- mod_01$sample(
      ld, iter_warmup = 1000, iter_sampling = 1000,
      parallel_chains = 4, chains = 4, refresh = 0, show_exceptions = F,
      max_treedepth = 10)
Running MCMC with 4 parallel chains...

Chain 3 finished in 3.9 seconds.
Chain 4 finished in 4.0 seconds.
Chain 1 finished in 4.1 seconds.
Chain 2 finished in 4.2 seconds.

All 4 chains finished successfully.
Mean chain execution time: 4.1 seconds.
Total execution time: 4.3 seconds.
Code
# )

Extract the parameters that we are interested in:

Code
m1$summary(variables = c("gamma", "shape", "med_surv_time0", "med_surv_time1"))
# A tibble: 5 × 10
  variable           mean  median      sd     mad      q5     q95  rhat ess_bulk
  <chr>             <dbl>   <dbl>   <dbl>   <dbl>   <dbl>   <dbl> <dbl>    <dbl>
1 gamma[1]          6.86  6.85e+0 8.87e-2 8.76e-2 6.72e+0 7.01e+0  1.00    1512.
2 gamma[2]          0.353 3.49e-1 8.10e-2 8.03e-2 2.25e-1 4.96e-1  1.00    2144.
3 shape             2.30  2.29e+0 1.75e-1 1.73e-1 2.01e+0 2.59e+0  1.00    1532.
4 med_surv_time0  953.    9.44e+2 8.58e+1 8.25e+1 8.26e+2 1.11e+3  1.00    1512.
5 med_surv_time1 1361.    1.34e+3 1.68e+2 1.60e+2 1.13e+3 1.68e+3  1.00    1654.
# ℹ 1 more variable: ess_tail <dbl>

By way of a sanity check, run the equivalent model using the flexsurv package.

Code
m2 <- flexsurvreg(Surv(t_evt_obs, evt) ~ trt, data = d_sim, dist = "llogis")
print(m2)
Call:
flexsurvreg(formula = Surv(t_evt_obs, evt) ~ trt, data = d_sim, 
    dist = "llogis")

Estimates: 
       data mean  est       L95%      U95%      se        exp(est)  L95%    
shape        NA   2.28e+00  1.94e+00  2.67e+00  1.85e-01        NA        NA
scale        NA   9.51e+02  7.92e+02  1.14e+03  8.85e+01        NA        NA
trt    5.00e-01   3.46e-01  1.80e-01  5.12e-01  8.47e-02  1.41e+00  1.20e+00
       U95%    
shape        NA
scale        NA
trt    1.67e+00

N = 2000,  Events: 146,  Censored: 1854
Total time at risk: 703384.6
Log-likelihood = -1337.121, df = 3
AIC = 2680.241

Another option for model implementation might be through brms with a custom family (if that is possible).

Code
d_sim[, censored := 1-evt]
# brms is backwords - 
# for cens, specify 0 to indicate no censoring and 1 to indicate right censoring

brms::make_stancode(t_evt_obs | cens(censored) ~ trt, data = d_sim, family = lognormal())
brms::make_stancode(t_evt_obs | cens(censored) ~ trt, data = d_sim, family = weibull())

Exponentiating the \(\gamma_2\) parameter gives the acceleration factor associated with the treatment effect. For example, if \(\gamma_2 > 0\) we can say that change from the control to treatment arm is associated with survival times being multiplied by a factor of \(\exp(\gamma_2)\), indicating prolonged survival/delayed events. Similarly, if \(\gamma_2 < 0\) we have a reduction in survival (the time to event speeds up).

In a log-logistic AFT model with the current parameterisation, the median survival time for an individual with covariates \(x_i\) is given by \(\exp( \gamma x_i') = \alpha_i = \text{scale}_\text{i}\). Median survival is a common measure used to contrast groups.

Produce a posterior for the survival curve:

Code
d_post <- data.table(m1$draws(variables = c("surv0", "surv1"), format = "matrix"))
d_post <- melt(d_post, measure.vars = names(d_post))
d_post[variable %like% "surv0", trt := 0]
d_post[variable %like% "surv1", trt := 1]
d_fig <- copy(d_post)

d_fig[, x := gsub(".*\\[", "", variable)]
d_fig[, x := gsub("\\]", "", x)]
d_fig[, x := as.numeric(x)]

d_fig <- d_fig[
  , .(mu = mean(value),
      q_025 = quantile(value, prob = 0.025),
      q_975 = quantile(value, prob = 0.975)), keyby = .(trt, x)]
d_fig[, trt := factor(trt, levels = 0:1, labels = c("ctl", "trt"))]

ggplot(d_fig, aes(x = x, y = mu, group = trt, col = trt)) +
  geom_ribbon(aes(ymin = q_025, ymax = q_975, fill = trt), alpha = 0.1, col = NA) +
  geom_line() + 
  scale_y_continuous(limits = c(0.7, 1), breaks = seq(0.7, 1, by = 0.1)) +
  scale_color_discrete("") +
  scale_fill_discrete("") +
  theme_minimal() +
  theme(
    legend.position = "bottom"
  )

Posterior on the median survival time, the time at which 50% of the cohort has experienced the occurrence of the event, e.g. a medical attendance for RSV ARI.

Code
d_post <- data.table(m1$draws(variables = c("med_surv_time0", "med_surv_time1"), format = "matrix"))

d_fig <- melt(d_post, measure.vars = names(d_post), variable.name = "trt")

d_fig[, trt := factor(trt, levels = c("med_surv_time0", "med_surv_time1"), labels = c("ctl", "trt"))]

ggplot(d_fig, aes(x = value, group = trt, col = trt)) +
  geom_density() +
  scale_x_continuous("Median survival time") +
  scale_color_discrete("Treatment") +
  theme_minimal() +
  theme(
    legend.position = "bottom"
  )

Estimate the posterior for restricted mean survival time (by treatment group) by integrating under the survival curve function for each draw from the posterior.

The RMST can be interpreted as the average survival time (i.e. time without the event, here being occurrence of RSV) during a defined time period ranging from time 0 to a specific follow-up time point.

Code
# Define the function to integrate
integrand_1 <- function(
    x, mu, shape) {
  
  a = exp(mu)
  S = 1 / (1 + (x/a)^shape)
  S
  
}


d_post <- data.table(m1$draws(variables = c("gamma", "shape"), format = "matrix"))
names(d_post) <- c(paste0("gamma", 1:2), "shape")
i <- 1
m_rmst <- matrix(NA, ncol = 2, nrow = nrow(d_post))
for(i in 1:nrow(d_post)){
  m_rmst[i, 1] <- integrate(
    integrand_1, lower = 0, upper = 360,
    mu = d_post$gamma1[i], 
    shape = d_post$shape[i])$value  
  m_rmst[i, 2] <- integrate(
    integrand_1, lower = 0, upper = 360,
    mu = d_post$gamma1[i] + d_post$gamma2[i], 
    shape = d_post$shape[i])$value  
}



d_rmst <- data.table(m_rmst)
names(d_rmst) <- paste0(0:1)

rmst_diff <- d_rmst$`1` -  d_rmst$`0`


d_fig <- melt(d_rmst, measure.vars = names(d_rmst), variable.name = "trt")

d_fig[, trt := factor(trt, levels = 0:1, labels = c("ctl", "trt"))]

ggplot(d_fig, aes(x = value, group = trt, col = trt)) +
  geom_density() +
  scale_x_continuous("RMST") +
  scale_color_discrete("Treatment") +
  theme_minimal() +
  theme(
    legend.position = "bottom"
  )

Code
names(d_rmst) <- paste0("rmst", 0:1)
d_rmst[, diff := rmst1 - rmst0]

From here, we could evaluate differences in the RMST between groups considering what level of improvement in the mean survival to 360 days would be warranted to decide on adopting the treatment over the control.

Note that I have assumed a log-logistic parametric assumption here, primarily because I wanted something similar to a log-normal but more tractable. Other distributional assumptions might be more suitable. For example, if the data have a peaked hazard followed by a decline, then standard log-logistic or generalized log-logistic may work well. However, if the hazard function is more complex (e.g. bathtub shape, non-monotonic tail behavior etc), the Generalized F or Burr distributions might be better. Weibull or gamma models are simpler if only a monotonically increasing or decreasing hazard is required. All of these are reasonably straight forward to code up.

References

1. Sun J. Statistical analysis of interval censored failure time data. Springer; 2006.
2. Christensen R. Bayesian ideas and data analysis. CRC Press; 2011.
3. Collett D. Modelling survival data in medical research. CRC Press; 2015.
4. Cleves M. Introduction to survival analysis using stata. Stata Press; 2010.