17  Parallel computation

Author

Sebastian Weber -

For large data-sets MCMC sampling can become very slow. The computational cost of any Stan program is dominated by the calculation of the gradient of the log likelihood function of the model. Most statistical models involve the summation over many independent contributions to the log likelihood. As the sum is associative - the result is invariant to re-orderings of individual terms of the sum - the log likelihood evaluation can be parallelized over multiple CPU cores.

brms supports the within-chain parallelization feature of Stan in a way which is fully automatic for the user. However, it does require to request threading support when calling the brm command, since then the Stan model is written in a different manner using the reduce_sum feature in Stan. Furthermore, the reduce_sum feature is part of the Stan modeling language since version 2.23.0. In these vignettes, we use a cmdstan installation with a sufficiently high Stan version of 2.32.2.

Thus, for a user to take advantage of within-chain parallelization means to slightly change the call to brm like:

fit <- brm(
    bf(
        AVAL ~ TRT01P + gp(AVISITN, by = TRT01P),
        autocor = ~ cosy(time = AVISIT, gr = SUBJID)
    ),
    data = analysis_data2,
    silent = 2,
    refresh = 0,
    seed = 45646,
    chains = 4,
    cores = 4,
    backend = "cmdstanr",  # request cmdstanr as backend
    threads = threading(4) # request 4 threads per chain
)

The above call will request to run 4 chains in parallel (cores=4) and will allocate 2 threads per chain such that in total 8 cores will be used at the same time. It can be somewhat confusing to users that cores is set to 4 and we still use 8 physical cores. This is due to historical reasons as the meaning of cores refers to parallel chains running at the same time and threading was added at a later stage to Stan.

Within-chain parallelization is only useful if sufficient cores have been allocated. For example, on a machine with 4 cores, the potential benefits are limited. In such a setting, one may consider to run 2 chains and use 2 threads per chain. This is a compromise given that we usually wish to run 4 chains.

If a computer cluster, or other distributed computing environment is available, within-chain parallelization may be advantageous, but users must take care to request an appropriate number of resources. One must a job to the main cluster which allocates a sufficient amount of CPU cores. The utility function cmq_brm below facilites running the posterior sampling on the cluster. The idea is to define first the brms model in the current session, but do not perform any actual sampling in the current process by setting the arguemnt chains=0. The so defined model is given to the brm_update_cluster function which then uses clustermq to run the sampling on the registered backend of clustermq. The clustermq package then will sample either locally or on the cluster:

# Fit model as usual, just use "cmq_brm" instead of brm call, see
# definition at the bottom
model_poisson <- cmq_brm(y ~ 1 + x1 + x2 + (1 | g),
                         data = fake, 
                         family = poisson(),
                         iter = 2000,
                         warmup = 1000,
                         prior = prior(normal(0,1), class = b) +
                             prior(constant(1), class = sd, group = g),
                         # use the cmdstanr backend
                         backend = "cmdstanr",
                         # request 4 threads per chain
                         threads = threading(4),
                         seed=345345,
                         control=list(adapt_delta=0.95)
                         )

It is important to note that between-chain parallelization is always more efficient than within-chain parallelization. Moreover, with increasing number of cores per chain, the efficiency of the parallelisation decreases. For a detailed discussion and instructions on how to tune the number of cores for a given problem, please refer to the vignette from brms on threading available on CRAN.

Finally, an important remark on exact reproducibility of the results: In the default setting of threading exact reproducibility is not maintained! This stems from the fact that the large log likelihood sums are partitioned into blocks of varying size and these sums are moreover accumulated in a random order. Due to the limited floating point precision of CPU cores this leads to slightly varying results at the order of the machine precision (\(10^{-16}\) commonly). In case exactly reproducible results are required, Stan offers a static version of reduce_sum. This variant can be requested by using the static=TRUE argument of the threading function. However, when doing so, it is recommended to also provide a so-called grainsize. The grainsize is the number of terms which are included in every partial sum. This size should be large enough such that each partial sum formed represents a considerable amount of work, while being small enough to allow for good load balancing between the CPUs. The threading vignette linked above discusses strategies on how to define a reasonable grainsize. The call with static sum partitioning may then look like threading(2, 80, TRUE), which will request 2 threads per chain, a grainsize of \(80\) and static sum partitioning.

