Discrete parameters in stan

stan
discrete parameters
Author

maj

Published

November 20, 2024

Modified

March 13, 2025

library(cmdstanr)
This is cmdstanr version 0.8.1
- CmdStanR documentation and vignettes: mc-stan.org/cmdstanr
- CmdStan path: /Users/mark/.cmdstan/cmdstan-2.35.0
- CmdStan version: 2.35.0

A newer version of CmdStan is available. See ?install_cmdstan() to install it.
To disable this check set option or environment variable cmdstanr_no_ver_check=TRUE.
library(data.table)
library(ggplot2)

Hamiltonian Monte Carlo is usually used to estimate continuous parameters. However, sometimes we are interested in discrete parameters.

HMC runs on log-probabilities. Therefore, as long as we can increment the log-probability associated with our model (whatever the model is), we can code it up, even if it uses discrete parameters. For models with discrete parameter, we have to ‘marginalise out’ the discrete parameter(s) and increment the log-density by the marginal density.

Ben Lambert has an example:

Consider a series of \(K\) experiments where we get someone flips a coin a fixed number of times, \(n\), for each experiment, but we are never told \(n\). The individual uses the same coin across all the experiments and the probability of heads remains constant \(\text{Pr}(X = \text{heads}) = \theta\). We define \(X_k\) as the number of heads obtained in each experiment \(k\) yielding \((X_1, X_2, \dots X_K)\). Again, both \(\theta\) and \(n\) are unknown to us; \(\theta\) is continuous, but \(n\) is discrete.

We want to be able to make inference on both \(\theta\) and \(n\), i.e. talk about the probability intervals for theta and the probability that \(n\) takes on certain values.

Assume that we know \(K = 10\) experiments were run and the following data was observed \(X_k = (2, 4, 3, 3, 3, 3, 3, 3, 4, 4)\).

We adopt independent priors on \(\theta\) and \(n\):

\[ \begin{aligned} n &\sim \text{Discrete-Unif}(5, 8) \\ \theta &\sim \text{Unif}(0, 1) \end{aligned} \]

We can write down the joint posterior of \(\theta\) and \(n\), i.e. \(\text{Pr}(\theta, n | X)\) and then marginalise out the discrete parameter, \(n\). Once we have an expression that excludes the \(n\), then we can get stan to use that expression to conduct the sampling we want it to do. We have:

\[ \begin{aligned} \text{Pr}(\theta | X) &= \sum_{n=5}^8 \text{Pr}(\theta, n | X) \end{aligned} \]

Stan runs on the log probability so we need to think in those terms:

\[ \begin{aligned} \log(\text{Pr}(\theta | X)) &= \log \left(\sum_{n=5}^8 \text{Pr}(\theta, n | X) \right) \\ &= \log \left(\sum_{n=5}^8 \exp( \log( \text{Pr}(\theta, n | X))) \right) \\ \end{aligned} \]

where the second line is to ensure we are dealing with log probabilities for both terms.

In stan, the above can be achieved in a mathematically stable way via the log-sum-exp function.

\[ \begin{aligned} \log(\text{Pr}(\theta | X)) &= \text{log\_sum\_exp}_{n=5}^8 \left( \log( \text{Pr}(\theta, n | X) ) \right) \end{aligned} \]

Unfortunately, we do not have \(\text{Pr}(\theta, n | X)\) but we can use Bayes rule to determine what it is:

\[ \begin{aligned} \text{Pr}(\theta, n | X) &\propto \text{Pr}(X | \theta, n) \text{Pr}(\theta, n) \\ &= \text{Pr}(X | \theta, n) \text{Pr}(n)\text{Pr}(\theta) \end{aligned} \]

where the second line comes from the fact that we use independent priors. Taking logs, we get:

\[ \begin{aligned} \log(\text{Pr}(\theta, n | X)) &\propto \log(\text{Pr}(X | \theta, n)) + \log(\text{Pr}(n)) + \log(\text{Pr}(\theta)) \\ \end{aligned} \]

The above allows us to estimate models with discrete parameters by marginalising them out of the joint density. But, what if we want to do inference on the discrete parameters?

Answer; write down the unnormalised density of \(n\) conditional on \(X\) and estimate via MCMC:

\[ \begin{aligned} q(n | x) \approx \frac{1}{B} \sum_{i = 1}^B q(n, \theta_i | X) \end{aligned} \]

where \(B\) is the number of MCMC samples and \(\theta_i\) are the posterior samples. Essentially, this is averaging over \(\theta_i\).

