Discrete parameters in stan

stan
discrete parameters
Author

maj

Published

November 20, 2024

Modified

November 20, 2024

library(cmdstanr)
This is cmdstanr version 0.8.1
- CmdStanR documentation and vignettes: mc-stan.org/cmdstanr
- CmdStan path: /Users/mark/.cmdstan/cmdstan-2.35.0
- CmdStan version: 2.35.0
library(data.table)
library(ggplot2)

Hamiltonian Monte Carlo can most directly estimate parameters for which the uncertainty is represented by continuous random variables. However, sometimes we are interested in discrete parameters.

HMC runs on log-probabilities. So, as long as we can still increment the log-probability associated with our model (whatever the model is), we can code up discrete parameters.

This typically requires us to marginalise out the discrete parameter(s) and then increment the log-density by that marginal density.

Ben Lambert has an example:

Consider a series of \(K\) experiments where get someone flips a coin \(n\) times. The use the same coin across all the experiments. Accordingly, the probability of heads remains constant \(\text{Pr}(X = \text{heads}) = \theta\). In each experiment we count the number of heads giving us data \(X_k\) for experiment \(k\) yielding \((X_1, X_2, \dots X_K)\). Both \(\theta\) and \(n\) are unknown to us, \(\theta\) is continuous, but \(n\) is discrete and so not directly amenable to inference in stan.

Assume the following data were observed \(X_k = (2, 4, 3, 3, 3, 3, 3, 3, 4, 4)\) for \(K = 10\) experiments and that we will adopt independent priors on \(\theta\) and \(n\), namely:

\[ \begin{aligned} n &\sim \text{Discrete-Unif}(5, 8) \\ \theta &\sim \text{Unif}(0, 1) \end{aligned} \]

To proceed we first have to write down the joint posterior of \(\theta\) and \(n\), i.e. \(\text{Pr}(\theta, n | X)\) and then work towards marginalising out \(n\). Once we have an expression that excludes the problematic \(n\) (at least for stan) then we can get stan to use that expression to conduct the sampling we want it to do.

\[ \begin{aligned} \text{Pr}(\theta | X) &= \sum_{n=5}^8 \text{Pr}(\theta, n | X) \end{aligned} \]

Stan runs on the log probability so what we actually have is:

\[ \begin{aligned} \log(\text{Pr}(\theta | X)) &= \log \left(\sum_{n=5}^8 \text{Pr}(\theta, n | X) \right) \\ &= \log \left(\sum_{n=5}^8 \exp( \log( \text{Pr}(\theta, n | X))) \right) \\ \end{aligned} \]

where the second line is to ensure we are dealing with log probabilities for both terms.

The log-sum-exp operation (implemented in stan) gives us what we need in a mathematically stable way.

\[ \begin{aligned} \log(\text{Pr}(\theta | X)) &= \text{log\_sum\_exp}_{n=5}^8 \left( \text{Pr}(\theta, n | X) \right) \end{aligned} \]

The problem we now face is that we do not have \(\text{Pr}(\theta, n | X)\) but we can use Bayes rule to determine what it is:

\[ \begin{aligned} \text{Pr}(\theta, n | X) &\propto \text{Pr}(X | \theta, n) \text{Pr}(\theta, n) \\ &= \text{Pr}(X | \theta, n) \text{Pr}(n)\text{Pr}(\theta) \end{aligned} \]

where the second line comes from the fact that we use independent priors. Taking logs, we get:

\[ \begin{aligned} \log(\text{Pr}(\theta, n | X)) &\propto \log(\text{Pr}(X | \theta, n)) + \log(\text{Pr}(n)) + \log(\text{Pr}(\theta)) \\ \end{aligned} \]

The first term on the RHS is the likelihood for which we use the binomial distribution. The second term is the log of the discrete uniform distribution that we defined earlier. For a given \(n \in \{ 5,6,7,8 \}\) this is \(\log(1/4)\). Finally, the third term is standard uniform. However, given that the third term does not contain \(n\), we do not actually need to include it in the expression for the joint distribution (although we do still include it in the model block).

The above allows us to estimate models with discrete parameters by marginalising them out of the joint density. But, what if we want to do inference on the discrete parameters?

The answer to this is to write down the unnormalised density of \(n\) conditional on \(X\) and estimate via MCMC:

