torchsurv.loss.momentum#

Classes

Momentum(backbone, loss[, batchsize, steps, ...])

Survival framework to momentum update learning to decouple batch size during model training.

class torchsurv.loss.momentum.Momentum(backbone: Module, loss: Callable, batchsize: int = 16, steps: int = 4, rate: float = 0.999)[source]#

Survival framework to momentum update learning to decouple batch size during model training. Two networks are concurently trained, an online network and a target network. The online network outputs batches are concanetaed and used by the target network, so it virtually increase its batchsize.

The target network (k)is updated using an exponential momentum average (EMA) using parameters from the online network (q). The online network is trained using a memory bank of previously computed log hazards, but only tracking loss from current batch.

\[\theta_k \leftarrow m \theta_k + (1 - m) \theta_q\]

with m=0.999.

Here is a pseudo python code to illustrate what is going on under the hood.

model_q = model_k                 # Same architecture, random weights
model_k.require_grad = False      # No gradient update for target network (k)
hz_memory_bank = deque(maxlen=n * batch_size)  # Double-ended queue size n * batch_size

for epoch in epochs:
    hz_q = model_q(batch)            # Compute current estmate w/ ONLINE network (q)
    hz_loss = hz_memory_bank + hz_q  # Combine current log hz and memory bank
    loss = loss_function(hz_loss)    # Compute loss with pooled log hz
    loss.backward()                  # Update online model (q) w/ PyTorch autograd
    model_k.ema_update(model_q)      # Update target model (k) with Exponential Moving Average (EMA)
    hz_k = model_k(batch)            # Compute batch estimate w/ TARGET network (k)
    hz_memory_bank += hz_k           # Replace oldest batch with current from memory bank

Note

This code is inspired from MoCo [HFW+19] and its ability to decouple batch size from training size.

Note

A notable difference is that memory bank is updated at the end of the step, not at the beginning. This is because MoCo uses the target network batch for the positive pair, while we use a rank-based loss function. We then need to exclude the current batch from the memory bank to effectively compute our loss.

References

[HFW+19]

Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, and Ross Girshick. Momentum contrast for unsupervised visual representation learning. 2019. URL: http://arxiv.org/abs/1911.05722.

__init__(backbone: Module, loss: Callable, batchsize: int = 16, steps: int = 4, rate: float = 0.999)[source]#

Initialise the momentum class. Use must provide their model as backbone.

Parameters:
  • backbone (nn.Module) – Torch model to be use as backbone. The model must return either one (Cox) or two ouputs (Weibull)

  • loss (Callable) – Torchsurv loss function (Cox, Weibull)

  • batchsize (int, optional) – Number of samples per batch. Defaults to 16.

  • n (int, optional) – Number of queued batches to be stored for training. Defaults to 4.

  • rate (float, optional) – Exponential moving average rate. Defaults to 0.999.

Examples

>>> from torchsurv.loss import cox, weibull
>>> _ = torch.manual_seed(42)
>>> n = 4
>>> params = torch.randn((n, 16))
>>> events = torch.randint(low=0, high=2, size=(n,), dtype=torch.bool)
>>> times = torch.randint(low=1, high=100, size=(n,))
>>> backbone = torch.nn.Sequential(torch.nn.Linear(16, 1))  # Cox expect one ouput
>>> model = Momentum(backbone=backbone, loss=cox.neg_partial_log_likelihood)
>>> model(params, events, times)
tensor(0.0978, grad_fn=<DivBackward0>)
>>> model.online(params)  # online network (q) - w/ gradient
tensor([[-0.7867],
        [ 0.3161],
        [-1.2158],
        [-0.8195]], grad_fn=<AddmmBackward0>)
>>> model.target(params)  # target network (k) - w/o gradient
tensor([[-0.7867],
        [ 0.3161],
        [-1.2158],
        [-0.8195]])

Note

self.encoder_k is the recommended to be used for inference. It refers to the target network (momentum).

forward(inputs: Tensor, event: Tensor, time: Tensor) Tensor[source]#

Compute the loss for the current batch and update the memory bank using momentum class.

Parameters:
  • inputs (torch.Tensor) – Input tensors to the backbone model

  • event (torch.Tensor) – A boolean tensor indicating whether a patient experienced an event.

  • time (torch.Tensor) – A positive float tensor representing time to event (or censoring time)

Returns:

A loss tensor for the current batch.

Return type:

torch.Tensor

Examples

>>> from torchsurv.loss import cox, weibull
>>> _ = torch.manual_seed(42)
>>> n = 128  # samples
>>> x = torch.randn((n, 16))
>>> y = torch.randint(low=0, high=2, size=(n,)).bool()
>>> t = torch.randint(low=1, high=100, size=(n,))
>>> backbone = torch.nn.Sequential(torch.nn.Linear(16, 1))  # (log hazards)
>>> model_cox = Momentum(backbone, loss=cox.neg_partial_log_likelihood)  # Cox loss
>>> with torch.no_grad(): model_cox.forward(x, y, t)
tensor(2.1366)
>>> backbone = torch.nn.Sequential(torch.nn.Linear(16, 2))  # (lambda, rho)
>>> model_weibull = Momentum(backbone, loss=weibull.neg_log_likelihood)  # Weibull loss
>>> with torch.no_grad(): torch.round(model_weibull.forward(x, y, t), decimals=2)
tensor(68.0400)
infer(inputs: Tensor) Tensor[source]#

Evaluate data with target network

Parameters:

x (torch.Tensor) – Input tensors to the backbone model

Returns:

Predictions from target (momentum) network without augmentation (.eval()) nor gradient.

Return type:

torch.Tensor

Examples

>>> from torchsurv.loss import weibull
>>> _ = torch.manual_seed(42)
>>> backbone = torch.nn.Sequential(torch.nn.Linear(8, 2))  # Weibull expect two ouputs
>>> model = Momentum(backbone=backbone, loss=weibull.neg_log_likelihood)
>>> model.infer(torch.randn((3, 8)))
tensor([[ 0.5342,  0.0062],
        [ 0.6439,  0.7863],
        [ 0.9771, -0.8513]])