Source code for torchsurv.metrics.brier_score

import copy
import warnings
from typing import Optional

import torch
from scipy import stats

from torchsurv.tools.validate_data import (
    validate_new_time,
    validate_survival_data,
)

__all__ = ["BrierScore"]


[docs] class BrierScore: r"""Compute the Brier Score for survival models."""
[docs] def __init__(self, checks: bool = True): """Initialize a BrierScore for survival class model evaluation. Args: checks (bool): Whether to perform input format checks. Enabling checks can help catch potential issues in the input data. Defaults to True. Examples: >>> _ = torch.manual_seed(52) >>> n = 10 >>> time = torch.randint(low=5, high=250, size=(n,), dtype=torch.float) >>> event = torch.randint(low=0, high=2, size=(n,), dtype=torch.bool) >>> estimate = torch.rand((n, len(time)), dtype=torch.float) >>> brier_score = BrierScore() >>> brier_score(estimate, event, time) tensor([0.2463, 0.2740, 0.3899, 0.1964, 0.3608, 0.2821, 0.1932, 0.2978, 0.1950, 0.1668]) >>> brier_score.integral() # integrated brier score tensor(0.2862) >>> brier_score.confidence_interval() # default: parametric, two-sided tensor([[0.1061, 0.0604, 0.2360, 0.0533, 0.1252, 0.0795, 0.0000, 0.1512, 0.0381, 0.0051], [0.3866, 0.4876, 0.5437, 0.3394, 0.5965, 0.4847, 0.4137, 0.4443, 0.3520, 0.3285]]) >>> brier_score.p_value() # default: bootstrap permutation test, two-sided tensor([1.0000, 0.7860, 1.0000, 0.3840, 1.0000, 1.0000, 0.3840, 1.0000, 0.7000, 0.2380]) """ self.checks = checks # init instate attributes self.order_time = None self.time = None self.event = None self.weight = None self.new_time = None self.weight_new_time = None self.estimate = None self.brier_score = None self.residuals = None
[docs] def __call__( self, estimate: torch.Tensor, event: torch.Tensor, time: torch.Tensor, new_time: Optional[torch.Tensor] = None, weight: Optional[torch.Tensor] = None, weight_new_time: Optional[torch.Tensor] = None, instate: bool = True, ) -> torch.Tensor: r"""Compute the Brier Score. Args: estimate (torch.Tensor): Estimated probability of remaining event-free (i.e., survival function). Can be of shape = (n_samples, n_samples) if subject-specific survival is evaluated at ``time`` (the entry at row i and column j corresponds to the survival for subject i at the `time` of subject j), or of shape = (n_samples, n_times) if subject-specific survival is evaluated at ``new_time`` (the entry at row i and column j corresponds to the survival for subject i at the jth ``new_time``). event (torch.Tensor, bool): Event indicator of size n_samples (= True if event occurred) time (torch.Tensor, float): Event or censoring time of size n_samples. new_time (torch.Tensor, float, optional): Time points at which to evaluate the Brier score of size n_times. Defaults to unique ``time``. weight (torch.Tensor, optional): Optional sample weight evaluated at ``time`` of size n_samples. Defaults to 1. weight_new_time (torch.Tensor, optional): Optional sample weight evaluated at ``new_time`` of size n_times. Defaults to 1. Returns: torch.Tensor: Brier score evaluated at ``new_time``. Note: The function evaluates the time-dependent Brier score at time :math:`t \in \{t_1, \cdots, t_T\}` (argument ``new_time``). For each subject :math:`i \in \{1, \cdots, N\}`, denote :math:`X_i` as the survival time and :math:`D_i` as the censoring time. Survival data consist of the event indicator, :math:`\delta_i=(X_i\leq D_i)` (argument ``event``) and the event or censoring time, :math:`T_i = \min(\{ X_i,D_i \})` (argument ``time``). The survival function, of subject :math:`i` is specified through :math:`S_i: [0, \infty) \rightarrow [0,1]`. The argument ``estimate`` is the estimated survival function. If ``new_time`` is specified, it should be of shape = (N,T) (:math:`(i,k)` th element is :math:`\hat{S}_i(t_k)`); if ``new_time`` is not specified, it should be of shape = (N,N) (:math:`(i,j)` th element is :math:`\hat{S}_i(T_j)`). The time-dependent Brier score :cite:p:`Graf1999` at time :math:`t` is the mean squared error of the event status .. math:: BS(t) = \mathbb{E}\left[\left(1\left(X > t\right) - \hat{S}(t)\right)^2\right] The default Brier score estimate is .. math:: \hat{BS}(t) = \frac{1}{n}\sum_i 1(T_i \leq t, \delta_i = 1) (0 - \hat{S}_i(t))^2 + 1(T_1 > t) (1- \hat{S}_i(t))^2 To account for the fact that the event time are censored, the inverse probability weighting technique can be used. In this context, each subject associated with time :math:`t` is weighted by the inverse probability of censoring :math:`\omega(t) = 1 / \hat{D}(t)`, where :math:`\hat{D}(t)` is the Kaplan-Meier estimate of the censoring distribution, :math:`P(D>t)`. The censoring-adjusted Brier score is .. math:: \hat{BS}(t) = \frac{1}{n}\sum_i \omega(T_i) 1(T_i \leq t, \delta_i = 1) (0 - \hat{S}_i(t))^2 + \omega(t) 1(T_1 > t) (1- \hat{S}_i(t))^2 The censoring-adjusted Brier score can be obtained by specifying the argument ``weight``, the weights evaluated at each ``time`` (:math:`\omega(T_1), \cdots, \omega(T_N)`). If ``new_time`` is specified, the argument ``weight_new_time`` should also be specified accordingly, the weights evaluated at each ``new_time`` (:math:`\omega(t_1), \cdots, \omega(t_K)`). In the context of train/test split, the weights should be derived from the censoring distribution estimated in the training data. Examples: >>> from torchsurv.stats.ipcw import get_ipcw >>> _ = torch.manual_seed(52) >>> n = 10 >>> time = torch.randint(low=5, high=250, size=(n,), dtype=torch.float) >>> event = torch.randint(low=0, high=2, size=(n,), dtype=torch.bool) >>> estimate = torch.rand((n, len(time)), dtype=torch.float) >>> brier_score = BrierScore() >>> brier_score(estimate, event, time) tensor([0.2463, 0.2740, 0.3899, 0.1964, 0.3608, 0.2821, 0.1932, 0.2978, 0.1950, 0.1668]) >>> ipcw = get_ipcw(event, time) # ipcw at time >>> brier_score(estimate, event, time, weight=ipcw) # censoring-adjusted brier-score tensor([0.2463, 0.2740, 0.4282, 0.2163, 0.4465, 0.3826, 0.2630, 0.3888, 0.2219, 0.1882]) >>> new_time = torch.unique(torch.randint(low=5, high=time.max().int(), size=(n * 2,), dtype=torch.float)) >>> ipcw_new_time = get_ipcw(event, time, new_time) # ipcw at new_time >>> estimate = torch.rand((n, len(new_time)), dtype=torch.float) >>> brier_score( ... estimate, event, time, new_time, ipcw, ipcw_new_time ... ) # censoring-adjusted brier-score at new time tensor([0.4036, 0.3014, 0.2517, 0.3947, 0.4200, 0.3908, 0.3766, 0.3737, 0.3596, 0.2088, 0.4922, 0.3237, 0.2255, 0.1841, 0.3029, 0.6919, 0.2357, 0.3507, 0.4364, 0.3312]) References: .. bibliography:: :filter: False Graf1999 """ # ensure event, time are squeezed event = event.squeeze() time = time.squeeze() # mandatory input format checks BrierScore._validate_brier_score_inputs(estimate, time, new_time, weight, weight_new_time) # update inputs as required ( estimate, new_time, weight, weight_new_time, ) = BrierScore._update_brier_score_new_time(estimate, time, new_time, weight, weight_new_time) weight, weight_new_time = BrierScore._update_brier_score_weight(time, new_time, weight, weight_new_time) # further input format checks if self.checks: validate_survival_data(event, time) validate_new_time(new_time, time, within_follow_up=False) # Calculating the residuals for each subject and time point residuals = torch.zeros_like(estimate) for index, new_time_i in enumerate(new_time): est = estimate[:, index] is_case = ((time <= new_time_i) & (event)).int() is_control = (time > new_time_i).int() residuals[:, index] = ( torch.square(est) * is_case * weight + torch.square(1.0 - est) * is_control * weight_new_time[index] ) # Calculating the brier scores at each time point brier_score = torch.mean(residuals, axis=0) # Create/overwrite internal attributes states if instate: # sort all objects by time self.order_time = torch.argsort(time, dim=0) self.time = time[self.order_time] self.event = event[self.order_time] self.weight = weight[self.order_time] self.new_time = new_time self.weight_new_time = weight_new_time self.estimate = torch.index_select(estimate, 0, self.order_time) self.brier_score = brier_score self.residuals = residuals return brier_score
[docs] def integral(self): r"""Compute the integrated Brier Score. Returns: torch.Tensor: Integrated Brier Score. Examples: >>> _ = torch.manual_seed(52) >>> n = 10 >>> time = torch.randint(low=5, high=250, size=(n,), dtype=torch.float) >>> event = torch.randint(low=0, high=2, size=(n,), dtype=torch.bool) >>> estimate = torch.rand((n, len(time)), dtype=torch.float) >>> brier_score = BrierScore() >>> brier_score(estimate, event, time) tensor([0.2463, 0.2740, 0.3899, 0.1964, 0.3608, 0.2821, 0.1932, 0.2978, 0.1950, 0.1668]) >>> brier_score.integral() # integrated brier score tensor(0.2862) Note: The integrated Brier score is the integral of the time-dependent Brier score over the interval :math:`[t_1, t_2]`, where :math:`t_1 = \min\left(\{T_i\}_{i = 1, \cdots, N}\right)` and :math:`t_2 = \max\left(\{T_i\}_{i = 1, \cdots, N}\right)`. It is defined by :cite:p:`Graf1999` .. math:: \hat{IBS} = \int_{t_1}^{t_2} \hat{BS}(t) dW(t) where :math:`W(t) = t / t_2`. The integral is estimated with the trapzoidal rule. """ # Single time available if len(self.new_time) == 1: brier = self.brier_score[0] else: brier = torch.trapezoid(self.brier_score, self.new_time) / (self.new_time[-1] - self.new_time[0]) return brier
[docs] def confidence_interval( self, method: str = "parametric", alpha: float = 0.05, alternative: str = "two_sided", n_bootstraps: int = 999, ) -> torch.Tensor: """Compute the confidence interval of the Brier Score. This function calculates either the pointwise confidence interval or the bootstrap confidence interval for the Brier Score. The pointwise confidence interval is computed assuming that the Brier score is normally distributed and using empirical standard errors. The bootstrap confidence interval is constructed based on the distribution of bootstrap samples. Args: method (str): Method for computing confidence interval. Defaults to "parametric". Must be one of the following: "parametric", "bootstrap". alpha (float): Significance level. Defaults to 0.05. alternative (str): Alternative hypothesis. Defaults to "two_sided". Must be one of the following: "two_sided", "greater", "less". n_bootstraps (int): Number of bootstrap samples. Defaults to 999. Ignored if ``method`` is not "bootstrap". Returns: torch.Tensor([lower,upper]): Lower and upper bounds of the confidence interval. Examples: >>> _ = torch.manual_seed(52) >>> n = 10 >>> time = torch.randint(low=5, high=250, size=(n,), dtype=torch.float) >>> event = torch.randint(low=0, high=2, size=(n,), dtype=torch.bool) >>> estimate = torch.rand((n, len(time)), dtype=torch.float) >>> brier_score = BrierScore() >>> brier_score(estimate, event, time) tensor([0.2463, 0.2740, 0.3899, 0.1964, 0.3608, 0.2821, 0.1932, 0.2978, 0.1950, 0.1668]) >>> brier_score.confidence_interval() # default: parametric, two-sided tensor([[0.1061, 0.0604, 0.2360, 0.0533, 0.1252, 0.0795, 0.0000, 0.1512, 0.0381, 0.0051], [0.3866, 0.4876, 0.5437, 0.3394, 0.5965, 0.4847, 0.4137, 0.4443, 0.3520, 0.3285]]) >>> brier_score.confidence_interval(method="bootstrap", alternative="greater") tensor([[0.1455, 0.1155, 0.2741, 0.0903, 0.1985, 0.1323, 0.0245, 0.1938, 0.0788, 0.0440], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]) """ assert hasattr(self, "brier_score") and self.brier_score is not None, ( "Error: Please calculate brier score using `BrierScore()` before calling `confidence_interval()`." ) if alternative not in ["less", "greater", "two_sided"]: raise ValueError("'alternative' parameter must be one of ['less', 'greater', 'two_sided'].") if method == "bootstrap": conf_int = self._confidence_interval_bootstrap(alpha, alternative, n_bootstraps) elif method == "parametric": conf_int = self._confidence_interval_parametric(alpha, alternative) else: raise ValueError(f"Method {method} not implemented. Please choose either 'parametric' or 'bootstrap'.") return conf_int
[docs] def p_value( self, method: str = "bootstrap", alternative: str = "two_sided", n_bootstraps: int = 999, null_value: float = None, ) -> torch.Tensor: """Perform a one-sample hypothesis test on the Brier score. This function calculates either the pointwise p-value or the bootstrap p-value for testing the null hypothesis that the estimated brier score is equal to bs0, where bs0 is the brier score that would be expected if the survival model was not providing accurate predictions beyond random chance. The pointwise p-value is computed assuming that the Brier score is normally distributed and using the empirical standard errors. To obtain the pointwise p-value, the Brier score under the null, bs0, must be provided. The bootstrap p-value is derived by permuting survival function's predictions to estimate the the sampling distribution under the null hypothesis. Args: method (str): Method for computing p-value. Defaults to "bootstrap". Must be one of the following: "parametric", "bootstrap". alternative (str): Alternative hypothesis. Defaults to "two_sided". Must be one of the following: "two_sided" (Brier score is not equal to bs0), "greater" (Brier score is greater than bs0), "less" (Brier score is less than bs0). n_bootstraps (int): Number of bootstrap samples. Defaults to 999. Ignored if ```method``` is not "bootstrap". null_value (float): The Brier score expected if the survival model was not providing accurate predictions beyond what would be beyond by random chance alone, i.e., bs0. Returns: torch.Tensor: p-value of the statistical test. Examples: >>> _ = torch.manual_seed(52) >>> n = 10 >>> time = torch.randint(low=5, high=250, size=(n,), dtype=torch.float) >>> event = torch.randint(low=0, high=2, size=(n,), dtype=torch.bool) >>> new_time = torch.unique(time) >>> estimate = torch.rand((n, len(new_time)), dtype=torch.float) >>> brier_score = BrierScore() >>> brier_score(estimate, event, time, new_time) tensor([0.3465, 0.5310, 0.4222, 0.4582, 0.3601, 0.3395, 0.2285, 0.1975, 0.3120, 0.3883]) >>> brier_score.p_value() # Default: bootstrap, two_sided tensor([1.0000, 0.0560, 1.0000, 1.0000, 1.0000, 1.0000, 0.8320, 0.8620, 1.0000, 1.0000]) >>> brier_score.p_value( ... method="parametric", alternative="less", null_value=0.3 ... ) # H0: bs = 0.3, Ha: bs < 0.3 tensor([0.7130, 0.9964, 0.8658, 0.8935, 0.6900, 0.6630, 0.1277, 0.1128, 0.5383, 0.8041]) """ assert hasattr(self, "brier_score") and self.brier_score is not None, ( "Error: Please calculate the brier score using `BrierScore()` before calling `p_value()`." ) if alternative not in ["less", "greater", "two_sided"]: raise ValueError("'alternative' parameter must be one of ['less', 'greater', 'two_sided'].") if method == "parametric" and null_value is None: raise ValueError("Error: If the method is 'parametric', you must provide the 'null_value'.") if method == "parametric": pvalue = self._p_value_parametric(alternative, null_value) elif method == "bootstrap": pvalue = self._p_value_bootstrap(alternative, n_bootstraps) else: raise ValueError(f"Method {method} not implemented. Please choose either 'parametric' or 'bootstrap'.") return pvalue
[docs] def compare(self, other, method: str = "parametric", n_bootstraps: int = 999) -> torch.Tensor: """Compare two Brier scores. This function compares two Brier scores computed on the same data with different risk scores. The statistical hypotheses are formulated as follows, null hypothesis: brierscore1 = brierscore2 and alternative hypothesis: brierscore1 < brierscore2. The statistical test is either a Student t-test for paired samples or a two-sample bootstrap test. The Student t-test for paired samples assumes that the Brier Scores are normally distributed and uses the Brier scores' empirical standard errors. Args: other (BrierScore): Another instance of the BrierScore class representing brierscore2. method (str): Statistical test used to perform the hypothesis test. Defaults to "parametric". Must be one of the following: "parametric", "bootstrap". n_bootstraps (int): Number of bootstrap samples. Defaults to 999. Ignored if ``method`` is not "bootstrap". Returns: torch.Tensor: p-value of the statistical test. Examples: >>> _ = torch.manual_seed(52) >>> n = 10 >>> time = torch.randint(low=5, high=250, size=(n,), dtype=torch.float) >>> event = torch.randint(low=0, high=2, size=(n,), dtype=torch.bool) >>> brier_score = BrierScore() >>> brier_score(torch.rand((n, len(time)), dtype=torch.float), event, time) tensor([0.2463, 0.2740, 0.3899, 0.1964, 0.3608, 0.2821, 0.1932, 0.2978, 0.1950, 0.1668]) >>> brier_score2 = BrierScore() >>> brier_score2(torch.rand((n, len(time)), dtype=torch.float), event, time) tensor([0.4136, 0.2750, 0.3002, 0.2826, 0.2030, 0.2643, 0.2525, 0.2964, 0.1804, 0.3109]) >>> brier_score.compare(brier_score2) # default: parametric tensor([0.1793, 0.4972, 0.7105, 0.1985, 0.9254, 0.5591, 0.3455, 0.5060, 0.5437, 0.0674]) >>> brier_score.compare(brier_score2, method="bootstrap") tensor([0.1360, 0.5030, 0.7310, 0.2090, 0.8630, 0.5490, 0.3120, 0.5110, 0.5460, 0.1030]) """ assert hasattr(self, "brier_score") and self.brier_score is not None, ( "Error: Please calculate the brier score using `BrierScore()` before calling `compare()`." ) # assert that the same data were used to compute the two brier score if torch.any(self.event != other.event) or torch.any(self.time != other.time): raise ValueError( "Mismatched survival data: 'time' and 'event' should be the same for both brier score computations." ) if torch.any(self.new_time != other.new_time): raise ValueError( "Mismatched evaluation times: 'new_time' should be the same for both brier score computations." ) if method == "parametric": pvalue = self._compare_parametric(other) elif method == "bootstrap": pvalue = self._compare_bootstrap(other, n_bootstraps) else: raise ValueError("Method not implemented. Please choose either 'parametric' or 'bootstrap'.") return pvalue
def _brier_score_se(self): """Brier Score's empirical standard errors.""" return torch.std(self.residuals, axis=0) / (self.time.shape[0] ** (1 / 2)) def _confidence_interval_parametric(self, alpha: float, alternative: str) -> torch.Tensor: """Confidence interval of Brier score assuming that the Brier score is normally distributed and using empirical standard errors. """ alpha = alpha / 2 if alternative == "two_sided" else alpha brier_score_se = self._brier_score_se() if torch.any(brier_score_se) == 0: time_index = torch.where(brier_score_se == 0) warnings.warn( f"The standard error of the brier score at time index: {time_index}are zero. This indicates that the brier score is constant across all samples. Confidence interval will equal to the point estimate", stacklevel=2, ) ci = -torch.distributions.normal.Normal(0, 1).icdf(torch.tensor(alpha)) * brier_score_se lower = torch.max(torch.tensor(0.0), self.brier_score - ci) upper = torch.min(torch.tensor(1.0), self.brier_score + ci) if alternative == "less": lower = torch.zeros_like(lower) elif alternative == "greater": upper = torch.ones_like(upper) return torch.stack([lower, upper], dim=0) def _confidence_interval_bootstrap(self, alpha: float, alternative: str, n_bootstraps: int) -> torch.Tensor: """Bootstrap confidence interval of the Brier Score using Efron percentile method. References: Efron, Bradley; Tibshirani, Robert J. (1993). An introduction to the bootstrap, New York: Chapman & Hall, software. """ # brier score given bootstrap distribution brier_score_bootstrap = self._bootstrap_brier_score(metric="confidence_interval", n_bootstraps=n_bootstraps) # initialize tensor to store confidence intervals lower = torch.zeros_like(self.brier_score) upper = torch.zeros_like(self.brier_score) # iterate over time for index_t in range(len(self.brier_score)): # obtain confidence interval if alternative == "two_sided": lower[index_t], upper[index_t] = torch.quantile( brier_score_bootstrap[:, index_t], torch.tensor( [alpha / 2, 1 - alpha / 2], device=self.brier_score.device, ), ) elif alternative == "less": upper[index_t] = torch.quantile( brier_score_bootstrap[:, index_t], torch.tensor(1 - alpha, device=self.brier_score.device), ) lower[index_t] = torch.tensor(0.0, device=self.brier_score.device) elif alternative == "greater": lower[index_t] = torch.quantile( brier_score_bootstrap[:, index_t], torch.tensor(alpha, device=self.brier_score.device), ) upper[index_t] = torch.tensor(1.0, device=self.brier_score.device) return torch.stack([lower, upper], dim=0) def _p_value_parametric(self, alternative: str, null_value: float = 0.5) -> torch.Tensor: """p-value for a one-sample hypothesis test of the Brier score assuming that the Brier score is normally distributed and using empirical standard error. """ brier_score_se = self._brier_score_se() if torch.any(brier_score_se) == 0: time_index = torch.where(brier_score_se == 0) warnings.warn( f"The standard error of the brier score at time index: {time_index} are zero. This indicates that the brier score is constant across all samples. Confidence interval will equal to the point estimate", stacklevel=2, ) # get p-value p = torch.distributions.normal.Normal(0, 1).cdf((self.brier_score - null_value) / brier_score_se) if alternative == "two_sided": mask = self.brier_score >= 0.5 p[mask] = 1 - p[mask] p *= 2 p = torch.min(torch.tensor(1.0, device=self.brier_score.device), p) # in case critical value is below 0.5 elif alternative == "greater": p = 1 - p return p def _p_value_bootstrap(self, alternative, n_bootstraps) -> torch.Tensor: """p-value for a one-sample hypothesis test of the Brier score using permutation of survival distribution prediction to estimate sampling distribution under the null hypothesis. """ # brier score bootstraps given null distribution brierscore0 = self._bootstrap_brier_score(metric="p_value", n_bootstraps=n_bootstraps) # initialize empty tensor to store p-values p_values = torch.zeros_like(self.brier_score) # iterate over time for index_t, brier_score_t in enumerate(self.brier_score): # Derive p-value p = (1 + torch.sum(brierscore0[:, index_t] <= brier_score_t)) / (n_bootstraps + 1) if alternative == "two_sided": if brier_score_t >= 0.5: p = 1 - p p *= 2 p = torch.min( torch.tensor(1.0, device=self.brier_score.device), p ) # in case very small bootstrap sample size is used elif alternative == "greater": p = 1 - p p_values[index_t] = p return p_values def _compare_parametric(self, other): """Student t-test for paired samples assuming that the Brier scores are normally distributed and using empirical standard errors.""" # sample size n_samples = self.time.shape[0] # initialize empty vector to store p_values p_values = torch.zeros_like(self.brier_score) # iterate over time for index_t, brier_score_t in enumerate(self.brier_score): # compute standard error of the difference paired_se = torch.std(self.residuals[:, index_t] - other.residuals[:, index_t]) / (n_samples ** (1 / 2)) # compute t-stat t_stat = (brier_score_t - other.brier_score[index_t]) / paired_se # p-value p_values[index_t] = torch.tensor( stats.t.cdf(t_stat, df=n_samples - 1), # student-t cdf not available on torch dtype=self.brier_score.dtype, device=self.brier_score.device, ) return p_values def _compare_bootstrap(self, other, n_bootstraps) -> torch.Tensor: """Bootstrap two-sample test to compare two Brier scores.""" # bootstrap brier scores given null hypothesis that brierscore1 and # brierscore2 come from the same distribution brier_score1_null = self._bootstrap_brier_score(metric="compare", other=other, n_bootstraps=n_bootstraps) brier_score2_null = self._bootstrap_brier_score(metric="compare", other=other, n_bootstraps=n_bootstraps) # bootstrapped test statistics t_boot = brier_score1_null - brier_score2_null # observed test statistics t_obs = self.brier_score - other.brier_score # initialize empty tensor to store p-values p_values = torch.zeros_like(self.brier_score) # iterate over time for index_t, _ in enumerate(self.brier_score): p_values[index_t] = (1 + torch.sum(t_boot[:, index_t] <= t_obs[index_t])) / (n_bootstraps + 1) return p_values def _bootstrap_brier_score(self, metric: str, n_bootstraps: int, other=None) -> torch.Tensor: """Compute bootstrap samples of the Brier Score. Args: metric (str): Must be one of the following: "confidence_interval", "compare". If "confidence_interval", computes bootstrap samples of the Brier score given the data distribution. If "compare", computes bootstrap samples of the Brier score given the sampling distribution under the comparison test null hypothesis (brierscore1 = brierscore2). n_bootstraps (int): Number of bootstrap samples. other (optional, BrierScore): Another instance of the BrierScore class representing brierscore2. Only required if ``metric`` is "compare". Returns: torch.Tensor: Bootstrap samples of Brier score. """ # Initiate empty list to store brier score brier_scores = [] # Get the bootstrap samples of brier score for _ in range(n_bootstraps): if metric == "confidence_interval": # bootstrap samples given data distribution index = torch.randint( low=0, high=self.estimate.shape[0], size=(self.estimate.shape[0],), ) brier_scores.append( self( self.estimate[index, :], self.event[index], self.time[index], self.new_time, self.weight[index], self.weight_new_time, instate=False, ) ) # Run without saving internal state elif metric == "compare": # bootstrap samples given null distribution (brierscore1 = brierscore2) index = torch.randint( low=0, high=self.estimate.shape[0] * 2, size=(self.estimate.shape[0],), ) # with prob 0.5, take the weight_new_time from self and with prob 0.5 from other weight_new_time = self.weight_new_time if torch.rand(1) < 0.5 else other.weight_new_time brier_scores.append( self( # sample with replacement from pooled sample torch.cat((self.estimate, other.estimate))[index, :], torch.cat((self.event, other.event))[index], torch.cat((self.time, other.time))[index], self.new_time, torch.cat((self.weight, other.weight))[index], weight_new_time, instate=False, ) ) elif metric == "p_value": # bootstrap samples given null distribution (estimate are not informative) estimate = copy.deepcopy(self.estimate) estimate = estimate[torch.randperm(estimate.shape[0]), :] # Shuffle estimate brier_scores.append( self( estimate, self.event, self.time, self.new_time, self.weight, self.weight_new_time, instate=False, ) ) # Run without saving internal state brier_scores = torch.stack(brier_scores, dim=0) if torch.any(torch.isnan(brier_scores)): raise ValueError("The brier score computed using bootstrap should not be NaN.") return brier_scores @staticmethod def _find_torch_unique_indices(inverse_indices: torch.Tensor, counts: torch.Tensor) -> torch.Tensor: """return unique_sorted_indices such that sorted_unique_tensor[inverse_indices] = original_tensor original_tensor[unique_sorted_indices] = sorted_unique_tensor Usage: _, inverse_indices, counts = torch.unique( x, sorted=True, return_inverse=True, return_counts=True ) sorted_unique_indices = Auc._find_torch_unique_indices( inverse_indices, counts ) """ _, ind_sorted = torch.sort(inverse_indices, stable=True) cum_sum = counts.cumsum(0) cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1])) return ind_sorted[cum_sum] @staticmethod def _validate_brier_score_inputs( estimate: torch.Tensor, time: torch.Tensor, new_time: torch.Tensor, weight: torch.Tensor, weight_new_time: torch.Tensor, ) -> torch.Tensor: # check new_time and weight are provided, weight_new_time should be provided if all([new_time is not None, weight is not None, weight_new_time is None]): raise ValueError("Please provide 'weight_new_time', the weight evaluated at 'new_time'.") # check that estimate has 2 dimensions estimate are probabilities if torch.any(estimate < 0) or torch.any(estimate > 1): raise ValueError("The 'estimate' input should contain estimated survival probabilities between 0 and 1.") # check if estimate is of the correct dimension if estimate.ndim != 2: raise ValueError("The 'estimate' input should have two dimensions.") # check if new_time are not specified and estimate are not evaluated at time if new_time is None and len(time) != estimate.shape[1]: raise ValueError( "Mismatched dimensions: The number of columns in 'estimate' does not match the length of 'time'. " "Please provide the times at which 'estimate' is evaluated using the 'new_time' input." ) @staticmethod def _update_brier_score_new_time( estimate: torch.Tensor, time: torch.Tensor, new_time: torch.Tensor, weight: torch.Tensor, weight_new_time: torch.Tensor, ) -> torch.Tensor: # check format of new_time if new_time is not None: # if new_time are specified: ensure it has the correct format if isinstance(new_time, int): new_time = torch.tensor([new_time]).float() if new_time.ndim == 0: new_time = new_time.unsqueeze(0) else: # else: find new_time # if new_time are not specified, use unique time new_time, inverse_indices, counts = torch.unique(time, sorted=True, return_inverse=True, return_counts=True) sorted_unique_indices = BrierScore._find_torch_unique_indices(inverse_indices, counts) # for time-dependent estimate, select those corresponding to new time estimate = estimate[:, sorted_unique_indices] if weight is not None: # select weight corresponding at new time weight_new_time = weight[sorted_unique_indices] return estimate, new_time, weight, weight_new_time @staticmethod def _update_brier_score_weight( time: torch.Tensor, new_time: torch.Tensor, weight: torch.Tensor, weight_new_time: torch.Tensor, ) -> torch.Tensor: # if weight was not specified, weight of 1 if weight is None: weight = torch.ones_like(time) weight_new_time = torch.ones_like(new_time) return weight, weight_new_time
if __name__ == "__main__": import doctest import sys # Run doctest results = doctest.testmod() if results.failed == 0: print("All tests passed.") else: print("Some doctests failed.") sys.exit(1)