19  Parallel computation

Author

Sebastian Weber -

For large data-sets MCMC sampling can become very slow. The computational cost of any Stan program is dominated by the calculation of the gradient of the log likelihood function of the model. Most statistical models involve the summation over many independent contributions to the log likelihood. As the sum is associative - the result is invariant to re-orderings of individual terms of the sum - the log likelihood evaluation can be parallelized over multiple CPU cores.

19.1 Within-chain parallelization with threading

brms supports the within-chain parallelization feature of Stan in a way which is fully automatic for the user. However, it does require to request threading support when calling the brm command, since then the Stan model is written in a different manner using the reduce_sum feature in Stan. Furthermore, the reduce_sum feature is part of the Stan modeling language since version 2.23.0. In these vignettes, we use a cmdstan installation with a sufficiently high Stan version of 2.32.2.

Thus, for a user to take advantage of within-chain parallelization means to slightly change the call to brm like:

fit <- brm(
  bf(
    AVAL ~ TRT01P + gp(AVISITN, by = TRT01P),
    autocor = ~ cosy(time = AVISIT, gr = SUBJID)
  ),
  data = analysis_data2,
  silent = 2,
  refresh = 0,
  seed = 45646,
  chains = 4,
  cores = 4,
  backend = "cmdstanr", # request cmdstanr as backend
  threads = threading(4) # request 4 threads per chain
)

The above call will request to run 4 chains in parallel (cores=4) and will allocate 4 threads per chain such that in total 16 cores will be used at the same time. It can be somewhat confusing to users that cores is set to 4 and we still use 16 physical cores. This is due to historical reasons as the meaning of cores refers to parallel chains running at the same time and threading was added at a later stage to Stan.

19.2 Running on multiple machines

Within-chain parallelization is only useful if sufficient cores have been allocated. For example, on a machine with 4 cores, the potential benefits are limited. In such a setting, one may consider to run 2 chains and use 2 threads per chain. This is a compromise given that we usually wish to run 4 chains.

If a computer cluster, or other distributed computing environment is available, within-chain parallelization may be advantageous, but users must take care to request an appropriate number of resources. One must submit a job to a compute cluster which allocates a sufficient amount of CPU cores.

A modern way in R to exploit parallelism is the mirai framework, which can use local and remote ressources. While brms has not been written with mirai in mind, we can still adapt brms to use mirai backends through the use of future and future.mirai as:

# loading future makes brms use it for parallelization
library(future)

# first instantiate mirai daemons; here we use all available cores on
# the local machine
mirai::daemons(parallelly::availableCores())

# then register the mirai cluster as future backend
plan(future.mirai::mirai_cluster)

# point brms by default to use the future backend
options(future = TRUE)

source(here::here("src", "simulate_fake_data.R"))
fake_data <- simulate_fake_data()

brm(
  y ~ 1 + x1 + x2 + (1 | g),
  data = fake_data,
  family = poisson(),
  iter = 2000,
  warmup = 1000,
  prior = prior(normal(0, 1), class = b) +
    prior(constant(1), class = sd, group = g),
  backend = "cmdstanr",
  # request 4 threads per chain
  threads = threading(4),
  chains = 4,
  seed = 345345,
  refresh = 0,
  control = list(adapt_delta = 0.95)
)

# terminate mirai daemons (or keep them running if used later again)
mirai::daemons(0)

In the above example a mirai cluster is used with locally started workers which requests 4 parallel chains with each 4 threads. To run, for example, on an IBM LSF managed cluster one can replace the above first call to mirai::deamons with:

# Remote worker setup based for a LSF cluster which uses "modules" to
# load specific R versions, common on HPCE systems. Requesting 3000 MB
# of RAM (-M 3000) and 4 cores per job (-n 4) which all run on the
# very same machine (-R "span[hosts=1]").
lsf_config <- mirai::cluster_config(
  command = "bsub",
  options = '#BSUB -J mirai
             #BSUB -M 3000
             #BSUB -n 4
             #BSUB -R "span[hosts=1]"
             #BSUB -o job.out
             module load R/4.5.0',
  rscript = "Rscript"
)

mirai::daemons(n = 4, url = mirai::host_url(), remote = lsf_config)

For more options on how to spawn remote workers, please refer to the mirai documentation; in particular the section on remote daemons.

As HPCEs may not always have the latest R software installed, we provide here an alternative way, which is based on the well established R package clustermq. This R package is tailored to interact with a queing system on a cluster and below the utlity function cmq_brm is provided. This function facilites running the posterior sampling on the cluster via clustermq. The idea is to define first the brms model in the current session, but do not perform any actual sampling in the current process by setting the arguemnt chains=0. The so defined model is given to the update_model function which then uses clustermq to run the sampling on the registered backend of clustermq. The clustermq package then will sample either locally or on the cluster:

# Fit model as usual, just use "cmq_brm" instead of brm call, see
# definition at the bottom
model_poisson <- cmq_brm(
  y ~ 1 + x1 + x2 + (1 | g),
  data = fake,
  family = poisson(),
  iter = 2000,
  warmup = 1000,
  prior = prior(normal(0, 1), class = b) +
    prior(constant(1), class = sd, group = g),
  # use the cmdstanr backend
  backend = "cmdstanr",
  # request 4 threads per chain
  threads = threading(4),
  seed = 345345,
  control = list(adapt_delta = 0.95)
)

It is important to note that between-chain parallelization is always more efficient than within-chain parallelization. Moreover, with increasing number of cores per chain, the efficiency of the parallelisation decreases. For a detailed discussion and instructions on how to tune the number of cores for a given problem, please refer to the vignette from brms on threading available on CRAN.

19.3 Exact reproducability

