Getting started#

In this notebook, we use TorchSurv to train a model that predicts relative risk of breast cancer recurrence. We use a public data set, the German Breast Cancer Study Group 2 (GBSG2). After training the model, we evaluate the predictive performance using evaluation metrics implemented in TorchSurv.

We first load the dataset using the package lifelines. The GBSG2 dataset contains features and recurrence free survival time (in days) for 686 women undergoing hormonal treatment.

Dependencies#

To run this notebook, dependencies must be installed. the recommended method is to use our development conda environment (preferred). Instruction can be found here to install all optional dependencies. The other method is to install only required packages using the command line below:

[1]:
# Install only required packages (optional)
# %pip install lifelines
# %pip install matplotlib
# %pip install sklearn
# %pip install pandas
[2]:
import warnings

warnings.filterwarnings("ignore")
[3]:
import lifelines
import pandas as pd
import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

# Our package
from torchsurv.loss.cox import neg_partial_log_likelihood
from torchsurv.loss.weibull import neg_log_likelihood, log_hazard, survival_function
from torchsurv.metrics.brier_score import BrierScore
from torchsurv.metrics.cindex import ConcordanceIndex
from torchsurv.metrics.auc import Auc
from torchsurv.stats.kaplan_meier import KaplanMeierEstimator

# PyTorch boilerplate - see https://github.com/Novartis/torchsurv/blob/main/docs/notebooks/helpers_introduction.py
from helpers_introduction import Custom_dataset, plot_losses
[4]:
# Issue with eager mode
# torch._dynamo.config.suppress_errors = True  # Suppress inductor errors
# torch._dynamo.reset()  # Reset the backend
[5]:
# Constant parameters across models
# Detect available accelerator; Downgrade batch size if only CPU available
if any([torch.cuda.is_available(), torch.backends.mps.is_available()]):
    print("CUDA-enabled GPU/TPU is available.")
    BATCH_SIZE = 128  # batch size for training
else:
    print("No CUDA-enabled GPU found, using CPU.")
    BATCH_SIZE = 32  # batch size for training

EPOCHS = 100
LEARNING_RATE = 1e-2
CUDA-enabled GPU/TPU is available.

Dataset overview#

[6]:
# Load GBSG2 dataset
df = lifelines.datasets.load_gbsg2()
df.head(5)
[6]:
horTh age menostat tsize tgrade pnodes progrec estrec time cens
0 no 70 Post 21 II 3 48 66 1814 1
1 yes 56 Post 12 II 7 61 77 2018 1
2 yes 58 Post 35 II 9 52 271 712 1
3 yes 59 Post 17 II 4 60 29 1807 1
4 no 73 Post 35 II 1 26 65 772 1

The dataset contains the categorical features:

  • horTh: hormonal therapy, a factor at two levels (yes and no).

  • age: age of the patients in years.

  • menostat: menopausal status, a factor at two levels pre (premenopausal) and post (postmenopausal).

  • tsize: tumor size (in mm).

  • tgrade: tumor grade, a ordered factor at levels I < II < III.

  • pnodes: number of positive nodes.

  • progrec: progesterone receptor (in fmol).

  • estrec: estrogen receptor (in fmol).

Additionally, it contains our survival targets:

  • time: recurrence free survival time (in days).

  • cens: censoring indicator (0- censored, 1- event).

One common approach is to use a one hot encoder to convert them into numerical features. We then separate the dataframes into features X and labels y. The following code also partitions the labels and features into training and testing cohorts.

Data preparation#