\[ \begin{aligned} q(n | x) \approx \frac{1}{B} \sum_{i = 1}^B q(n, \theta_i | X) \end{aligned} \]

where \(B\) is the number of MCMC samples and \(\theta_i\) are the posterior samples. Essentially, this is averaging over \(\theta_i\).

To get the normalised version, we need to form a simplex (sum of elements is 1 and elements are non-negative and less than 1) across the four possible values for \(n\) giving probabilities for \(n = 5\), \(n = 6\), \(n = 7\) and \(n = 8\). We can obtain this from:

\[ \begin{aligned} p(n | x) &\approx \frac{q(n|X)}{\sum_{n = 5}^8 q(n, | X)} \\ &= \frac{\exp( \log( q(n | x) ) )}{ \exp( \text{log\_sum\_exp} (\log( q(n | X))) )} \end{aligned} \]

and in stan, we would write this as:

\[ \begin{aligned} p(n | x) = \exp\left[ \log(q(n | X)) - \text{log\_sum\_exp}(\log( q(n | X)) ) \right] \end{aligned} \]

An implementation for the above discussion is shown below:



data {
  // num expt
  int<lower=0> K;
  array[K] int X;
}
transformed data{
  array[4] int n;
  // gives us 5, 6, 7, 8
  for(i in 1:4){
    n[i] = 4 + i;
  }
}
parameters {
  real<lower=0, upper=1> theta;
}
transformed parameters{
  // unnormalised density
  vector[4] lq;
  for(i in 1:4){
    // record the unnormalised density for every possible value 
    // of n
    
    // log pmf for the array of values in X conditional on a given n[i]
    // and the parameter theta PLUS the prior on n 
    lq[i] = binomial_lpmf(X | n[i], theta) + log(0.25);
  }
}
model {
  target += uniform_lpdf(theta | 0, 1);
  
  // marginalise out the troublesome n
  target += log_sum_exp(lq);
  
}
generated quantities{
  // probability of n given X, i.e. the distribution of n | x
  vector[4] p_n_X;
  p_n_X = exp(lq - log_sum_exp(lq));
  
}

Running the model with the assumed data gives us our parameter estimates for both \(\theta\) and \(n\).

m1 <- cmdstanr::cmdstan_model("stan/discrete-param-1.stan")

ld <- list(
  K = 10, X = c(2, 4, 3, 3, 3, 3, 3, 3, 4, 4)
)

