torchsurv.loss.survival#
- neg_log_likelihood(log_hz, event, time, eval_time, reduction='mean', checks=True)[source]#
Negative log-likelihood for a survival model.
- Parameters:
log_hz (torch.Tensor, float) β Log hazard rates of shape = (n_samples, n_eval_time). The entry at row i and column j corresponds to the log relative hazard for subject i at the jth
n_eval_time.event (torch.Tensor, bool) β Event indicator (= True if event occurred) of shape = (n_samples,).
time (torch.Tensor, float) β Event or censoring time of shape = (n_samples,).
eval_time (torch.Tensor, float) β Times at which
log_hzis evaluated of shape: (n_eval_time,)reduction (str, optional) β Method to reduce losses. Defaults to βmeanβ. Must be one of the following: βsumβ, βmeanβ.
checks (bool, optional) β Whether to perform input format checks. Enabling checks can help catch potential issues in the input data. Defaults to True.
- Returns:
Negative of the log likelihood of survival model.
- Return type:
torch.Tensor
Note
For each subject \(i \in \{1, \cdots, N\}\), denote \(X_i\) as the survival time and \(D_i\) as the censoring time. Survival data consist of the event indicator, \(\delta_i=1(X_i\leq D_i)\) (argument
event) and the time-to-event or censoring, \(T_i = \min(\{ X_i,D_i \})\) (argumenttime).Further, let \(\tau_1 < \tau_2 < \cdots < \tau_M\) be the evaluation times (argument
eval_time), and \(\log h_i(\tau)\) be the log hazard function for subject \(i\) at time \(\tau\) (argumentlog_hz).The (continuous) log-likelihood for the survival model is given by:
\[\text{ll} = - \sum_{i=1}^N \left( \delta_i \log h_i(T_i) - \int_0^{T_i} h_i(u) du \right).\]We approximate the cumulative hazard term using the trapezoidal rule evaluated at discrete times \(\{\tau_1, \tau_2, \ldots, \tau_M\}\):
\[\int_0^{T_i} h_i(u)\,du \;\approx\; \sum_{k=2}^{K_i} \frac{h_i(\tau_{k-1}) + h_i(\tau_k)}{2} \, (\tau_k - \tau_{k-1}),\]where \(K_i = \max\{\,k : \tau_k \le T_i\,\}\) is the index of the largest evaluation time not exceeding \(T_i\). The integration therefore begins at \(\tau_1\), which should represent the start of observation (often \(\tau_1 = 0\)).
If \(T_i\) does not coincide exactly with any of the evaluation times, the log-hazard at \(T_i\) is approximated by the value corresponding to the nearest evaluation time not exceeding \(T_i\), that is:
\[\log h_i(T_i) \;\approx\; \log h_i(\tau_{K_i}), \quad \text{where } K_i = \max\{\,k : \tau_k \le T_i\,\}.\]Examples
>>> _ = torch.manual_seed(43) >>> n, M = 4, 5 >>> eval_time = torch.linspace(0, 100, steps=M, dtype=torch.float) >>> log_hz = torch.randn((n, M), dtype=torch.float) >>> event = torch.randint(low=0, high=2, size=(n,), dtype=torch.bool) >>> time = torch.randint(low=1, high=100, size=(n,), dtype=torch.float) >>> neg_log_likelihood(log_hz, event, time, eval_time) # default, mean of log likelihoods across patients tensor(54.0886) >>> neg_log_likelihood( ... log_hz, event, time, eval_time, reduction="sum" ... ) # sum of log likelihoods across patients tensor(216.3546)
- survival_function(new_log_hz, new_time, eval_time, checks=True)[source]#
Compute the individual survival function for new subjects for the survival model.
- Parameters:
new_log_hz (torch.Tensor, float) β Log hazard rates for new subjects of shape = (n_samples_new, n_eval_time).
new_time (torch.Tensor, float) β Time at which to evaluate the survival probability of shape = (n_times,).
eval_time (torch.Tensor, float) β Times at which
new_log_hzis evaluated of shape = (n_eval_time,)checks (bool, optional) β Whether to perform input format checks. Enabling checks can help catch potential issues in the input data. Defaults to True.
- Returns:
Individual survival probabilities for each new subject at
new_timeof shape = (n_samples_new, n_times).- Return type:
torch.Tensor
Note
Let let \(\tau_1 < \tau_2 < \cdots < \tau_M\) be the evaluation times (argument
eval_time), and \(\log h^{\star}_i(\tau)\) be the log hazard function for new subject \(i\) at time \(\tau\) (argumentnew_log_hz).The estimated survival function for new subject \(i\) under the survival model is given by:
\[\hat{S}_i(t) = \exp\left(- \int_0^t thh_i^{\star}(u) du\right).\]The cumulative hazard term, i.e. \(\int_0^t h_i^{\star}(u) du\), is approximated using the trapezoidal rule evaluated at discrete times \(\{\tau_1, \tau_2, \ldots, \tau_M\}\). The integration begins at \(\tau_1\), which should represent the start of observation (often \(\tau_1 = 0\)).
Examples
>>> eval_time = torch.linspace(0, 4.5, steps=3, dtype=torch.float) >>> new_log_hz = torch.tensor([[0.15, 0.175, 0.2], [0.25, 0.5, 0.75]]) # 2 new subjects >>> new_time = torch.tensor([2.5, 4.5]) >>> survival_function(new_log_hz, new_time, eval_time) tensor([[0.0708, 0.0047], [0.0369, 0.0005]])