torchsurv.loss.weibull#
Functions
|
Log hazard of the Weibull Accelerated Time Failure (AFT) survival model. |
|
Negative of the log likelihood for the Weibull Accelerated Time Failure (AFT) survival model. |
|
Survival function for the Weibull Accelerated Time Failure (AFT) survival model. |
- torchsurv.loss.weibull.neg_log_likelihood(log_params: Tensor, event: Tensor, time: Tensor, reduction: str = 'mean', checks: bool = True) Tensor[source]#
Negative of the log likelihood for the Weibull Accelerated Time Failure (AFT) survival model.
- Parameters:
log_params (torch.Tensor, float) – Parameters of the Weibull distribution of shape = (n_samples, 1) or (n_samples, 2). The first column corresponds to the log scale parameter. The second column corresponds to the log shape parameter. If the log shape parameter is missing, it is imputed with 0.
event (torch.Tensor, bool) – Event indicator of length n_samples (= True if event occurred).
time (torch.Tensor, float) – Event or censoring time of length n_samples.
reduction (str) – Method to reduce losses. Defaults to “mean”. Must be one of the following: “sum”, “mean”.
checks (bool) – Whether to perform input format checks. Enabling checks can help catch potential issues in the input data. Defaults to True.
- Returns:
Negative of the log likelihood.
- Return type:
(torch.Tensor, float)
Note
For each subject \(i \in \{1, \cdots, N\}\), denote \(X_i\) as the survival time and \(D_i\) as the censoring time. Survival data consist of the event indicator, \(\delta_i=1(X_i\leq D_i)\) (argument
event) and the event or censoring time, \(T_i = \min(\{ X_i,D_i \})\) (argumenttime).The log hazard function for the Weibull AFT survival model [Car03] of subject \(i\) at time \(t\) has the form:
\[\log h_i(t) = \log{\rho_i} - \log{\lambda_i} + (\rho_i -1) \left( \log{t} - \log{\lambda_i}\right)\]where \(\log{\lambda_i}\) is the log scale parameter (first column of argument
log_params) and \(\log{\rho_i}\) is the log shape parameter (second column of argumentlog_params). The cumulative hazard for the Weibull survival model of subject \(i\) at time \(t\) has the form:\[H_i(t) = \left(\frac{t}{\lambda_i}\right)^{\rho_i}\]The survival function for the Weibull survival model of subject \(i\) at time \(t\) has the form:
\[S_i(t) = 1 - F(t | \lambda_i, \rho_i)\]where \(F(t | \lambda, \rho)\) is the cumulative distribution function (CDF) of the Weibull distribution given scale parameter \(\lambda\) and shape parameter \(\rho\).
The log likelihood of the Weibull survival model is
\[ll = \sum_{i: \delta_i = 1} \log h_i(T_i) - \sum_{i = 1}^N H_i(T_i)\]Examples
>>> _ = torch.manual_seed(43) >>> n = 4 >>> log_params = torch.randn((n, 2), dtype=torch.float) >>> event = torch.randint(low=0, high=2, size=(n,), dtype=torch.bool) >>> time = torch.randint(low=1, high=100, size=(n,), dtype=torch.float) >>> neg_log_likelihood(log_params, event, time) # Default: mean of log likelihoods across subject tensor(143039.2656) >>> neg_log_likelihood(log_params, event, time, reduction="sum") # Sum of log likelihoods across subject tensor(572157.0625) >>> neg_log_likelihood( ... torch.randn((n, 1), dtype=torch.float), event, time ... ) # Missing shape: exponential distribution tensor(67.4289)
References
[Car03]Kevin J. Carroll. On the use and utility of the weibull model in the analysis of survival data. Controlled Clinical Trials, 24(6):682–701, December 2003.
- torchsurv.loss.weibull.log_hazard(new_log_params: Tensor, new_time: Tensor, respective_times: bool = False, clamp_value: float = 10000000000.0) Tensor[source]#
Log hazard of the Weibull Accelerated Time Failure (AFT) survival model.
- Parameters:
new_log_params (torch.Tensor, float) – Parameters of the Weibull distribution for new subjects, of shape = (n_samples_new, 1) or (n_samples_new, 2). The first column corresponds to the log scale parameter. The second column corresponds to the log shape parameter. If the log shape parameter is missing, it is imputed with 0.
new_time (torch.Tensor, float) – Time at which to evaluate the log hazard of length n_times.
respective_times (bool, optional) – If True,
new_timemust have the same length asnew_log_params. The subject-specific log hazard is then evaluated at each respective index innew_time. Defaults to False.clamp_value (float, optional) – Maximum value to which the log hazard is clipped. This prevents numerical overflow or instability by capping extremely large values of the log hazard. Defaults to 1e10.
- Returns:
Subject-specific log hazard evaluated at
new_time. Shape = (n_samples_new, n_times) ifrespective_timesis False. Shape = (n_samples_new,) ifrespective_timesis True.- Return type:
(torch.Tensor, float)
Examples
>>> new_log_params = torch.tensor([[0.15, 0.25], [0.1, 0.2]]) # 2 new subjects >>> new_time = torch.tensor([1.0, 2.0]) >>> log_hazard(new_log_params, new_time) tensor([[0.0574, 0.2543], [0.0779, 0.2313]]) >>> log_hazard(new_log_params, new_time, respective_times=True) tensor([0.0574, 0.2313])
- torchsurv.loss.weibull.survival_function(new_log_params: Tensor, new_time: Tensor) Tensor[source]#
Survival function for the Weibull Accelerated Time Failure (AFT) survival model.
- Parameters:
new_log_params (torch.Tensor, float) – Parameters of the Weibull distribution for new subjects, of shape = (n_samples_new, 1) or (n_samples_new, 2). The first column corresponds to the log scale parameter. The second column corresponds to the log shape parameter. If the log shape parameter is missing, it is imputed with 0.
new_time (torch.Tensor, float) – Time at which to evaluate the survival probability of length n_times.
- Returns:
Individual survival probabilities for each new subject at
new_time. Shape = (n_samples_new, n_times).- Return type:
torch.Tensor
Examples
>>> new_log_params = torch.tensor([[0.15, 0.25], [0.1, 0.2]]) # 2 new subjects >>> new_time = torch.tensor([1.0, 2.0, 3.0, 4.0]) >>> survival_function(new_log_params, new_time) # Survival at new times tensor([[0.4383, 0.1342, 0.0340, 0.0075], [0.4127, 0.1270, 0.0338, 0.0081]])