torchsurv.loss.weibull

torchsurv.loss.weibull#

Functions

log_hazard(new_log_params, new_time[, ...])

Log hazard of the Weibull Accelerated Time Failure (AFT) survival model.

neg_log_likelihood(log_params, event, time)

Negative of the log likelihood for the Weibull Accelerated Time Failure (AFT) survival model.

survival_function(new_log_params, new_time)

Survival function for the Weibull Accelerated Time Failure (AFT) survival model.

torchsurv.loss.weibull.neg_log_likelihood(log_params: Tensor, event: Tensor, time: Tensor, reduction: str = 'mean', checks: bool = True) Tensor[source]#

Negative of the log likelihood for the Weibull Accelerated Time Failure (AFT) survival model.

Parameters:
  • log_params (torch.Tensor, float) – Parameters of the Weibull distribution of shape = (n_samples, 1) or (n_samples, 2). The first column corresponds to the log scale parameter. The second column corresponds to the log shape parameter. If the log shape parameter is missing, it is imputed with 0.

  • event (torch.Tensor, bool) – Event indicator of length n_samples (= True if event occurred).

  • time (torch.Tensor, float) – Event or censoring time of length n_samples.

  • reduction (str) – Method to reduce losses. Defaults to “mean”. Must be one of the following: “sum”, “mean”.

  • checks (bool) – Whether to perform input format checks. Enabling checks can help catch potential issues in the input data. Defaults to True.

Returns:

Negative of the log likelihood.

Return type:

(torch.Tensor, float)

Note

For each subject \(i \in \{1, \cdots, N\}\), denote \(X_i\) as the survival time and \(D_i\) as the censoring time. Survival data consist of the event indicator, \(\delta_i=1(X_i\leq D_i)\) (argument event) and the event or censoring time, \(T_i = \min(\{ X_i,D_i \})\) (argument time).

The log hazard function for the Weibull AFT survival model [Car03] of subject \(i\) at time \(t\) has the form:

\[\log h_i(t) = \log{\rho_i} - \log{\lambda_i} + (\rho_i -1) \left( \log{t} - \log{\lambda_i}\right)\]

where \(\log{\lambda_i}\) is the log scale parameter (first column of argument log_params) and \(\log{\rho_i}\) is the log shape parameter (second column of argument log_params). The cumulative hazard for the Weibull survival model of subject \(i\) at time \(t\) has the form:

\[H_i(t) = \left(\frac{t}{\lambda_i}\right)^{\rho_i}\]

The survival function for the Weibull survival model of subject \(i\) at time \(t\) has the form:

\[S_i(t) = 1 - F(t | \lambda_i, \rho_i)\]

where \(F(t | \lambda, \rho)\) is the cumulative distribution function (CDF) of the Weibull distribution given scale parameter \(\lambda\) and shape parameter \(\rho\).

The log likelihood of the Weibull survival model is

\[ll = \sum_{i: \delta_i = 1} \log h_i(T_i) - \sum_{i = 1}^N H_i(T_i)\]

Examples

>>> _ = torch.manual_seed(43)
>>> n = 4
>>> log_params = torch.randn((n, 2), dtype=torch.float)
>>> event = torch.randint(low=0, high=2, size=(n,), dtype=torch.bool)
>>> time = torch.randint(low=1, high=100, size=(n,), dtype=torch.float)
>>> neg_log_likelihood(log_params, event, time)  # Default: mean of log likelihoods across subject
tensor(143039.2656)
>>> neg_log_likelihood(log_params, event, time, reduction="sum")  # Sum of log likelihoods across subject
tensor(572157.0625)
>>> neg_log_likelihood(
...     torch.randn((n, 1), dtype=torch.float), event, time
... )  # Missing shape: exponential distribution
tensor(67.4289)

References

[Car03]

Kevin J. Carroll. On the use and utility of the weibull model in the analysis of survival data. Controlled Clinical Trials, 24(6):682–701, December 2003.

torchsurv.loss.weibull.log_hazard(new_log_params: Tensor, new_time: Tensor, respective_times: bool = False, clamp_value: float = 10000000000.0) Tensor[source]#

Log hazard of the Weibull Accelerated Time Failure (AFT) survival model.

Parameters:
  • new_log_params (torch.Tensor, float) – Parameters of the Weibull distribution for new subjects, of shape = (n_samples_new, 1) or (n_samples_new, 2). The first column corresponds to the log scale parameter. The second column corresponds to the log shape parameter. If the log shape parameter is missing, it is imputed with 0.

  • new_time (torch.Tensor, float) – Time at which to evaluate the log hazard of length n_times.

  • respective_times (bool, optional) – If True, new_time must have the same length as new_log_params. The subject-specific log hazard is then evaluated at each respective index in new_time. Defaults to False.

  • clamp_value (float, optional) – Maximum value to which the log hazard is clipped. This prevents numerical overflow or instability by capping extremely large values of the log hazard. Defaults to 1e10.

Returns:

Subject-specific log hazard evaluated at new_time. Shape = (n_samples_new, n_times) if respective_times is False. Shape = (n_samples_new,) if respective_times is True.

Return type:

(torch.Tensor, float)

Examples

>>> new_log_params = torch.tensor([[0.15, 0.25], [0.1, 0.2]])  # 2 new subjects
>>> new_time = torch.tensor([1.0, 2.0])
>>> log_hazard(new_log_params, new_time)
tensor([[0.0574, 0.2543],
        [0.0779, 0.2313]])
>>> log_hazard(new_log_params, new_time, respective_times=True)
tensor([0.0574, 0.2313])
torchsurv.loss.weibull.survival_function(new_log_params: Tensor, new_time: Tensor) Tensor[source]#

Survival function for the Weibull Accelerated Time Failure (AFT) survival model.

Parameters:
  • new_log_params (torch.Tensor, float) – Parameters of the Weibull distribution for new subjects, of shape = (n_samples_new, 1) or (n_samples_new, 2). The first column corresponds to the log scale parameter. The second column corresponds to the log shape parameter. If the log shape parameter is missing, it is imputed with 0.

  • new_time (torch.Tensor, float) – Time at which to evaluate the survival probability of length n_times.

Returns:

Individual survival probabilities for each new subject at new_time. Shape = (n_samples_new, n_times).

Return type:

torch.Tensor

Examples

>>> new_log_params = torch.tensor([[0.15, 0.25], [0.1, 0.2]])  # 2 new subjects
>>> new_time = torch.tensor([1.0, 2.0, 3.0, 4.0])
>>> survival_function(new_log_params, new_time)  #  Survival at new times
tensor([[0.4383, 0.1342, 0.0340, 0.0075],
        [0.4127, 0.1270, 0.0338, 0.0081]])