In the default setting of threading exact reproducibility is not maintained! This stems from the fact that the large log likelihood sums are partitioned into blocks of varying size and these sums are moreover accumulated in a random order. Due to the limited floating point precision of CPU cores this leads to slightly varying results at the order of the machine precision (\(10^{-16}\) commonly). In case exactly reproducible results are required, Stan offers a static version of reduce_sum. This variant can be requested by using the static=TRUE argument of the threading function. However, when doing so, it is recommended to also provide a so-called grainsize. The grainsize is the number of terms which are included in every partial sum. This size should be large enough such that each partial sum formed represents a considerable amount of work, while being small enough to allow for good load balancing between the CPUs. The threading vignette linked above discusses strategies on how to define a reasonable grainsize. The call with static sum partitioning may then look like threading(2, 80, TRUE), which will request 2 threads per chain, a grainsize of \(80\) and static sum partitioning.

19.4 Implementation of cmq_brm

# executes brm with parallelization via clustermq
cmq_brm <- function(
  ...,
  seed,
  control = list(adapt_delta = 0.9),
  cores,
  file,
  chains = 4,
  .log_worker = FALSE
) {
  checkmate::assert_integer(
    as.integer(seed),
    lower = 1,
    any.missing = FALSE,
    len = 1
  )
  brms_global <- options()[grep("^brms", names(options()), value = TRUE)]
  cmdstanr_global <- options()[grep(
    "^cmdstanr",
    names(options()),
    value = TRUE
  )]
  dots <- rlang::enquos(...)
  brms_args <- lapply(dots, rlang::eval_tidy)
  suppressWarnings(
    model <- do.call(brm, modifyList(brms_args, list(chains = 0, cores = 1)))
  )
  update_args <- list(
    object = model,
    seed = seed,
    cores = 1,
    chains = 1,
    control = control
  )
  if (!missing(file)) {
    update_args$file <- sub("\\.rds$", "", file)
  }
  if (!missing(file)) {
    brms_args$file <- sub("\\.rds$", "", file)
  }
  master_lib_paths <- .libPaths()
  update_model <- function(chain_id) {
    .libPaths(master_lib_paths)
    library(brms)
    # in case file is part of extra-arguments, we add here the chain_id
    # to get correct by-chain file caching
    if ("file" %in% names(update_args)) {
      update_args <- modifyList(
        update_args,
        list(file = paste0(update_args$file, "-", chain_id))
      )
    }
    if ("file" %in% names(brms_args)) {
      brms_args <- modifyList(
        brms_args,
        list(file = paste0(brms_args$file, "-", chain_id))
      )
    }
    update_args$chain_id <- chain_id
    brms_args$chain_id <- chain_id
    # ensure the same brms & cmdstanr global options are set
    options(brms_global)
    options(cmdstanr_global)
    ## for the rstan backend we do an update while for cmdstanr we
    ## have to avoid this for whatever reason
    if (model$backend == "cmdstanr") {
      msg <- capture.output(
        fit <- do.call(brm, modifyList(brms_args, list(chains = 1)))
      )
    } else {
      msg <- capture.output(fit <- do.call(update, update_args))
    }
    list(fit = fit, msg = msg)
  }
  n_jobs <- chains
  backend <- getOption("clustermq.scheduler", "multiprocess")
  if (backend %in% c("multiprocess", "multicore")) {
    n_jobs <- min(chains, getOption("mc.cores", 1))
  }
  cores_per_chain <- 1
  if (!is.null(model$threads$threads)) {
    cores_per_chain <- model$threads$threads
  }
  if (chains == 1 & cores_per_chain == 1) {
    ## looks like a debugging run...avoid clustermq
    return(update_model(1)$fit)
  }
  message(
    "Starting ",
    chains,
    " chains with a concurrency of ",
    n_jobs,
    " and using ",
    cores_per_chain,
    " cores per chain with backend ",
    backend,
    "...\n"
  )
  cluster_update <- clustermq::Q(
    update_model,
    chain_id = 1:chains,
    n_jobs = n_jobs,
    export = list(
      update_args = update_args,
      brms_args = brms_args,
      brms_global = brms_global,
      cmdstanr_global = cmdstanr_global,
      master_lib_paths = master_lib_paths,
      model = model
    ),
    template = list(cores = cores_per_chain),
    log_worker = .log_worker
  )
  fit <- combine_models(mlist = lapply(cluster_update, "[[", "fit"))
  msg <- lapply(cluster_update, "[[", "msg")
  for (i in seq_len(length(msg))) {
    if (length(msg[[i]]) == 0) {
      next
    }
    cat(paste0("Output for chain ", i, ":\n"))
    cat(paste(msg[[i]], collapse = "\n"), "\n")
  }
  fit$file <- NULL
  fit
}

20 Example fake data simulate_fake_data

simulate_fake_data <- function(N = 1E4, G = 1E3, P = 3, seed = 46765875) {
  withr::local_seed(seed)

  # regression coefficients
  beta <- rnorm(P)

  # sampled covariates, group means and fake data
  fake <- matrix(rnorm(N * P), ncol = P)
  dimnames(fake) <- list(NULL, paste0("x", 1:P))

  # fixed effect part and sampled group membership
  fake <- transform(
    as.data.frame(fake),
    theta = fake %*% beta,
    g = sample.int(G, N, replace = TRUE)
  )

  # add random intercept by group
  fake <- merge(fake, data.frame(g = 1:G, eta = rnorm(G)), by = "g")

  # linear predictor
  fake <- transform(fake, mu = theta + eta)

  # sample Poisson data
  fake <- transform(fake, y = rpois(N, exp(mu)))

  # shuffle order of data rows to ensure even distribution of computational effort
  fake <- fake[sample.int(N, N), ]

  # drop not needed row names
  rownames(fake) <- NULL

  return(fake)
}