torchsurv.loss.cox#
- neg_partial_log_likelihood(log_hz, event, time, ties_method='efron', reduction='mean', strata=None, checks=True)[source]#
Compute the negative of the partial log likelihood for the Cox proportional hazards model.
- Parameters:
log_hz (torch.Tensor, float) – Log relative hazard of length n_samples.
event (torch.Tensor, bool) – Event indicator of length n_samples (= True if event occurred).
time (torch.Tensor, float) – Event or censoring time of length n_samples.
ties_method (str) – Method to handle ties in event time. Defaults to “efron”. Must be one of the following: “efron”, “breslow”.
reduction (str) – Method to reduce losses. Defaults to “mean”. Must be one of the following: “sum”, “mean”.
strata (torch.Tensor, int, optional) – Integer tensor of length n_samples representing stratum for each subject defined by combinations of categorical variables. This is useful if a categorical covariate does not obey the proportional hazard assumption. This is used similar to the strata expression in R and lifelines. See http://courses.washington.edu/b515/l17.pdf.
checks (bool, optional) – Whether to perform input format checks. Enabling checks can help catch potential issues in the input data. Defaults to True.
- Returns:
Negative of the partial log likelihood.
- Return type:
(torch.tensor, float)
Note
For each subject \(i \in \{1, \cdots, N\}\), denote \(X_i\) as the survival time and \(D_i\) as the censoring time. Survival data consist of the event indicator, \(\delta_i=1(X_i\leq D_i)\) (argument
event) and the time-to-event or censoring, \(T_i = \min(\{ X_i,D_i \})\) (argumenttime).The log hazard function for the Cox proportional hazards model has the form:
\[\log \lambda_i (t) = \log \lambda_{0}(t) + \log \theta_i\]where \(\log \theta_i\) is the log relative hazard (argument
log_hz).No ties in event time. If the set \(\{T_i: \delta_i = 1\}_{i = 1, \cdots, N}\) represent unique event times (i.e., no ties), the standard Cox partial likelihood can be used [Cox72]. Let \(\tau_1 < \tau_2 < \cdots < \tau_N\) be the ordered times and let \(R(\tau_i) = \{ j: \tau_j \geq \tau_i\}\) be the risk set at \(\tau_i\). The partial log likelihood is defined as:
\[pll = \sum_{i: \: \delta_i = 1} \left(\log \theta_i - \log\left(\sum_{j \in R(\tau_i)} \theta_j \right) \right)\]Ties in event time handled with Breslow’s method. Breslow’s method [Bre75] describes the approach in which the procedure described above is used unmodified, even when ties are present. If two subjects A and B have the same event time, subject A will be at risk for the event that happened to B, and B will be at risk for the event that happened to A. Let \(\xi_1 < \xi_2 < \cdots\) denote the unique ordered times (i.e., unique \(\tau_i\)). Let \(H_k\) be the set of subjects that have an event at time \(\xi_k\) such that \(H_k = \{i: \tau_i = \xi_k, \delta_i = 1\}\), and let \(m_k\) be the number of subjects that have an event at time \(\xi_k\) such that \(m_k = |H_k|\).
\[pll = \sum_{k} \left( {\sum_{i\in H_{k}}\log \theta_i} - m_k \: \log\left(\sum_{j \in R(\xi_k)} \theta_j \right) \right)\]Ties in event time handled with Efron’s method. An alternative approach that is considered to give better results is the Efron’s method [Efr77]. As a compromise between the Cox’s and Breslow’s method, Efron suggested to use the average risk among the subjects that have an event at time \(\xi_k\):
\[\bar{\theta}_{k} = {\frac {1}{m_{k}}}\sum_{i\in H_{k}}\theta_i\]Efron approximation of the partial log likelihood is defined by
\[pll = \sum_{k} \left( {\sum_{i\in H_{k}}\log \theta_i} - \sum_{r =0}^{m_{k}-1} \log\left(\sum_{j \in R(\xi_k)}\theta_j-r\:\bar{\theta}_{j}\right)\right)\]Stratified Cox model. When subjects come from different strata (argument
strata), each stratum has its own baseline hazard function. Let \(\lambda_{0}^s(t)\) be the baseline hazard for stratum \(s\). The hazard function for patient \(i\) in stratum \(s\) becomes:\[\log \lambda_i^s(t) = \log \lambda_{0}^s(t) + \log \theta_i\]The partial likelihood is computed separately within each stratum and then combined:
\[pll = \sum_{s} pll_{s}\]where \(pll_{s}\) is the partial log likelihood contribution computed using only subjects in stratum \(s\)
Examples
>>> _ = torch.manual_seed(43) >>> n = 4 >>> log_hz = torch.randn((n, 1), dtype=torch.float) >>> event = torch.randint(low=0, high=2, size=(n,), dtype=torch.bool) >>> time = torch.randint(low=1, high=100, size=(n,), dtype=torch.float) >>> neg_partial_log_likelihood(log_hz, event, time) # default, mean of log likelihoods across patients tensor(1.9908) >>> neg_partial_log_likelihood(log_hz, event, time, reduction="sum") # sum of log likelihoods across patients tensor(5.9724) >>> time[0] = time[1] # Dealing with ties (default: Efron) >>> neg_partial_log_likelihood(log_hz, event, time, ties_method="efron") tensor(2.9877) >>> neg_partial_log_likelihood(log_hz, event, time, ties_method="breslow") # Dealing with ties (Breslow) tensor(2.0247)
References
[Bre75]N. E. Breslow. Analysis of survival data under the proportional hazards model. International Statistical Review / Revue Internationale de Statistique, 43(1):45, April 1975.
[Cox72]D. R. Cox. Regression models and life‐tables. Journal of the Royal Statistical Society: Series B (Methodological), 34(2):187–202, January 1972.
[Efr77]Bradley Efron. The efficiency of cox's likelihood function for censored data. Journal of the American Statistical Association, 72(359):557–565, September 1977.
- baseline_survival_function(log_hz, event, time, strata=None, checks=True)[source]#
Compute the baseline survival function for the Cox proportional hazards model with Breslow’s method.
- Parameters:
log_hz (torch.Tensor, float) – Log relative hazard of length n_samples.
event (torch.Tensor, bool) – Event indicator of length n_samples (= True if event occurred) used to fit the model.
time (torch.Tensor, float) – Event or censoring time of length n_samples used to fit the model.
strata (torch.Tensor, int, optional) – Integer tensor of length n_samples representing stratum for each subject defined by combinations of categorical variables.
checks (bool, optional) – Whether to perform input format checks. Enabling checks can help catch potential issues in the input data. Defaults to True.
- Returns:
- Dictionary with two entries:
”time” (torch.Tensor): Sorted unique
time.”baseline_survival” (torch.Tensor): Estimated baseline survival function evaluated at these times.
- Return type:
(dict)
Note
The baseline survival function, \(S_0(t)\), and the baseline cumulative hazard, \(H_0(t)\), under the Cox proportional hazards model are defined as:
\[S_0(t) = \exp\Big(-H_0(u)\, du \Big), \quad H_0(t) = \int_{0}^{t} \lambda_0(u)\, du.\]Using the Breslow’s estimator [Bre72], we estimate the baseline cumulative hazard as:
\[\hat{H}_0(t) = \sum_{\xi_k \le t} \frac{m_k}{\sum_{j \in R(\xi_k)} \theta_j}.\]The estimated baseline survival function is then given by:
\[\hat{S}_0(t) = \exp\left(-\hat{H}_0(t)\right).\]When
strataare provided, the baseline cumulative hazard \(\hat{H}_{0}^s(t)\) and baseline survival function \(\hat{S}_{0}^s(t)\) are computed separately for each stratum \(s\), using only subjects from the same stratum.Examples
>>> log_hz = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) >>> event = torch.tensor([1, 0, 0, 1, 1], dtype=torch.bool) >>> time = torch.tensor([1.0, 2.0, 3.0, 4.0, 4.0]) >>> baseline_survival_function(log_hz, event, time) {'time': tensor([1., 2., 3., 4.]), 'baseline_survival': tensor([0.8636, 0.8636, 0.8636, 0.4568])}
References
[Bre72]N. E. Breslow. Discussion on professor cox's paper. Journal of the Royal Statistical Society Series B: Statistical Methodology, 34(2):202–220, January 1972.
- survival_function(baseline_survival, new_log_hz, new_time, new_strata=None)[source]#
Compute the individual survival function for new subjects for the Cox proportional hazards model.
- Parameters:
baseline_survival (dict) – Output of
baseline_survival_function.new_log_hz (torch.Tensor, float) – Log relative hazard for new subjects of length n_samples_new.
new_time (torch.Tensor, float) – Time at which to evaluate the survival probability of length n_times.
new_strata (torch.Tensor, int, optional) – Integer tensor of length n_samples_new representing stratum for each new subject defined by combinations of categorical variables.
- Returns:
Individual survival probabilities for each new subject at
new_timeof shape = (n_samples_new, n_times).- Return type:
torch.Tensor
Note
The estimated survival function for new subject \(i\) under the Cox proportional hazards models is given by:
\[\hat{S}_i(t) = \hat{S}_0(t)^{\theta_i^{\star}},\]where \(\hat{S}_0(t)\) is the estimated baseline survival function and \(\log \theta_i^{\star}\) is the log relative hazard of new subjects (argument
new_log_hz).When strata are provided for both the original model fitting and new subject prediction (argument
new_strata), the survival function uses the baseline survival function specific to the subject’s stratum \(\hat{S}_{0}^s(t)\).Examples
>>> event = torch.tensor([1, 0, 0, 1, 1], dtype=torch.bool) # original subjects >>> time = torch.tensor([1.0, 2.0, 3.0, 4.0, 4.0]) >>> log_hz = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) >>> baseline_survival = baseline_survival_function(log_hz, event, time) >>> new_log_hz = torch.tensor([0.15, 0.25]) # 2 new subjects >>> new_time = torch.tensor([2.5, 4.5]) >>> survival_function(baseline_survival, new_log_hz, new_time) tensor([[0.8433, 0.4024], [0.8283, 0.3657]])