<- brm(
fit bf(
~ TRT01P + gp(AVISITN, by = TRT01P),
AVAL autocor = ~ cosy(time = AVISIT, gr = SUBJID)
),data = analysis_data2,
silent = 2,
refresh = 0,
seed = 45646,
chains = 4,
cores = 4,
backend = "cmdstanr", # request cmdstanr as backend
threads = threading(4) # request 4 threads per chain
)
17 Parallel computation
For large data-sets MCMC sampling can become very slow. The computational cost of any Stan program is dominated by the calculation of the gradient of the log likelihood function of the model. Most statistical models involve the summation over many independent contributions to the log likelihood. As the sum is associative - the result is invariant to re-orderings of individual terms of the sum - the log likelihood evaluation can be parallelized over multiple CPU cores.
brms
supports the within-chain parallelization feature of Stan in a way which is fully automatic for the user. However, it does require to request threading support when calling the brm
command, since then the Stan model is written in a different manner using the reduce_sum
feature in Stan. Furthermore, the reduce_sum
feature is part of the Stan modeling language since version 2.23.0. In these vignettes, we use a cmdstan
installation with a sufficiently high Stan version of 2.32.2.
Thus, for a user to take advantage of within-chain parallelization means to slightly change the call to brm
like:
The above call will request to run 4 chains in parallel (cores=4
) and will allocate 2 threads per chain such that in total 8 cores will be used at the same time. It can be somewhat confusing to users that cores
is set to 4 and we still use 8 physical cores. This is due to historical reasons as the meaning of cores
refers to parallel chains running at the same time and threading was added at a later stage to Stan.
Within-chain parallelization is only useful if sufficient cores have been allocated. For example, on a machine with 4 cores, the potential benefits are limited. In such a setting, one may consider to run 2 chains and use 2 threads per chain. This is a compromise given that we usually wish to run 4 chains.
If a computer cluster, or other distributed computing environment is available, within-chain parallelization may be advantageous, but users must take care to request an appropriate number of resources. One must a job to the main cluster which allocates a sufficient amount of CPU cores. The utility function cmq_brm
below facilites running the posterior sampling on the cluster. The idea is to define first the brms
model in the current session, but do not perform any actual sampling in the current process by setting the arguemnt chains=0
. The so defined model is given to the brm_update_cluster
function which then uses clustermq
to run the sampling on the registered backend of clustermq
. The clustermq
package then will sample either locally or on the cluster:
# Fit model as usual, just use "cmq_brm" instead of brm call, see
# definition at the bottom
<- cmq_brm(y ~ 1 + x1 + x2 + (1 | g),
model_poisson data = fake,
family = poisson(),
iter = 2000,
warmup = 1000,
prior = prior(normal(0,1), class = b) +
prior(constant(1), class = sd, group = g),
# use the cmdstanr backend
backend = "cmdstanr",
# request 4 threads per chain
threads = threading(4),
seed=345345,
control=list(adapt_delta=0.95)
)
It is important to note that between-chain parallelization is always more efficient than within-chain parallelization. Moreover, with increasing number of cores per chain, the efficiency of the parallelisation decreases. For a detailed discussion and instructions on how to tune the number of cores for a given problem, please refer to the vignette from brms
on threading available on CRAN.
Finally, an important remark on exact reproducibility of the results: In the default setting of threading exact reproducibility is not maintained! This stems from the fact that the large log likelihood sums are partitioned into blocks of varying size and these sums are moreover accumulated in a random order. Due to the limited floating point precision of CPU cores this leads to slightly varying results at the order of the machine precision (\(10^{-16}\) commonly). In case exactly reproducible results are required, Stan offers a static version of reduce_sum
. This variant can be requested by using the static=TRUE
argument of the threading function. However, when doing so, it is recommended to also provide a so-called grainsize
. The grainsize
is the number of terms which are included in every partial sum. This size should be large enough such that each partial sum formed represents a considerable amount of work, while being small enough to allow for good load balancing between the CPUs. The threading vignette linked above discusses strategies on how to define a reasonable grainsize
. The call with static
sum partitioning may then look like threading(2, 80, TRUE)
, which will request 2 threads per chain, a grainsize of \(80\) and static
sum partitioning.
17.1 Implementation of cmq_brm
# executes brm with parallelization via clustermq
<- function(..., seed, control=list(adapt_delta=0.9), cores, file, chains=4, .log_worker=FALSE) {
cmq_brm ::assert_integer(as.integer(seed), lower=1, any.missing=FALSE, len=1)
checkmate<- options()[grep("^brms", names(options()), value=TRUE)]
brms_global <- options()[grep("^cmdstanr", names(options()), value=TRUE)]
cmdstanr_global <- rlang::enquos(...)
dots <- lapply(dots, rlang::eval_tidy)
brms_args suppressWarnings(model <- do.call(brm, modifyList(brms_args, list(chains=0, cores=1))))
<- list(object=model, seed=seed, cores=1, chains=1, control=control)
update_args if(!missing(file)) {
$file <- sub("\\.rds$", "", file)
update_args
}if(!missing(file)) {
$file <- sub("\\.rds$", "", file)
brms_args
}<- .libPaths()
master_lib_paths <- function(chain_id) {
update_model .libPaths(master_lib_paths)
library(brms)
# in case file is part of extra-arguments, we add here the chain_id
# to get correct by-chain file caching
if("file" %in% names(update_args)) {
<- modifyList(update_args, list(file=paste0(update_args$file, "-", chain_id)))
update_args
}if("file" %in% names(brms_args)) {
<- modifyList(brms_args, list(file=paste0(brms_args$file, "-", chain_id)))
brms_args
}$chain_id <- chain_id
update_args$chain_id <- chain_id
brms_args# ensure the same brms & cmdstanr global options are set
options(brms_global)
options(cmdstanr_global)
## for the rstan backend we do an update while for cmdstanr we
## have to avoid this for whatever reason
if(model$backend == "cmdstanr") {
<- capture.output(fit <- do.call(brm, modifyList(brms_args, list(chains=1))))
msg else {
} <- capture.output(fit <- do.call(update, update_args))
msg
}list(fit=fit, msg=msg)
}<- chains
n_jobs <- getOption("clustermq.scheduler", "multiprocess")
backend if(backend %in% c("multiprocess", "multicore")) {
<- min(chains, getOption("mc.cores", 1))
n_jobs
}<- 1
cores_per_chain if(!is.null(model$threads$threads)) {
<- model$threads$threads
cores_per_chain
}if(chains == 1 & cores_per_chain == 1) {
## looks like a debugging run...avoid clustermq
return(update_model(1)$fit)
}message("Starting ", chains, " chains with a concurrency of ", n_jobs, " and using ", cores_per_chain, " cores per chain with backend ", backend, "...\n")
<- clustermq::Q(update_model, chain_id=1:chains, n_jobs=n_jobs, export=list(update_args=update_args, brms_args=brms_args, brms_global=brms_global, cmdstanr_global=cmdstanr_global, master_lib_paths=master_lib_paths, model=model),
cluster_update template=list(cores=cores_per_chain),
log_worker=.log_worker)
<- combine_models(mlist=lapply(cluster_update, "[[", "fit"))
fit <- lapply(cluster_update, "[[", "msg")
msg for(i in seq_len(length(msg))) {
if(length(msg[[i]]) == 0)
next
cat(paste0("Output for chain ", i, ":\n"))
cat(paste(msg[[i]], collapse="\n"), "\n")
}$file <- NULL
fit
fit }