f1 <- m1$sample(
    ld, iter_warmup = 1000, iter_sampling = 1000,
    parallel_chains = 4, chains = 4, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 4 parallel chains...

Chain 1 finished in 0.0 seconds.
Chain 2 finished in 0.0 seconds.
Chain 3 finished in 0.0 seconds.
Chain 4 finished in 0.0 seconds.

All 4 chains finished successfully.
Mean chain execution time: 0.0 seconds.
Total execution time: 0.2 seconds.
f1$summary()
# A tibble: 10 × 10
   variable     mean     median    sd      mad       q5     q95  rhat ess_bulk
   <chr>       <dbl>      <dbl> <dbl>    <dbl>    <dbl>   <dbl> <dbl>    <dbl>
 1 lp__     -14.9    -14.7      0.620 0.369    -1.61e+1 -14.4    1.01    1376.
 2 theta      0.573    0.585    0.104 0.112     3.88e-1   0.722  1.00    1191.
 3 lq[1]    -14.7    -13.7      2.21  0.854    -1.96e+1 -13.1    1.01    1245.
 4 lq[2]    -15.6    -15.0      1.77  1.27     -1.92e+1 -14.0    1.00    1819.
 5 lq[3]    -18.2    -17.0      3.66  3.19     -2.54e+1 -14.6    1.00    1489.
 6 lq[4]    -21.8    -20.5      5.97  6.45     -3.30e+1 -15.1    1.00    1263.
 7 p_n_X[1]   0.582    0.703    0.370 0.403     5.40e-3   0.993  1.01    1191.
 8 p_n_X[2]   0.219    0.177    0.176 0.217     6.90e-3   0.507  1.00    1779.
 9 p_n_X[3]   0.121    0.0212   0.158 0.0314    9.92e-6   0.436  1.01    1334.
10 p_n_X[4]   0.0782   0.000572 0.167 0.000848  4.87e-9   0.506  1.01    1191.
# ℹ 1 more variable: ess_tail <dbl>

In the above, the data were created in an arbitrary fashion. The model can be revised slightly to assume any \(n\) and then we can simualate data more formally and see what we recover.



data {
  // num expt
  int<lower=0> K;
  array[K] int X;
  int<lower=0> P;
  array[P] int n;
}
transformed data{
}
parameters {
  real<lower=0, upper=1> theta;
}
transformed parameters{
  // unnormalised density
  vector[P] lq;
  for(i in 1:P){
    // record the unnormalised density for every possible value 
    // of n
    
    // log pmf for the array of values in X conditional on a given n[i]
    // and the parameter theta PLUS the prior on n 
    lq[i] = binomial_lpmf(X | n[i], theta) + log(1./P);
  }
}
model {
  target += uniform_lpdf(theta | 0, 1);
  
  // marginalise out the troublesome n
  target += log_sum_exp(lq);
  
}
generated quantities{
  // probability of n given X, i.e. the distribution of n | x
  vector[P] p_n_X;
  real mu_n;
  p_n_X = exp(lq - log_sum_exp(lq));
  
  {
    vector[P] tmp;
    for(i in 1:P){
      tmp[i] =  p_n_X[i] * n[i];
    }
    mu_n = sum(tmp);
  }
  
  
}
m2 <- cmdstanr::cmdstan_model("stan/discrete-param-2.stan")


set.seed(1)
theta <- 0.1
n <- 7
# 1000 experiments!!!
K <- 1e3
# K experiments, each with n = 7 and theta = 0.4 produces 
# the data for X
X <- rbinom(K, n, theta)
# range(X)

ld <- list(
  K = K, X = X, P = 5, n = c(5, 6, 7, 8, 9)
)

f2 <- m2$sample(
    ld, iter_warmup = 1000, iter_sampling = 1000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 2 finished in 0.9 seconds.
Chain 1 finished in 1.0 seconds.

Both chains finished successfully.
Mean chain execution time: 1.0 seconds.
Total execution time: 1.1 seconds.
# p_n_X is giving you the probability of n given 
# the observations contained in our data X
f2$summary()
# A tibble: 13 × 10
   variable       mean    median      sd      mad        q5        q95  rhat
   <chr>         <dbl>     <dbl>   <dbl>    <dbl>     <dbl>      <dbl> <dbl>
 1 lp__     -1089.     -1.09e+ 3  0.721  4.74e- 1 -1.09e+ 3 -1089.      1.03
 2 theta        0.0916  8.74e- 2  0.0160 1.44e- 2  7.37e- 2     0.123   1.06
 3 lq[1]    -1154.     -1.16e+ 3 36.7    4.62e+ 1 -1.21e+ 3 -1093.      1.06
 4 lq[2]    -1115.     -1.11e+ 3 20.9    2.67e+ 1 -1.15e+ 3 -1087.      1.05
 5 lq[3]    -1099.     -1.09e+ 3 12.5    1.12e+ 1 -1.12e+ 3 -1086.      1.02
 6 lq[4]    -1098.     -1.09e+ 3 22.5    5.95e+ 0 -1.14e+ 3 -1086.      1.03
 7 lq[5]    -1110.     -1.09e+ 3 38.6    1.04e+ 1 -1.19e+ 3 -1086.      1.06
 8 p_n_X[1]     0.0456  3.17e-31  0.206  4.71e-31  5.59e-54     0.0206  1.06
 9 p_n_X[2]     0.115   3.68e-12  0.307  5.45e-12  1.76e-28     1.00    1.05
10 p_n_X[3]     0.229   2.16e- 4  0.394  3.20e- 4  2.74e-15     0.999   1.01
11 p_n_X[4]     0.310   1.29e- 2  0.421  1.92e- 2  4.46e-24     0.994   1.03
12 p_n_X[5]     0.301   1.04e- 3  0.433  1.54e- 3  1.92e-44     1.00    1.06
13 mu_n         7.71    8.00e+ 0  1.13   1.46e+ 0  5.98e+ 0     9.00    1.06
# ℹ 2 more variables: ess_bulk <dbl>, ess_tail <dbl>