[7]:
df_onehot = pd.get_dummies(df, columns=["horTh", "menostat", "tgrade"]).astype("float")
df_onehot.drop(
    ["horTh_no", "menostat_Post", "tgrade_I"],
    axis=1,
    inplace=True,
)
df_onehot.head(5)
[7]:
age tsize pnodes progrec estrec time cens horTh_yes menostat_Pre tgrade_II tgrade_III
0 70.0 21.0 3.0 48.0 66.0 1814.0 1.0 0.0 0.0 1.0 0.0
1 56.0 12.0 7.0 61.0 77.0 2018.0 1.0 1.0 0.0 1.0 0.0
2 58.0 35.0 9.0 52.0 271.0 712.0 1.0 1.0 0.0 1.0 0.0
3 59.0 17.0 4.0 60.0 29.0 1807.0 1.0 1.0 0.0 1.0 0.0
4 73.0 35.0 1.0 26.0 65.0 772.0 1.0 0.0 0.0 1.0 0.0
[8]:
df_train, df_test = train_test_split(df_onehot, test_size=0.3)
df_train, df_val = train_test_split(df_train, test_size=0.3)
print(
    f"(Sample size) Training:{len(df_train)} | Validation:{len(df_val)} |Testing:{len(df_test)}"
)
(Sample size) Training:336 | Validation:144 |Testing:206

Let us setup the dataloaders for training, validation and testing.

[9]:
# Dataloader
dataloader_train = DataLoader(
    Custom_dataset(df_train), batch_size=BATCH_SIZE, shuffle=True
)
dataloader_val = DataLoader(
    Custom_dataset(df_val), batch_size=len(df_val), shuffle=False
)
dataloader_test = DataLoader(
    Custom_dataset(df_test), batch_size=len(df_test), shuffle=False
)
[10]:
# Sanity check
x, (event, time) = next(iter(dataloader_train))
num_features = x.size(1)

print(f"x (shape)    = {x.shape}")
print(f"num_features = {num_features}")
print(f"event        = {event.shape}")
print(f"time         = {time.shape}")
x (shape)    = torch.Size([128, 9])
num_features = 9
event        = torch.Size([128])
time         = torch.Size([128])

Section 1: Cox proportional hazards model#

In this section, we use the Cox proportional hazards model. Given covariate \(x_{i}\), the hazard of patient \(i\) has the form

\[\lambda (t|x_{i}) =\lambda_{0}(t)\theta(x_{i})\]

The baseline hazard \(\lambda_{0}(t)\) is identical across subjects (i.e., has no dependency on \(i\)). The subject-specific risk of event occurrence is captured through the relative hazards \(\{\theta(x_{i})\}_{i = 1, \dots, N}\).

We train a multi-layer perceptron (MLP) to model the subject-specific risk of event occurrence, i.e., the log relative hazards \(\log\theta(x_{i})\). Patients with lower recurrence time are assumed to have higher risk of event.

Section 1.1: MLP model for log relative hazards#

[11]:
# Initiate Weibull model
cox_model = torch.nn.Sequential(
    torch.nn.BatchNorm1d(num_features),  # Batch normalization
    torch.nn.Linear(num_features, 32),
    torch.nn.ReLU(),
    torch.nn.Dropout(),
    torch.nn.Linear(32, 64),
    torch.nn.ReLU(),
    torch.nn.Dropout(),
    torch.nn.Linear(64, 1),  # Estimating log hazards for Cox models
)

Section 1.2: MLP model training#

[12]:
torch.manual_seed(42)

# Init optimizer for Cox
optimizer = torch.optim.Adam(cox_model.parameters(), lr=LEARNING_RATE)

# Initiate empty list to store the loss on the train and validation sets
train_losses = []
val_losses = []

