User-defined Probability Distributions in Stan

stan
bayes
Author

maj

Published

September 25, 2024

Modified

September 30, 2024

Overview

Some of this material can be found in the stan user guide and this is solely to serve as a reference in my own words.

To implement, you just need to provide a function to increment the total log-probability appropriately.

Note

When a function with the name ending in *_lpdf* or *_lpmf* is defined, the stan compiler automatically makes a *_lupdf* or lupmf version. Only normalised custom distributions are permitted.

Assume that we want to create a custom distribution per:

\[ \begin{aligned} f(x) &= (1-a) x^{-a} \end{aligned} \]

defined for \(a \in [0,1]\) and \(x \in [0,1]\) with cdf:

\[ \begin{aligned} F_x &= x^{a-1} \end{aligned} \]

We can generate draws from this distribution using the inverse cdf method:

library(data.table)
library(ggplot2)
library(cmdstanr)
This is cmdstanr version 0.8.1
- CmdStanR documentation and vignettes: mc-stan.org/cmdstanr
- CmdStan path: /Users/mark/.cmdstan/cmdstan-2.35.0
- CmdStan version: 2.35.0
f_x <- function(x, a){
  if(a < 0 | a > 1) stop("only defined for a in [0,1]")
  if(any(x < 0 | x > 1)) stop("only defined for x in [0,1]")
  (1-a) * x ^ -a
}
F_x <- function(x, a){
  if(a < 0 | a > 1) stop("only defined for a in [0,1]")
  if(any(x < 0 | x > 1)) stop("only defined for x in [0,1]")
  x^(1-a)
}
F_inv_x <- function(u, a){
  if(a < 0 | a > 1) stop("only defined for a in [0,1]")
  if(any(u < 0 | u > 1)) stop("only defined for x in [0,1]")
  u ^ (1 / (1-a))
}

a <- 0.35
x <- seq(0, 1, len = 1000)
d_fig <- data.table(x = x, y = f_x(x, a))
d_sim <- data.table(
  y_sim = F_inv_x(runif(1e6), a)
)

ggplot(d_fig, aes(x = x, y = y)) +
  geom_histogram(data = d_sim, aes(x = y_sim, y = ..density..),
               inherit.aes = F, fill = 1, alpha = 0.2,
               binwidth = density(d_sim$y_sim)$bw) + 
  geom_line() +
  theme_bw()
Warning: The dot-dot notation (`..density..`) was deprecated in ggplot2 3.4.0.
ℹ Please use `after_stat(density)` instead.

functions {
  real custom_lpdf(vector x, real alpha) {
    
    int n_x = num_elements(x);
    vector[n_x] lpdf;
    for(i in 1:n_x){
      
      lpdf[i] = log1m(alpha) - alpha * log(x[i]);
    }  
    return sum(lpdf);
  }
}
data {
  int N;
  vector[N] y;
}

parameters {
  real<lower=0, upper = 1> a;
}
model {
  target += exponential_lpdf(a | 1);
  target += custom_lpdf(y | a);   
}
m1 <- cmdstanr::cmdstan_model("stan/custom-dist-1.stan")

ld = list(
  N = 1000, 
  y = d_sim$y_sim[1:1000]
)


f1 <- m1$sample(
  ld, iter_warmup = 1000, iter_sampling = 1000,
  parallel_chains = 1, chains = 1, refresh = 0, show_exceptions = F,
  max_treedepth = 10)
Running MCMC with 1 chain...

Chain 1 finished in 0.2 seconds.
f1$summary(variables = c("a"))
# A tibble: 1 × 10
  variable  mean median     sd    mad    q5   q95  rhat ess_bulk ess_tail
  <chr>    <dbl>  <dbl>  <dbl>  <dbl> <dbl> <dbl> <dbl>    <dbl>    <dbl>
1 a        0.363  0.363 0.0202 0.0191 0.328 0.394  1.01     435.     217.
post <- data.table(f1$draws(variables = "a", format = "matrix"))
hist(post$a)

References