To get the normalised version, we need to form a simplex (sum of elements is 1 and elements are non-negative and less than 1) across the four possible values for \(n\) giving probabilities for \(n = 5\), \(n = 6\), \(n = 7\) and \(n = 8\). We can obtain this from:

\[ \begin{aligned} p(n | x) &\approx \frac{q(n|X)}{\sum_{n = 5}^8 q(n, | X)} \\ &= \frac{\exp( \log( q(n | x) ) )}{ \exp( \text{log\_sum\_exp} (\log( q(n | X))) )} \end{aligned} \]

and in stan, we would write this as:

\[ \begin{aligned} p(n | x) = \exp\left[ \log(q(n | X)) - \text{log\_sum\_exp}(\log( q(n | X)) ) \right] \end{aligned} \]

An implementation for the above discussion is shown below:



data {
  // num expt
  int<lower=0> K;
  array[K] int X;
}
transformed data{
  array[4] int n;
  // these are the permissible values of n, 
  // i.e. 5, 6, 7, 8
  for(i in 1:4){
    n[i] = 4 + i;
  }
}
parameters {
  real<lower=0, upper=1> theta;
}
transformed parameters{
  // unnormalised density
  vector[4] lq;
  for(i in 1:4){
    // record the unnormalised density for every possible value 
    // of n
    
    // log pmf for the array of values in X conditional on a given n[i]
    // and the parameter theta PLUS the prior on the given n, which is 
    // a discrete uniform, i.e. for each n, the prior is 0.25
    lq[i] = binomial_lpmf(X | n[i], theta) + log(0.25);
  }
}
model {
  target += uniform_lpdf(theta | 0, 1);
  
  // marginalise out the troublesome n
  target += log_sum_exp(lq);
  
}
generated quantities{
  // probability of n given X, i.e. the distribution of n | x
  vector[4] p_n_X;
  p_n_X = exp(lq - log_sum_exp(lq));
  
}

Running the model with the assumed data gives us our parameter estimates for both \(\theta\) and \(n\).

m1 <- cmdstanr::cmdstan_model("stan/discrete-param-1.stan")

ld <- list(
  K = 10, X = c(2, 4, 3, 3, 3, 3, 3, 3, 4, 4)
)

f1 <- m1$sample(
    ld, iter_warmup = 1000, iter_sampling = 1000,
    parallel_chains = 4, chains = 4, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 4 parallel chains...

Chain 1 finished in 0.0 seconds.
Chain 2 finished in 0.0 seconds.
Chain 3 finished in 0.0 seconds.
Chain 4 finished in 0.0 seconds.

All 4 chains finished successfully.
Mean chain execution time: 0.0 seconds.
Total execution time: 0.2 seconds.
f1$summary()
# A tibble: 10 × 10
   variable     mean     median    sd      mad       q5     q95  rhat ess_bulk
   <chr>       <dbl>      <dbl> <dbl>    <dbl>    <dbl>   <dbl> <dbl>    <dbl>
 1 lp__     -14.9    -14.7      0.643 0.353    -1.61e+1 -14.4    1.00    1565.
 2 theta      0.576    0.589    0.103 0.107     3.89e-1   0.729  1.00    1017.
 3 lq[1]    -14.6    -13.7      2.17  0.802    -1.95e+1 -13.1    1.00    1371.
 4 lq[2]    -15.7    -15.0      1.83  1.27     -1.95e+1 -14.0    1.00    1630.
 5 lq[3]    -18.3    -17.2      3.76  3.24     -2.61e+1 -14.6    1.00    1159.
 6 lq[4]    -22.0    -20.8      6.08  6.38     -3.40e+1 -15.1    1.00    1063.
 7 p_n_X[1]   0.596    0.725    0.366 0.374     5.64e-3   0.995  1.00    1017.
 8 p_n_X[2]   0.215    0.171    0.174 0.208     5.28e-3   0.506  1.00    1503.
 9 p_n_X[3]   0.115    0.0179   0.156 0.0265    5.90e-6   0.435  1.00    1086.
10 p_n_X[4]   0.0742   0.000437 0.163 0.000648  2.23e-9   0.501  1.00    1017.
# ℹ 1 more variable: ess_tail <dbl>

Obviously, the above is somewhat contrived. We gnerally do not know the bounds on the discrete parameters. For example, how did we know that the bounds of \(n\) were 5 and 8? How would we have modified the model to account for observing a 6 in the data?