# training loop
for epoch in range(EPOCHS):
    epoch_loss = torch.tensor(0.0)
    for i, batch in enumerate(dataloader_train):
        x, (event, time) = batch
        optimizer.zero_grad()
        log_hz = cox_model(x)  # shape = (16, 1)
        loss = neg_partial_log_likelihood(log_hz, event, time, reduction="mean")
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach()

    if epoch % (EPOCHS // 10) == 0:
        print(f"Epoch: {epoch:03}, Training loss: {epoch_loss:0.2f}")

    # Record loss on train and test sets
    epoch_loss /= i + 1
    train_losses.append(epoch_loss)
    with torch.no_grad():
        x, (event, time) = next(iter(dataloader_val))
        val_losses.append(
            neg_partial_log_likelihood(cox_model(x), event, time, reduction="mean")
        )
Epoch: 000, Training loss: 13.00
Epoch: 010, Training loss: 12.86
Epoch: 020, Training loss: 12.48
Epoch: 030, Training loss: 12.24
Epoch: 040, Training loss: 12.29
Epoch: 050, Training loss: 12.47
Epoch: 060, Training loss: 12.25
Epoch: 070, Training loss: 11.52
Epoch: 080, Training loss: 11.79
Epoch: 090, Training loss: 12.09

We can visualize the training and validation losses.

[13]:
plot_losses(train_losses, val_losses, "Cox")
../_images/notebooks_introduction_21_0.png

Section 1.3: Cox proportional hazards model evaluation#

We evaluate the predictive performance of the model using

  • the concordance index (C-index), which measures the the probability that a model correctly predicts which of two comparable samples will experience an event first based on their estimated risk scores,

  • the Area Under the Receiver Operating Characteristic Curve (AUC), which measures the probability that a model correctly predicts which of two comparable samples will experience an event by time t based on their estimated risk scores.

We cannot use the Brier score because this model is not able to estimate the survival function.

We start by evaluating the subject-specific relative hazards on the test set

[14]:
cox_model.eval()
with torch.no_grad():
    # test event and test time of length n
    x, (event, time) = next(iter(dataloader_test))
    log_hz = cox_model(x)  # log hazard of length n

We obtain the concordance index, and its confidence interval

[15]:
# Concordance index
cox_cindex = ConcordanceIndex()
print("Cox model performance:")
print(f"Concordance-index   = {cox_cindex(log_hz, event, time)}")
print(f"Confidence interval = {cox_cindex.confidence_interval()}")
Cox model performance:
Concordance-index   = 0.6596463322639465
Confidence interval = tensor([0.5345, 0.7847])

We can also test whether the observed concordance index is greater than 0.5. The statistical test is specified with H0: c-index = 0.5 and Ha: c-index > 0.5. The p-value of the statistical test is

[16]:
# H0: cindex = 0.5, Ha: cindex > 0.5
print("p-value = {}".format(cox_cindex.p_value(alternative="greater")))
p-value = 0.006189107894897461

For time-dependent prediction (e.g., 5-year mortality), the C-index is not a proper measure. Instead, it is recommended to use the AUC. The probability to correctly predicts which of two comparable patients will experience an event by 5-year based on their estimated risk scores is the AUC evaluated at 5-year (1825 days) obtained with

[17]:
cox_auc = Auc()

new_time = torch.tensor(1825.0)

# auc evaluated at new time = 1825, 5 year
print(f"AUC 5-yr             = {cox_auc(log_hz, event, time, new_time=new_time)}")
print(f"AUC 5-yr (conf int.) = {cox_auc.confidence_interval()}")
AUC 5-yr             = tensor([0.6188])
AUC 5-yr (conf int.) = tensor([0.5648, 0.6729])

As before, we can test whether the observed Auc at 5-year is greater than 0.5. The statistical test is specified with H0: auc = 0.5 and Ha: auc > 0.5. The p-value of the statistical test is

[18]:
print(f"AUC (p_value) = {cox_auc.p_value()}")
AUC (p_value) = tensor([1.6212e-05])

Section 2: Weibull accelerated failure time (AFT) model#

In this section, we use the Weibull accelerated failure (AFT) model. Given covariate \(x_{i}\), the hazard of patient \(i\) at time \(t\) has the form

\[\lambda (t|x_{i}) = \frac{\rho(x_{i}) } {\lambda(x_{i}) } + \left(\frac{t}{\lambda(x_{i})}\right)^{\rho(x_{i}) - 1}\]

Given the hazard form, it can be shown that the event density follows a Weibull distribution parametrized by scale \(\lambda(x_{i})\) and shape \(\rho(x_{i})\). The subject-specific risk of event occurrence at time \(t\) is captured through the hazards \(\{\lambda (t|x_{i})\}_{i = 1, \dots, N}\). We train a multi-layer perceptron (MLP) to model the subject-specific log scale, \(\log \lambda(x_{i})\), and the log shape, \(\log\rho(x_{i})\).

Section 2.1: MLP model for log scale and log shape#

[19]:
# Same architecture than Cox model, beside outputs dimension
weibull_model = torch.nn.Sequential(
    torch.nn.BatchNorm1d(num_features),  # Batch normalization
    torch.nn.Linear(num_features, 32),
    torch.nn.ReLU(),
    torch.nn.Dropout(),
    torch.nn.Linear(32, 64),
    torch.nn.ReLU(),
    torch.nn.Dropout(),
    torch.nn.Linear(64, 2),  # Estimating log parameters for Weibull model
)

Section 2.2: MLP model training#

[20]:
torch.manual_seed(42)

# Init optimizer for Weibull
optimizer = torch.optim.Adam(weibull_model.parameters(), lr=LEARNING_RATE)

# Initialize empty list to store loss on train and validation sets
train_losses = []
val_losses = []

# training loop
for epoch in range(EPOCHS):
    epoch_loss = torch.tensor(0.0)
    for i, batch in enumerate(dataloader_train):
        x, (event, time) = batch
        optimizer.zero_grad()
        log_params = weibull_model(x)  # shape = (16, 2)
        loss = neg_log_likelihood(log_params, event, time, reduction="mean")
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach()

    if epoch % (EPOCHS // 10) == 0:
        print(f"Epoch: {epoch:03}, Training loss: {epoch_loss:0.2f}")

    # Record losses for the following figure
    train_losses.append(epoch_loss)
    with torch.no_grad():
        x, (event, time) = next(iter(dataloader_val))
        val_losses.append(
            neg_log_likelihood(weibull_model(x), event, time, reduction="mean")
        )
Epoch: 000, Training loss: 19618.58
Epoch: 010, Training loss: 19.42
Epoch: 020, Training loss: 17.05
Epoch: 030, Training loss: 17.76
Epoch: 040, Training loss: 17.15
Epoch: 050, Training loss: 18.52
Epoch: 060, Training loss: 17.47
Epoch: 070, Training loss: 16.94
Epoch: 080, Training loss: 16.39
Epoch: 090, Training loss: 16.19

We can visualize the training and validation losses.

[21]:
plot_losses(train_losses, val_losses, "Weibull")
../_images/notebooks_introduction_40_0.png

Section 2.3: Weibull AFT model evaluation#

We evaluate the predictive performance of the model using

  • the C-index, which measures the the probability that a model correctly predicts which of two comparable samples will experience an event first based on their estimated risk scores,

  • the AUC, which measures the probability that a model correctly predicts which of two comparable samples will experience an event by time t based on their estimated risk scores, and

  • the Brier score, which measures the model’s calibration by calculating the mean square error between the estimated survival function and the empirical (i.e., in-sample) event status.

We start by obtaining the subject-specific log hazard and survival probability at every time \(t\) observed on the test set

[22]:
weibull_model.eval()
with torch.no_grad():
    # event and time of length n
    x, (event, time) = next(iter(dataloader_test))
    log_params = weibull_model(x)  # shape = (n,2)

# Compute the log hazards from weibull log parameters
log_hz = log_hazard(log_params, time)  # shape = (n,n)

# Compute the survival probability from weibull log parameters
surv = survival_function(log_params, time)  # shape = (n,n)

We can evaluate the concordance index, its confidence interval and the p-value of the statistical test testing whether the c-index is greater than 0.5:

[23]:
# Concordance index
weibull_cindex = ConcordanceIndex()
print("Weibull model performance:")
print(f"Concordance-index   = {weibull_cindex(log_hz, event, time)}")
print(f"Confidence interval = {weibull_cindex.confidence_interval()}")

# H0: cindex = 0.5, Ha: cindex >0.5
print(f"p-value             = {weibull_cindex.p_value(alternative = 'greater')}")
Weibull model performance:
Concordance-index   = 0.43771111965179443
Confidence interval = tensor([0.2926, 0.5828])
p-value             = 0.7999081611633301

For time-dependent prediction (e.g., 5-year mortality), the C-index is not a proper measure. Instead, it is recommended to use the AUC. The probability to correctly predicts which of two comparable patients will experience an event by 5-year based on their estimated risk scores is the AUC evaluated at 5-year (1825 days) obtained with

[24]:
new_time = torch.tensor(1825.0)

# subject-specific log hazard at \5-yr
log_hz_t = log_hazard(log_params, time=new_time)  # shape = (n)
weibull_auc = Auc()

# auc evaluated at new time = 1825, 5 year
print(f"AUC 5-yr             = {weibull_auc(log_hz_t, event, time, new_time=new_time)}")
print(f"AUC 5-yr (conf int.) = {weibull_auc.confidence_interval()}")
print(f"AUC 5-yr (p value)   = {weibull_auc.p_value(alternative='greater')}")
AUC 5-yr             = tensor([0.3919])
AUC 5-yr (conf int.) = tensor([0.3488, 0.4350])
AUC 5-yr (p value)   = tensor([1.0000])

Lastly, we can evaluate the time-dependent Brier score and the integrated Brier score

[25]:
brier_score = BrierScore()

# brier score at first 5 times
print(f"Brier score             = {brier_score(surv, event, time)[:5]}")
print(f"Brier score (conf int.) = {brier_score.confidence_interval()[:,:5]}")

# integrated brier score
print(f"Integrated Brier score  = {brier_score.integral()}")
Brier score             = tensor([0.3967, 0.4106, 0.4096, 0.4256, 0.4255])
Brier score (conf int.) = tensor([[0.3928, 0.4046, 0.4023, 0.4164, 0.4153],
        [0.4006, 0.4167, 0.4168, 0.4347, 0.4356]])
Integrated Brier score  = 0.25487515330314636

We can test whether the time-dependent Brier score is smaller than what would be expected if the survival model was not providing accurate predictions beyond random chance. We use a bootstrap permutation test and obtain the p-value with:

[26]:
# H0: bs = bs0, Ha: bs < bs0; where bs0 is the expected brier score if the survival model was not providing accurate predictions beyond random chance.

# p-value for brier score at first 5 times
print(f"Brier score (p-val)        = {brier_score.p_value(alternative = 'less')[:5]}")
Brier score (p-val)        = tensor([0.4910, 0.7000, 0.1750, 0.2450, 0.3430])

Section 3: Models comparison#

We can compare the predictive performance of the Cox proportional hazards model against the Weibull AFT model.

Section 3.1: Concordance index#

The statistical test is formulated as follows, H0: cindex cox = cindex weibull, Ha: cindex cox > cindex weibull

[27]:
print(f"Cox cindex     = {cox_cindex.cindex}")
print(f"Weibull cindex = {weibull_cindex.cindex}")
print("p-value        = {}".format(cox_cindex.compare(weibull_cindex)))
Cox cindex     = 0.6596463322639465
Weibull cindex = 0.43771111965179443
p-value        = 0.008534319698810577

Section 3.2: AUC at 5-year#

The statistical test is formulated as follows, H0: 5-yr auc cox = 5-yr auc weibull, Ha: 5-yr auc cox > 5-yr auc weibull

[28]:
print(f"Cox 5-yr AUC     = {cox_auc.auc}")
print(f"Weibull 5-yr AUC = {weibull_auc.auc}")
print("p-value          = {}".format(cox_auc.compare(weibull_auc)))
Cox 5-yr AUC     = tensor([0.6188])
Weibull 5-yr AUC = tensor([0.3919])
p-value          = tensor([2.2239e-11])

Section 4: Kaplan Meier#

[29]:
# Create a Kaplan-Meier estimator
km = KaplanMeierEstimator()

# Use our observed testing dataset
event = torch.tensor(df_test["cens"].values).bool()
time = torch.tensor(df_test["time"].values)

# Compute the estimator
km(event, time)
[30]:
# plot estimate
km.plot_km()
../_images/notebooks_introduction_59_0.png
[31]:
# Print the survival values at each time step
km.print_survival_table()
Time    Survival
----------------
8.00    1.0000
17.00   1.0000
18.00   1.0000
42.00   1.0000
46.00   1.0000
57.00   1.0000
63.00   1.0000
114.00  1.0000
120.00  0.9949
160.00  0.9899
168.00  0.9899
186.00  0.9899
195.00  0.9899
229.00  0.9899
233.00  0.9847
275.00  0.9796
276.00  0.9796
285.00  0.9744
288.00  0.9692
310.00  0.9692
343.00  0.9640
358.00  0.9588
372.00  0.9536
374.00  0.9484
377.00  0.9432
385.00  0.9380
420.00  0.9327
426.00  0.9275
429.00  0.9275
432.00  0.9275
455.00  0.9223
456.00  0.9170
471.00  0.9117
481.00  0.9065
486.00  0.9012
515.00  0.8959
525.00  0.8906
529.00  0.8854
533.00  0.8801
537.00  0.8748
541.00  0.8748
545.00  0.8694
548.00  0.8641
550.00  0.8534
563.00  0.8481
564.00  0.8428
567.00  0.8428
570.00  0.8428
571.00  0.8374
578.00  0.8320
622.00  0.8266
624.00  0.8212
629.00  0.8158
648.00  0.8104
657.00  0.8104
663.00  0.8104
675.00  0.8049
687.00  0.7994
695.00  0.7994
722.00  0.7994
730.00  0.7938
734.00  0.7938
740.00  0.7938
747.00  0.7881
766.00  0.7881
768.00  0.7881
770.00  0.7881
772.00  0.7822
792.00  0.7822
799.00  0.7763
806.00  0.7763
838.00  0.7704
842.00  0.7645
857.00  0.7586
861.00  0.7526
865.00  0.7467
867.00  0.7467
877.00  0.7467
891.00  0.7407
892.00  0.7407
918.00  0.7346
933.00  0.7346
936.00  0.7346
940.00  0.7346
956.00  0.7284
973.00  0.7284
986.00  0.7284
1077.00 0.7284
1078.00 0.7284
1080.00 0.7219
1088.00 0.7219
1090.00 0.7219
1091.00 0.7219
1093.00 0.7153
1095.00 0.7153
1100.00 0.7153
1108.00 0.7086
1109.00 0.7086
1117.00 0.7086
1120.00 0.7017
1125.00 0.7017
1146.00 0.6948
1150.00 0.6878
1157.00 0.6809
1170.00 0.6739
1193.00 0.6670
1205.00 0.6670
1207.00 0.6599
1222.00 0.6599
1230.00 0.6599
1246.00 0.6527
1253.00 0.6454
1306.00 0.6382
1331.00 0.6382
1337.00 0.6308
1340.00 0.6308
1341.00 0.6308
1349.00 0.6308
1352.00 0.6232
1356.00 0.6232
1363.00 0.6155
1364.00 0.6155
1371.00 0.6076
1434.00 0.6076
1441.00 0.6076
1443.00 0.6076
1459.00 0.5994
1469.00 0.5994
1490.00 0.5994
1493.00 0.5909
1502.00 0.5825
1514.00 0.5825
1582.00 0.5825
1587.00 0.5736
1598.00 0.5736
1604.00 0.5736
1617.00 0.5736
1624.00 0.5736
1625.00 0.5736
1629.00 0.5736
1632.00 0.5736
1642.00 0.5736
1645.00 0.5736
1653.00 0.5736
1655.00 0.5736
1666.00 0.5736
1675.00 0.5736
1679.00 0.5622
1685.00 0.5622
1693.00 0.5622
1701.00 0.5622
1702.00 0.5622
1722.00 0.5622
1735.00 0.5622
1751.00 0.5622
1756.00 0.5622
1760.00 0.5622
1765.00 0.5622
1806.00 0.5478
1807.00 0.5333
1818.00 0.5333
1820.00 0.5333
1821.00 0.5333
1838.00 0.5333
1842.00 0.5333
1846.00 0.5333
1847.00 0.5333
1854.00 0.5333
1855.00 0.5333
1856.00 0.5333
1878.00 0.5333
1904.00 0.5333
1905.00 0.5333
1922.00 0.5333
1933.00 0.5333
1938.00 0.5333
1965.00 0.5333
1975.00 0.5053
1976.00 0.5053
1977.00 0.5053
2009.00 0.5053
2010.00 0.5053
2014.00 0.5053
2027.00 0.5053
2039.00 0.4632
2049.00 0.4632
2093.00 0.4169
2138.00 0.4169
2156.00 0.4169
2195.00 0.4169
2217.00 0.4169
2239.00 0.4169
2286.00 0.3126
2372.00 0.2084
2380.00 0.2084
2388.00 0.2084