17.1 Implementation of cmq_brm

# executes brm with parallelization via clustermq
cmq_brm <- function(..., seed, control=list(adapt_delta=0.9), cores, file, chains=4, .log_worker=FALSE) {
    checkmate::assert_integer(as.integer(seed), lower=1, any.missing=FALSE, len=1)
    brms_global <- options()[grep("^brms", names(options()), value=TRUE)]
    cmdstanr_global <- options()[grep("^cmdstanr", names(options()), value=TRUE)]
    dots <- rlang::enquos(...)
    brms_args <- lapply(dots, rlang::eval_tidy)
    suppressWarnings(model <- do.call(brm, modifyList(brms_args, list(chains=0, cores=1))))
    update_args <- list(object=model, seed=seed, cores=1, chains=1, control=control)
    if(!missing(file)) {
        update_args$file <- sub("\\.rds$", "", file)
    }
    if(!missing(file)) {
        brms_args$file <- sub("\\.rds$", "", file)
    }
    master_lib_paths <- .libPaths()
    update_model <- function(chain_id) {
        .libPaths(master_lib_paths)
        library(brms)
        # in case file is part of extra-arguments, we add here the chain_id
        # to get correct by-chain file caching
        if("file" %in% names(update_args)) {
            update_args <- modifyList(update_args, list(file=paste0(update_args$file, "-", chain_id)))
        }
        if("file" %in% names(brms_args)) {
            brms_args <- modifyList(brms_args, list(file=paste0(brms_args$file, "-", chain_id)))
        }
        update_args$chain_id <- chain_id
        brms_args$chain_id <- chain_id
        # ensure the same brms & cmdstanr global options are set
        options(brms_global)
        options(cmdstanr_global)
        ## for the rstan backend we do an update while for cmdstanr we
        ## have to avoid this for whatever reason
        if(model$backend == "cmdstanr") {
            msg <- capture.output(fit <- do.call(brm, modifyList(brms_args, list(chains=1))))
        } else {
            msg <- capture.output(fit <- do.call(update, update_args))
        }
        list(fit=fit, msg=msg)
    }
    n_jobs <- chains
    backend <- getOption("clustermq.scheduler", "multiprocess")
    if(backend %in% c("multiprocess", "multicore")) {
        n_jobs <- min(chains, getOption("mc.cores", 1))
    }
    cores_per_chain <- 1
    if(!is.null(model$threads$threads)) {
        cores_per_chain <- model$threads$threads
    }
    if(chains == 1 & cores_per_chain == 1) {
        ## looks like a debugging run...avoid clustermq
        return(update_model(1)$fit)
    }
    message("Starting ", chains, " chains with a concurrency of ", n_jobs, " and using ", cores_per_chain, " cores per chain with backend ", backend, "...\n")
    cluster_update <- clustermq::Q(update_model, chain_id=1:chains, n_jobs=n_jobs, export=list(update_args=update_args, brms_args=brms_args, brms_global=brms_global, cmdstanr_global=cmdstanr_global, master_lib_paths=master_lib_paths, model=model),
                                   template=list(cores=cores_per_chain),
                                   log_worker=.log_worker)
    fit <- combine_models(mlist=lapply(cluster_update, "[[", "fit"))
    msg <- lapply(cluster_update, "[[", "msg")
    for(i in seq_len(length(msg))) {
        if(length(msg[[i]]) == 0)
            next
        cat(paste0("Output for chain ", i, ":\n"))
        cat(paste(msg[[i]], collapse="\n"), "\n")
    }
    fit$file <- NULL
    fit
}