torchsurv.loss.momentum#
Classes
|
Survival framework to momentum update learning to decouple batch size during model training. |
- class torchsurv.loss.momentum.Momentum(backbone: Module, loss: Callable, batchsize: int = 16, steps: int = 4, rate: float = 0.999)[source]#
Survival framework to momentum update learning to decouple batch size during model training. Two networks are concurently trained, an online network and a target network. The online network outputs batches are concanetaed and used by the target network, so it virtually increase its batchsize.
The target network (k)is updated using an exponential momentum average (EMA) using parameters from the online network (q). The online network is trained using a memory bank of previously computed log hazards, but only tracking loss from current batch.
\[\theta_k \leftarrow m \theta_k + (1 - m) \theta_q\]with m=0.999.
Here is a pseudo python code to illustrate what is going on under the hood.
model_q = model_k # Same architecture, random weights model_k.require_grad = False # No gradient update for target network (k) hz_memory_bank = deque(maxlen=n * batch_size) # Double-ended queue size n * batch_size for epoch in epochs: hz_q = model_q(batch) # Compute current estmate w/ ONLINE network (q) hz_loss = hz_memory_bank + hz_q # Combine current log hz and memory bank loss = loss_function(hz_loss) # Compute loss with pooled log hz loss.backward() # Update online model (q) w/ PyTorch autograd model_k.ema_update(model_q) # Update target model (k) with Exponential Moving Average (EMA) hz_k = model_k(batch) # Compute batch estimate w/ TARGET network (k) hz_memory_bank += hz_k # Replace oldest batch with current from memory bank
Note
This code is inspired from MoCo [HFW+19] and its ability to decouple batch size from training size.
Note
A notable difference is that memory bank is updated at the end of the step, not at the beginning. This is because MoCo uses the target network batch for the positive pair, while we use a rank-based loss function. We then need to exclude the current batch from the memory bank to effectively compute our loss.
References
[HFW+19]Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, and Ross Girshick. Momentum contrast for unsupervised visual representation learning. 2019. URL: http://arxiv.org/abs/1911.05722.
- __init__(backbone: Module, loss: Callable, batchsize: int = 16, steps: int = 4, rate: float = 0.999)[source]#
Initialise the momentum class. Use must provide their model as backbone.
- Parameters:
backbone (nn.Module) – Torch model to be use as backbone. The model must return either one (Cox) or two ouputs (Weibull)
loss (Callable) – Torchsurv loss function (Cox, Weibull)
batchsize (int, optional) – Number of samples per batch. Defaults to 16.
n (int, optional) – Number of queued batches to be stored for training. Defaults to 4.
rate (float, optional) – Exponential moving average rate. Defaults to 0.999.
Examples
>>> from torchsurv.loss import cox, weibull >>> _ = torch.manual_seed(42) >>> n = 4 >>> params = torch.randn((n, 16)) >>> events = torch.randint(low=0, high=2, size=(n,), dtype=torch.bool) >>> times = torch.randint(low=1, high=100, size=(n,)) >>> backbone = torch.nn.Sequential(torch.nn.Linear(16, 1)) # Cox expect one ouput >>> model = Momentum(backbone=backbone, loss=cox.neg_partial_log_likelihood) >>> model(params, events, times) tensor(0.0978, grad_fn=<DivBackward0>) >>> model.online(params) # online network (q) - w/ gradient tensor([[-0.7867], [ 0.3161], [-1.2158], [-0.8195]], grad_fn=<AddmmBackward0>) >>> model.target(params) # target network (k) - w/o gradient tensor([[-0.7867], [ 0.3161], [-1.2158], [-0.8195]])
Note
self.encoder_k is the recommended to be used for inference. It refers to the target network (momentum).
- forward(inputs: Tensor, event: Tensor, time: Tensor) Tensor [source]#
Compute the loss for the current batch and update the memory bank using momentum class.
- Parameters:
inputs (torch.Tensor) – Input tensors to the backbone model
event (torch.Tensor) – A boolean tensor indicating whether a patient experienced an event.
time (torch.Tensor) – A positive float tensor representing time to event (or censoring time)
- Returns:
A loss tensor for the current batch.
- Return type:
torch.Tensor
Examples
>>> from torchsurv.loss import cox, weibull >>> _ = torch.manual_seed(42) >>> n = 128 # samples >>> x = torch.randn((n, 16)) >>> y = torch.randint(low=0, high=2, size=(n,)).bool() >>> t = torch.randint(low=1, high=100, size=(n,)) >>> backbone = torch.nn.Sequential(torch.nn.Linear(16, 1)) # (log hazards) >>> model_cox = Momentum(backbone, loss=cox.neg_partial_log_likelihood) # Cox loss >>> with torch.no_grad(): model_cox.forward(x, y, t) tensor(2.1366) >>> backbone = torch.nn.Sequential(torch.nn.Linear(16, 2)) # (lambda, rho) >>> model_weibull = Momentum(backbone, loss=weibull.neg_log_likelihood) # Weibull loss >>> with torch.no_grad(): torch.round(model_weibull.forward(x, y, t), decimals=2) tensor(68.0400)
- infer(inputs: Tensor) Tensor [source]#
Evaluate data with target network
- Parameters:
x (torch.Tensor) – Input tensors to the backbone model
- Returns:
Predictions from target (momentum) network without augmentation (.eval()) nor gradient.
- Return type:
torch.Tensor
Examples
>>> from torchsurv.loss import weibull >>> _ = torch.manual_seed(42) >>> backbone = torch.nn.Sequential(torch.nn.Linear(8, 2)) # Weibull expect two ouputs >>> model = Momentum(backbone=backbone, loss=weibull.neg_log_likelihood) >>> model.infer(torch.randn((3, 8))) tensor([[ 0.5342, 0.0062], [ 0.6439, 0.7863], [ 0.9771, -0.8513]])