torchsurv.loss.survival

torchsurv.loss.survival#

neg_log_likelihood(log_hz, event, time, eval_time, reduction='mean', checks=True)[source]#

Negative log-likelihood for a survival model.

Parameters:
  • log_hz (torch.Tensor, float) – Log hazard rates of shape = (n_samples, n_eval_time). The entry at row i and column j corresponds to the log relative hazard for subject i at the jth n_eval_time.

  • event (torch.Tensor, bool) – Event indicator (= True if event occurred) of shape = (n_samples,).

  • time (torch.Tensor, float) – Event or censoring time of shape = (n_samples,).

  • eval_time (torch.Tensor, float) – Times at which log_hz is evaluated of shape: (n_eval_time,)

  • reduction (str, optional) – Method to reduce losses. Defaults to β€œmean”. Must be one of the following: β€œsum”, β€œmean”.

  • checks (bool, optional) – Whether to perform input format checks. Enabling checks can help catch potential issues in the input data. Defaults to True.

Returns:

Negative of the log likelihood of survival model.

Return type:

torch.Tensor

Note

For each subject \(i \in \{1, \cdots, N\}\), denote \(X_i\) as the survival time and \(D_i\) as the censoring time. Survival data consist of the event indicator, \(\delta_i=1(X_i\leq D_i)\) (argument event) and the time-to-event or censoring, \(T_i = \min(\{ X_i,D_i \})\) (argument time).

Further, let \(\tau_1 < \tau_2 < \cdots < \tau_M\) be the evaluation times (argument eval_time), and \(\log h_i(\tau)\) be the log hazard function for subject \(i\) at time \(\tau\) (argument log_hz).

The (continuous) log-likelihood for the survival model is given by:

\[\text{ll} = - \sum_{i=1}^N \left( \delta_i \log h_i(T_i) - \int_0^{T_i} h_i(u) du \right).\]

We approximate the cumulative hazard term using the trapezoidal rule evaluated at discrete times \(\{\tau_1, \tau_2, \ldots, \tau_M\}\):

\[\int_0^{T_i} h_i(u)\,du \;\approx\; \sum_{k=2}^{K_i} \frac{h_i(\tau_{k-1}) + h_i(\tau_k)}{2} \, (\tau_k - \tau_{k-1}),\]

where \(K_i = \max\{\,k : \tau_k \le T_i\,\}\) is the index of the largest evaluation time not exceeding \(T_i\). The integration therefore begins at \(\tau_1\), which should represent the start of observation (often \(\tau_1 = 0\)).

If \(T_i\) does not coincide exactly with any of the evaluation times, the log-hazard at \(T_i\) is approximated by the value corresponding to the nearest evaluation time not exceeding \(T_i\), that is:

\[\log h_i(T_i) \;\approx\; \log h_i(\tau_{K_i}), \quad \text{where } K_i = \max\{\,k : \tau_k \le T_i\,\}.\]

Examples

>>> _ = torch.manual_seed(43)
>>> n, M = 4, 5
>>> eval_time = torch.linspace(0, 100, steps=M, dtype=torch.float)
>>> log_hz = torch.randn((n, M), dtype=torch.float)
>>> event = torch.randint(low=0, high=2, size=(n,), dtype=torch.bool)
>>> time = torch.randint(low=1, high=100, size=(n,), dtype=torch.float)
>>> neg_log_likelihood(log_hz, event, time, eval_time)  # default, mean of log likelihoods across patients
tensor(54.0886)
>>> neg_log_likelihood(
...     log_hz, event, time, eval_time, reduction="sum"
... )  # sum of log likelihoods across patients
tensor(216.3546)
survival_function(new_log_hz, new_time, eval_time, checks=True)[source]#

Compute the individual survival function for new subjects for the survival model.

Parameters:
  • new_log_hz (torch.Tensor, float) – Log hazard rates for new subjects of shape = (n_samples_new, n_eval_time).

  • new_time (torch.Tensor, float) – Time at which to evaluate the survival probability of shape = (n_times,).

  • eval_time (torch.Tensor, float) – Times at which new_log_hz is evaluated of shape = (n_eval_time,)

  • checks (bool, optional) – Whether to perform input format checks. Enabling checks can help catch potential issues in the input data. Defaults to True.

Returns:

Individual survival probabilities for each new subject at new_time of shape = (n_samples_new, n_times).

Return type:

torch.Tensor

Note

Let let \(\tau_1 < \tau_2 < \cdots < \tau_M\) be the evaluation times (argument eval_time), and \(\log h^{\star}_i(\tau)\) be the log hazard function for new subject \(i\) at time \(\tau\) (argument new_log_hz).

The estimated survival function for new subject \(i\) under the survival model is given by:

\[\hat{S}_i(t) = \exp\left(- \int_0^t thh_i^{\star}(u) du\right).\]

The cumulative hazard term, i.e. \(\int_0^t h_i^{\star}(u) du\), is approximated using the trapezoidal rule evaluated at discrete times \(\{\tau_1, \tau_2, \ldots, \tau_M\}\). The integration begins at \(\tau_1\), which should represent the start of observation (often \(\tau_1 = 0\)).

Examples

>>> eval_time = torch.linspace(0, 4.5, steps=3, dtype=torch.float)
>>> new_log_hz = torch.tensor([[0.15, 0.175, 0.2], [0.25, 0.5, 0.75]])  # 2 new subjects
>>> new_time = torch.tensor([2.5, 4.5])
>>> survival_function(new_log_hz, new_time, eval_time)
tensor([[0.0708, 0.0047],
        [0.0369, 0.0005]])