<- brm(
fit bf(
~ TRT01P + gp(AVISITN, by = TRT01P),
AVAL autocor = ~ cosy(time = AVISIT, gr = SUBJID)
),data = analysis_data2,
silent = 2,
refresh = 0,
seed = 45646,
chains = 4,
cores = 4,
backend = "cmdstanr", # request cmdstanr as backend
threads = threading(4) # request 4 threads per chain
)
19 Parallel computation
For large data-sets MCMC sampling can become very slow. The computational cost of any Stan program is dominated by the calculation of the gradient of the log likelihood function of the model. Most statistical models involve the summation over many independent contributions to the log likelihood. As the sum is associative - the result is invariant to re-orderings of individual terms of the sum - the log likelihood evaluation can be parallelized over multiple CPU cores.
19.1 Within-chain parallelization with threading
brms
supports the within-chain parallelization feature of Stan in a way which is fully automatic for the user. However, it does require to request threading support when calling the brm
command, since then the Stan model is written in a different manner using the reduce_sum
feature in Stan. Furthermore, the reduce_sum
feature is part of the Stan modeling language since version 2.23.0. In these vignettes, we use a cmdstan
installation with a sufficiently high Stan version of 2.32.2.
Thus, for a user to take advantage of within-chain parallelization means to slightly change the call to brm
like:
The above call will request to run 4 chains in parallel (cores=4
) and will allocate 4 threads per chain such that in total 16 cores will be used at the same time. It can be somewhat confusing to users that cores
is set to 4 and we still use 16 physical cores. This is due to historical reasons as the meaning of cores
refers to parallel chains running at the same time and threading was added at a later stage to Stan.
19.2 Running on multiple machines
Within-chain parallelization is only useful if sufficient cores have been allocated. For example, on a machine with 4 cores, the potential benefits are limited. In such a setting, one may consider to run 2 chains and use 2 threads per chain. This is a compromise given that we usually wish to run 4 chains.
If a computer cluster, or other distributed computing environment is available, within-chain parallelization may be advantageous, but users must take care to request an appropriate number of resources. One must submit a job to a compute cluster which allocates a sufficient amount of CPU cores.
A modern way in R to exploit parallelism is the mirai
framework, which can use local and remote ressources. While brms
has not been written with mirai
in mind, we can still adapt brms
to use mirai
backends through the use of future
and future.mirai
as:
# loading future makes brms use it for parallelization
library(future)
# first instantiate mirai daemons; here we use all available cores on
# the local machine
::daemons(parallelly::availableCores())
mirai
# then register the mirai cluster as future backend
plan(future.mirai::mirai_cluster)
# point brms by default to use the future backend
options(future = TRUE)
source(here::here("src", "simulate_fake_data.R"))
<- simulate_fake_data()
fake_data
brm(
~ 1 + x1 + x2 + (1 | g),
y data = fake_data,
family = poisson(),
iter = 2000,
warmup = 1000,
prior = prior(normal(0, 1), class = b) +
prior(constant(1), class = sd, group = g),
backend = "cmdstanr",
# request 4 threads per chain
threads = threading(4),
chains = 4,
seed = 345345,
refresh = 0,
control = list(adapt_delta = 0.95)
)
# terminate mirai daemons (or keep them running if used later again)
::daemons(0) mirai
In the above example a mirai
cluster is used with locally started workers which requests 4 parallel chains with each 4 threads. To run, for example, on an IBM LSF managed cluster one can replace the above first call to mirai::deamons
with:
# Remote worker setup based for a LSF cluster which uses "modules" to
# load specific R versions, common on HPCE systems. Requesting 3000 MB
# of RAM (-M 3000) and 4 cores per job (-n 4) which all run on the
# very same machine (-R "span[hosts=1]").
<- mirai::cluster_config(
lsf_config command = "bsub",
options = '#BSUB -J mirai
#BSUB -M 3000
#BSUB -n 4
#BSUB -R "span[hosts=1]"
#BSUB -o job.out
module load R/4.5.0',
rscript = "Rscript"
)
::daemons(n = 4, url = mirai::host_url(), remote = lsf_config) mirai
For more options on how to spawn remote workers, please refer to the mirai
documentation; in particular the section on remote daemons.
As HPCEs may not always have the latest R software installed, we provide here an alternative way, which is based on the well established R package clustermq
. This R package is tailored to interact with a queing system on a cluster and below the utlity function cmq_brm
is provided. This function facilites running the posterior sampling on the cluster via clustermq
. The idea is to define first the brms
model in the current session, but do not perform any actual sampling in the current process by setting the arguemnt chains=0
. The so defined model is given to the update_model
function which then uses clustermq
to run the sampling on the registered backend of clustermq
. The clustermq
package then will sample either locally or on the cluster:
# Fit model as usual, just use "cmq_brm" instead of brm call, see
# definition at the bottom
<- cmq_brm(
model_poisson ~ 1 + x1 + x2 + (1 | g),
y data = fake,
family = poisson(),
iter = 2000,
warmup = 1000,
prior = prior(normal(0, 1), class = b) +
prior(constant(1), class = sd, group = g),
# use the cmdstanr backend
backend = "cmdstanr",
# request 4 threads per chain
threads = threading(4),
seed = 345345,
control = list(adapt_delta = 0.95)
)
It is important to note that between-chain parallelization is always more efficient than within-chain parallelization. Moreover, with increasing number of cores per chain, the efficiency of the parallelisation decreases. For a detailed discussion and instructions on how to tune the number of cores for a given problem, please refer to the vignette from brms
on threading available on CRAN.
19.3 Exact reproducability
In the default setting of threading exact reproducibility is not maintained! This stems from the fact that the large log likelihood sums are partitioned into blocks of varying size and these sums are moreover accumulated in a random order. Due to the limited floating point precision of CPU cores this leads to slightly varying results at the order of the machine precision (\(10^{-16}\) commonly). In case exactly reproducible results are required, Stan offers a static version of reduce_sum
. This variant can be requested by using the static=TRUE
argument of the threading function. However, when doing so, it is recommended to also provide a so-called grainsize
. The grainsize
is the number of terms which are included in every partial sum. This size should be large enough such that each partial sum formed represents a considerable amount of work, while being small enough to allow for good load balancing between the CPUs. The threading vignette linked above discusses strategies on how to define a reasonable grainsize
. The call with static
sum partitioning may then look like threading(2, 80, TRUE)
, which will request 2 threads per chain, a grainsize of \(80\) and static
sum partitioning.
19.4 Implementation of cmq_brm
# executes brm with parallelization via clustermq
<- function(
cmq_brm
...,
seed,control = list(adapt_delta = 0.9),
cores,
file,chains = 4,
.log_worker = FALSE
) {::assert_integer(
checkmateas.integer(seed),
lower = 1,
any.missing = FALSE,
len = 1
)<- options()[grep("^brms", names(options()), value = TRUE)]
brms_global <- options()[grep(
cmdstanr_global "^cmdstanr",
names(options()),
value = TRUE
)]<- rlang::enquos(...)
dots <- lapply(dots, rlang::eval_tidy)
brms_args suppressWarnings(
<- do.call(brm, modifyList(brms_args, list(chains = 0, cores = 1)))
model
)<- list(
update_args object = model,
seed = seed,
cores = 1,
chains = 1,
control = control
)if (!missing(file)) {
$file <- sub("\\.rds$", "", file)
update_args
}if (!missing(file)) {
$file <- sub("\\.rds$", "", file)
brms_args
}<- .libPaths()
master_lib_paths <- function(chain_id) {
update_model .libPaths(master_lib_paths)
library(brms)
# in case file is part of extra-arguments, we add here the chain_id
# to get correct by-chain file caching
if ("file" %in% names(update_args)) {
<- modifyList(
update_args
update_args,list(file = paste0(update_args$file, "-", chain_id))
)
}if ("file" %in% names(brms_args)) {
<- modifyList(
brms_args
brms_args,list(file = paste0(brms_args$file, "-", chain_id))
)
}$chain_id <- chain_id
update_args$chain_id <- chain_id
brms_args# ensure the same brms & cmdstanr global options are set
options(brms_global)
options(cmdstanr_global)
## for the rstan backend we do an update while for cmdstanr we
## have to avoid this for whatever reason
if (model$backend == "cmdstanr") {
<- capture.output(
msg <- do.call(brm, modifyList(brms_args, list(chains = 1)))
fit
)else {
} <- capture.output(fit <- do.call(update, update_args))
msg
}list(fit = fit, msg = msg)
}<- chains
n_jobs <- getOption("clustermq.scheduler", "multiprocess")
backend if (backend %in% c("multiprocess", "multicore")) {
<- min(chains, getOption("mc.cores", 1))
n_jobs
}<- 1
cores_per_chain if (!is.null(model$threads$threads)) {
<- model$threads$threads
cores_per_chain
}if (chains == 1 & cores_per_chain == 1) {
## looks like a debugging run...avoid clustermq
return(update_model(1)$fit)
}message(
"Starting ",
chains," chains with a concurrency of ",
n_jobs," and using ",
cores_per_chain," cores per chain with backend ",
backend,"...\n"
)<- clustermq::Q(
cluster_update
update_model,chain_id = 1:chains,
n_jobs = n_jobs,
export = list(
update_args = update_args,
brms_args = brms_args,
brms_global = brms_global,
cmdstanr_global = cmdstanr_global,
master_lib_paths = master_lib_paths,
model = model
),template = list(cores = cores_per_chain),
log_worker = .log_worker
)<- combine_models(mlist = lapply(cluster_update, "[[", "fit"))
fit <- lapply(cluster_update, "[[", "msg")
msg for (i in seq_len(length(msg))) {
if (length(msg[[i]]) == 0) {
next
}cat(paste0("Output for chain ", i, ":\n"))
cat(paste(msg[[i]], collapse = "\n"), "\n")
}$file <- NULL
fit
fit }
20 Example fake data simulate_fake_data
<- function(N = 1E4, G = 1E3, P = 3, seed = 46765875) {
simulate_fake_data ::local_seed(seed)
withr
# regression coefficients
<- rnorm(P)
beta
# sampled covariates, group means and fake data
<- matrix(rnorm(N * P), ncol = P)
fake dimnames(fake) <- list(NULL, paste0("x", 1:P))
# fixed effect part and sampled group membership
<- transform(
fake as.data.frame(fake),
theta = fake %*% beta,
g = sample.int(G, N, replace = TRUE)
)
# add random intercept by group
<- merge(fake, data.frame(g = 1:G, eta = rnorm(G)), by = "g")
fake
# linear predictor
<- transform(fake, mu = theta + eta)
fake
# sample Poisson data
<- transform(fake, y = rpois(N, exp(mu)))
fake
# shuffle order of data rows to ensure even distribution of computational effort
<- fake[sample.int(N, N), ]
fake
# drop not needed row names
rownames(fake) <- NULL
return(fake)
}