Getting started#

In this notebook, we use TorchSurv to train a model that predicts relative risk of breast cancer recurrence. We use a public data set, the German Breast Cancer Study Group 2 (GBSG2). After training the model, we evaluate the predictive performance using evaluation metrics implemented in TorchSurv.

We first load the dataset using the package lifelines. The GBSG2 dataset contains features and recurrence free survival time (in days) for 686 women undergoing hormonal treatment.

Dependencies#

To run this notebook, dependencies must be installed. the recommended method is to use our development conda environment (preferred). Instruction can be found here to install all optional dependencies. The other method is to install only required packages using the command line below:

[1]:
# Install only required packages (optional)
# %pip install lifelines
# %pip install matplotlib
# %pip install sklearn
# %pip install pandas
[2]:
import warnings

warnings.filterwarnings("ignore")
[3]:
import lifelines
import pandas as pd
import torch

# PyTorch boilerplate - see https://github.com/Novartis/torchsurv/blob/main/docs/notebooks/helpers_introduction.py
from helpers_introduction import Custom_dataset, plot_losses
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

# Our package
from torchsurv.loss.cox import neg_partial_log_likelihood
from torchsurv.loss.weibull import (
    log_hazard,
    neg_log_likelihood,
    survival_function,
)
from torchsurv.metrics.auc import Auc
from torchsurv.metrics.brier_score import BrierScore
from torchsurv.metrics.cindex import ConcordanceIndex
from torchsurv.stats.kaplan_meier import KaplanMeierEstimator
[4]:
# Issue with eager mode
# torch._dynamo.config.suppress_errors = True  # Suppress inductor errors
# torch._dynamo.reset()  # Reset the backend
[5]:
# Constant parameters across models
# Detect available accelerator; Downgrade batch size if only CPU available
if any([torch.cuda.is_available(), torch.backends.mps.is_available()]):
    print("CUDA-enabled GPU/TPU is available.")
    BATCH_SIZE = 128  # batch size for training
else:
    print("No CUDA-enabled GPU found, using CPU.")
    BATCH_SIZE = 32  # batch size for training

EPOCHS = 100
LEARNING_RATE = 1e-2
CUDA-enabled GPU/TPU is available.

Dataset overview#

[6]:
# Load GBSG2 dataset
df = lifelines.datasets.load_gbsg2()
df.head(5)
[6]:
horTh age menostat tsize tgrade pnodes progrec estrec time cens
0 no 70 Post 21 II 3 48 66 1814 1
1 yes 56 Post 12 II 7 61 77 2018 1
2 yes 58 Post 35 II 9 52 271 712 1
3 yes 59 Post 17 II 4 60 29 1807 1
4 no 73 Post 35 II 1 26 65 772 1

The dataset contains the categorical features:

  • horTh: hormonal therapy, a factor at two levels (yes and no).

  • age: age of the patients in years.

  • menostat: menopausal status, a factor at two levels pre (premenopausal) and post (postmenopausal).

  • tsize: tumor size (in mm).

  • tgrade: tumor grade, a ordered factor at levels I < II < III.

  • pnodes: number of positive nodes.

  • progrec: progesterone receptor (in fmol).

  • estrec: estrogen receptor (in fmol).

Additionally, it contains our survival targets:

  • time: recurrence free survival time (in days).

  • cens: censoring indicator (0- censored, 1- event).

One common approach is to use a one hot encoder to convert them into numerical features. We then separate the dataframes into features X and labels y. The following code also partitions the labels and features into training and testing cohorts.

Data preparation#

[7]:
df_onehot = pd.get_dummies(df, columns=["horTh", "menostat", "tgrade"]).astype("float")
df_onehot.drop(
    ["horTh_no", "menostat_Post", "tgrade_I"],
    axis=1,
    inplace=True,
)
df_onehot.head(5)
[7]:
age tsize pnodes progrec estrec time cens horTh_yes menostat_Pre tgrade_II tgrade_III
0 70.0 21.0 3.0 48.0 66.0 1814.0 1.0 0.0 0.0 1.0 0.0
1 56.0 12.0 7.0 61.0 77.0 2018.0 1.0 1.0 0.0 1.0 0.0
2 58.0 35.0 9.0 52.0 271.0 712.0 1.0 1.0 0.0 1.0 0.0
3 59.0 17.0 4.0 60.0 29.0 1807.0 1.0 1.0 0.0 1.0 0.0
4 73.0 35.0 1.0 26.0 65.0 772.0 1.0 0.0 0.0 1.0 0.0
[8]:
df_train, df_test = train_test_split(df_onehot, test_size=0.3)
df_train, df_val = train_test_split(df_train, test_size=0.3)
print(f"(Sample size) Training:{len(df_train)} | Validation:{len(df_val)} |Testing:{len(df_test)}")
(Sample size) Training:336 | Validation:144 |Testing:206

Let us setup the dataloaders for training, validation and testing.

[9]:
# Dataloader
dataloader_train = DataLoader(Custom_dataset(df_train), batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(Custom_dataset(df_val), batch_size=len(df_val), shuffle=False)
dataloader_test = DataLoader(Custom_dataset(df_test), batch_size=len(df_test), shuffle=False)
[10]:
# Sanity check
x, (event, time) = next(iter(dataloader_train))
num_features = x.size(1)

print(f"x (shape)    = {x.shape}")
print(f"num_features = {num_features}")
print(f"event        = {event.shape}")
print(f"time         = {time.shape}")
x (shape)    = torch.Size([128, 9])
num_features = 9
event        = torch.Size([128])
time         = torch.Size([128])

Section 1: Cox proportional hazards model#

In this section, we use the Cox proportional hazards model. Given covariate \(x_{i}\), the hazard of patient \(i\) has the form

\[\lambda (t|x_{i}) =\lambda_{0}(t)\theta(x_{i})\]

The baseline hazard \(\lambda_{0}(t)\) is identical across subjects (i.e., has no dependency on \(i\)). The subject-specific risk of event occurrence is captured through the relative hazards \(\{\theta(x_{i})\}_{i = 1, \dots, N}\).

We train a multi-layer perceptron (MLP) to model the subject-specific risk of event occurrence, i.e., the log relative hazards \(\log\theta(x_{i})\). Patients with lower recurrence time are assumed to have higher risk of event.

Section 1.1: MLP model for log relative hazards#

[11]:
# Initiate Weibull model
cox_model = torch.nn.Sequential(
    torch.nn.BatchNorm1d(num_features),  # Batch normalization
    torch.nn.Linear(num_features, 32),
    torch.nn.ReLU(),
    torch.nn.Dropout(),
    torch.nn.Linear(32, 64),
    torch.nn.ReLU(),
    torch.nn.Dropout(),
    torch.nn.Linear(64, 1),  # Estimating log hazards for Cox models
)

Section 1.2: MLP model training#

[12]:
torch.manual_seed(42)

# Init optimizer for Cox
optimizer = torch.optim.Adam(cox_model.parameters(), lr=LEARNING_RATE)

# Initiate empty list to store the loss on the train and validation sets
train_losses = []
val_losses = []

# training loop
for epoch in range(EPOCHS):
    epoch_loss = torch.tensor(0.0)
    for _, batch in enumerate(dataloader_train):
        x, (event, time) = batch
        optimizer.zero_grad()
        log_hz = cox_model(x)  # shape = (16, 1)
        loss = neg_partial_log_likelihood(log_hz, event, time, reduction="mean")
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach()

    if epoch % (EPOCHS // 10) == 0:
        print(f"Epoch: {epoch:03}, Training loss: {epoch_loss:0.2f}")

    # Record loss on train and test sets
    train_losses.append(epoch_loss)
    with torch.no_grad():
        x, (event, time) = next(iter(dataloader_val))
        val_losses.append(neg_partial_log_likelihood(cox_model(x), event, time, reduction="mean"))
Epoch: 000, Training loss: 12.87
Epoch: 010, Training loss: 11.90
Epoch: 020, Training loss: 11.76
Epoch: 030, Training loss: 11.80
Epoch: 040, Training loss: 11.77
Epoch: 050, Training loss: 11.72
Epoch: 060, Training loss: 11.62
Epoch: 070, Training loss: 11.39
Epoch: 080, Training loss: 11.58
Epoch: 090, Training loss: 11.46

We can visualize the training and validation losses.

[13]:
plot_losses(train_losses, val_losses, "Cox")
../_images/notebooks_introduction_21_0.png

Section 1.3: Cox proportional hazards model evaluation#

We evaluate the predictive performance of the model using

  • the concordance index (C-index), which measures the the probability that a model correctly predicts which of two comparable samples will experience an event first based on their estimated risk scores,

  • the Area Under the Receiver Operating Characteristic Curve (AUC), which measures the probability that a model correctly predicts which of two comparable samples will experience an event by time t based on their estimated risk scores.

We cannot use the Brier score because this model is not able to estimate the survival function.

We start by evaluating the subject-specific relative hazards on the test set

[14]:
cox_model.eval()
with torch.no_grad():
    # test event and test time of length n
    x, (event, time) = next(iter(dataloader_test))
    log_hz = cox_model(x)  # log hazard of length n

We obtain the concordance index, and its confidence interval

[15]:
# Concordance index
cox_cindex = ConcordanceIndex()
print("Cox model performance:")
print(f"Concordance-index   = {cox_cindex(log_hz, event, time)}")
print(f"Confidence interval = {cox_cindex.confidence_interval()}")
Cox model performance:
Concordance-index   = 0.6471903324127197
Confidence interval = tensor([0.5201, 0.7743])

We can also test whether the observed concordance index is greater than 0.5. The statistical test is specified with H0: c-index = 0.5 and Ha: c-index > 0.5. The p-value of the statistical test is

[16]:
# H0: cindex = 0.5, Ha: cindex > 0.5
print("p-value = {}".format(cox_cindex.p_value(alternative="greater")))
p-value = 0.01163017749786377

For time-dependent prediction (e.g., 5-year mortality), the C-index is not a proper measure. Instead, it is recommended to use the AUC. The probability to correctly predicts which of two comparable patients will experience an event by 5-year based on their estimated risk scores is the AUC evaluated at 5-year (1825 days) obtained with

[17]:
cox_auc = Auc()

new_time = torch.tensor(1825.0)

# auc evaluated at new time = 1825, 5 year
print(f"AUC 5-yr             = {cox_auc(log_hz, event, time, new_time=new_time)}")
print(f"AUC 5-yr (conf int.) = {cox_auc.confidence_interval()}")
AUC 5-yr             = tensor([0.6755])
AUC 5-yr (conf int.) = tensor([0.6167, 0.7343])

As before, we can test whether the observed Auc at 5-year is greater than 0.5. The statistical test is specified with H0: auc = 0.5 and Ha: auc > 0.5. The p-value of the statistical test is

[18]:
print(f"AUC (p_value) = {cox_auc.p_value()}")
AUC (p_value) = tensor([0.])

Section 2: Weibull accelerated failure time (AFT) model#

In this section, we use the Weibull accelerated failure (AFT) model. Given covariate \(x_{i}\), the hazard of patient \(i\) at time \(t\) has the form

\[\lambda (t|x_{i}) = \frac{\rho(x_{i}) } {\lambda(x_{i}) } + \left(\frac{t}{\lambda(x_{i})}\right)^{\rho(x_{i}) - 1}\]

Given the hazard form, it can be shown that the event density follows a Weibull distribution parametrized by scale \(\lambda(x_{i})\) and shape \(\rho(x_{i})\). The subject-specific risk of event occurrence at time \(t\) is captured through the hazards \(\{\lambda (t|x_{i})\}_{i = 1, \dots, N}\). We train a multi-layer perceptron (MLP) to model the subject-specific log scale, \(\log \lambda(x_{i})\), and the log shape, \(\log\rho(x_{i})\).

Section 2.1: MLP model for log scale and log shape#

[19]:
# Same architecture than Cox model, beside outputs dimension
weibull_model = torch.nn.Sequential(
    torch.nn.BatchNorm1d(num_features),  # Batch normalization
    torch.nn.Linear(num_features, 32),
    torch.nn.ReLU(),
    torch.nn.Dropout(),
    torch.nn.Linear(32, 64),
    torch.nn.ReLU(),
    torch.nn.Dropout(),
    torch.nn.Linear(64, 2),  # Estimating log parameters for Weibull model
)

Section 2.2: MLP model training#

[20]:
torch.manual_seed(42)

# Init optimizer for Weibull
optimizer = torch.optim.Adam(weibull_model.parameters(), lr=LEARNING_RATE)

# Initialize empty list to store loss on train and validation sets
train_losses = []
val_losses = []

# training loop
for epoch in range(EPOCHS):
    epoch_loss = torch.tensor(0.0)
    for _, batch in enumerate(dataloader_train):
        x, (event, time) = batch
        optimizer.zero_grad()
        log_params = weibull_model(x)  # shape = (16, 2)
        loss = neg_log_likelihood(log_params, event, time, reduction="mean")
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach()

    if epoch % (EPOCHS // 10) == 0:
        print(f"Epoch: {epoch:03}, Training loss: {epoch_loss:0.2f}")

    # Record losses for the following figure
    train_losses.append(epoch_loss)
    with torch.no_grad():
        x, (event, time) = next(iter(dataloader_val))
        val_losses.append(neg_log_likelihood(weibull_model(x), event, time, reduction="mean"))
Epoch: 000, Training loss: 163844.06
Epoch: 010, Training loss: 19.96
Epoch: 020, Training loss: 18.81
Epoch: 030, Training loss: 18.66
Epoch: 040, Training loss: 17.68
Epoch: 050, Training loss: 17.75
Epoch: 060, Training loss: 18.05
Epoch: 070, Training loss: 17.60
Epoch: 080, Training loss: 17.81
Epoch: 090, Training loss: 17.70

We can visualize the training and validation losses.

[21]:
plot_losses(train_losses, val_losses, "Weibull")
../_images/notebooks_introduction_40_0.png

Section 2.3: Weibull AFT model evaluation#

We evaluate the predictive performance of the model using

  • the C-index, which measures the the probability that a model correctly predicts which of two comparable samples will experience an event first based on their estimated risk scores,

  • the AUC, which measures the probability that a model correctly predicts which of two comparable samples will experience an event by time t based on their estimated risk scores, and

  • the Brier score, which measures the model’s calibration by calculating the mean square error between the estimated survival function and the empirical (i.e., in-sample) event status.

We start by obtaining the subject-specific log hazard and survival probability at every time \(t\) observed on the test set

[22]:
weibull_model.eval()
with torch.no_grad():
    # event and time of length n
    x, (event, time) = next(iter(dataloader_test))
    log_params = weibull_model(x)  # shape = (n,2)

# Compute the log hazards from weibull log parameters
log_hz = log_hazard(log_params, time)  # shape = (n,n)

# Compute the survival probability from weibull log parameters
surv = survival_function(log_params, time)  # shape = (n,n)

We can evaluate the concordance index, its confidence interval and the p-value of the statistical test testing whether the c-index is greater than 0.5:

[23]:
# Concordance index
weibull_cindex = ConcordanceIndex()
print("Weibull model performance:")
print(f"Concordance-index   = {weibull_cindex(log_hz, event, time)}")
print(f"Confidence interval = {weibull_cindex.confidence_interval()}")

# H0: cindex = 0.5, Ha: cindex >0.5
print(f"p-value             = {weibull_cindex.p_value(alternative='greater')}")
Weibull model performance:
Concordance-index   = 0.4526662230491638
Confidence interval = tensor([0.3259, 0.5795])
p-value             = 0.7678300142288208

For time-dependent prediction (e.g., 5-year mortality), the C-index is not a proper measure. Instead, it is recommended to use the AUC. The probability to correctly predicts which of two comparable patients will experience an event by 5-year based on their estimated risk scores is the AUC evaluated at 5-year (1825 days) obtained with

[24]:
new_time = torch.tensor(1825.0)

# subject-specific log hazard at \5-yr
log_hz_t = log_hazard(log_params, time=new_time)  # shape = (n)
weibull_auc = Auc()

# auc evaluated at new time = 1825, 5 year
print(f"AUC 5-yr             = {weibull_auc(log_hz_t, event, time, new_time=new_time)}")
print(f"AUC 5-yr (conf int.) = {weibull_auc.confidence_interval()}")
print(f"AUC 5-yr (p value)   = {weibull_auc.p_value(alternative='greater')}")
AUC 5-yr             = tensor([0.4180])
AUC 5-yr (conf int.) = tensor([0.3691, 0.4668])
AUC 5-yr (p value)   = tensor([0.9995])

Lastly, we can evaluate the time-dependent Brier score and the integrated Brier score

[25]:
brier_score = BrierScore()

# brier score at first 5 times
print(f"Brier score             = {brier_score(surv, event, time)[:5]}")
print(f"Brier score (conf int.) = {brier_score.confidence_interval()[:, :5]}")

# integrated brier score
print(f"Integrated Brier score  = {brier_score.integral()}")
Brier score             = tensor([0.4156, 0.4506, 0.4603, 0.4584, 0.4573])
Brier score (conf int.) = tensor([[0.4108, 0.4427, 0.4503, 0.4479, 0.4463],
        [0.4203, 0.4586, 0.4702, 0.4690, 0.4684]])
Integrated Brier score  = 0.24310742318630219

We can test whether the time-dependent Brier score is smaller than what would be expected if the survival model was not providing accurate predictions beyond random chance. We use a bootstrap permutation test and obtain the p-value with:

[26]:
# H0: bs = bs0, Ha: bs < bs0; where bs0 is the expected brier score if the survival model was not providing accurate predictions beyond random chance.

# p-value for brier score at first 5 times
print(f"Brier score (p-val)        = {brier_score.p_value(alternative='less')[:5]}")
Brier score (p-val)        = tensor([0.1420, 0.3500, 0.4240, 0.3260, 0.4380])

Section 3: Models comparison#

We can compare the predictive performance of the Cox proportional hazards model against the Weibull AFT model.

Section 3.1: Concordance index#

The statistical test is formulated as follows, H0: cindex cox = cindex weibull, Ha: cindex cox > cindex weibull

[27]:
print(f"Cox cindex     = {cox_cindex.cindex}")
print(f"Weibull cindex = {weibull_cindex.cindex}")
print(f"p-value        = {cox_cindex.compare(weibull_cindex)}")
Cox cindex     = 0.6471903324127197
Weibull cindex = 0.4526662230491638
p-value        = 0.01628177985548973

Section 3.2: AUC at 5-year#

The statistical test is formulated as follows, H0: 5-yr auc cox = 5-yr auc weibull, Ha: 5-yr auc cox > 5-yr auc weibull

[28]:
print(f"Cox 5-yr AUC     = {cox_auc.auc}")
print(f"Weibull 5-yr AUC = {weibull_auc.auc}")
print(f"p-value          = {cox_auc.compare(weibull_auc)}")
Cox 5-yr AUC     = tensor([0.6755])
Weibull 5-yr AUC = tensor([0.4180])
p-value          = tensor([8.6601e-11])

Section 4: Kaplan Meier#

[29]:
# Create a Kaplan-Meier estimator
km = KaplanMeierEstimator()

# Use our observed testing dataset
event = torch.tensor(df_test["cens"].values).bool()
time = torch.tensor(df_test["time"].values)

# Compute the estimator
km(event, time)
[30]:
# plot estimate
km.plot_km()
../_images/notebooks_introduction_59_0.png
[31]:
# Print the survival values at each time step
km.print_survival_table()
Time    Survival
----------------
18.00   1.0000
98.00   0.9951
168.00  0.9951
169.00  0.9902
173.00  0.9853
177.00  0.9804
191.00  0.9755
205.00  0.9706
233.00  0.9657
273.00  0.9657
276.00  0.9657
281.00  0.9608
296.00  0.9608
307.00  0.9558
308.00  0.9508
319.00  0.9508
329.00  0.9458
338.00  0.9408
358.00  0.9358
359.00  0.9308
369.00  0.9258
370.00  0.9208
385.00  0.9158
403.00  0.9108
420.00  0.9058
426.00  0.8958
429.00  0.8958
448.00  0.8907
465.00  0.8857
471.00  0.8807
486.00  0.8756
491.00  0.8706
495.00  0.8656
498.00  0.8605
500.00  0.8555
502.00  0.8505
530.00  0.8454
542.00  0.8404
546.00  0.8404
548.00  0.8353
552.00  0.8303
553.00  0.8303
554.00  0.8252
559.00  0.8201
563.00  0.8150
570.00  0.8150
575.00  0.8099
577.00  0.8047
586.00  0.7996
612.00  0.7945
622.00  0.7894
623.00  0.7894
624.00  0.7842
632.00  0.7790
637.00  0.7790
648.00  0.7739
650.00  0.7739
652.00  0.7739
663.00  0.7739
670.00  0.7686
692.00  0.7686
707.00  0.7632
723.00  0.7632
740.00  0.7632
741.00  0.7632
754.00  0.7577
758.00  0.7577
761.00  0.7577
766.00  0.7577
776.00  0.7521
784.00  0.7465
792.00  0.7465
795.00  0.7408
797.00  0.7352
805.00  0.7295
836.00  0.7239
841.00  0.7239
842.00  0.7182
857.00  0.7182
867.00  0.7182
876.00  0.7124
916.00  0.7124
918.00  0.7065
945.00  0.7065
956.00  0.7006
964.00  0.6946
972.00  0.6946
974.00  0.6946
981.00  0.6885
983.00  0.6824
995.00  0.6824
1002.00 0.6762
1059.00 0.6701
1077.00 0.6701
1078.00 0.6701
1080.00 0.6638
1089.00 0.6638
1090.00 0.6638
1091.00 0.6638
1094.00 0.6638
1095.00 0.6638
1100.00 0.6638
1109.00 0.6638
1117.00 0.6638
1125.00 0.6638
1140.00 0.6570
1170.00 0.6501
1185.00 0.6501
1207.00 0.6432
1219.00 0.6363
1232.00 0.6363
1233.00 0.6363
1240.00 0.6363
1246.00 0.6292
1283.00 0.6292
1296.00 0.6219
1306.00 0.6146
1329.00 0.6073
1337.00 0.5999
1340.00 0.5999
1341.00 0.5999
1342.00 0.5999
1343.00 0.5922
1350.00 0.5922
1352.00 0.5844
1355.00 0.5844
1358.00 0.5844
1363.00 0.5763
1364.00 0.5763
1449.00 0.5681
1469.00 0.5681
1486.00 0.5681
1499.00 0.5681
1502.00 0.5592
1505.00 0.5592
1514.00 0.5592
1528.00 0.5499
1578.00 0.5499
1587.00 0.5404
1589.00 0.5309
1600.00 0.5309
1603.00 0.5309
1617.00 0.5309
1624.00 0.5309
1625.00 0.5309
1632.00 0.5309
1641.00 0.5203
1653.00 0.5203
1666.00 0.5203
1679.00 0.5092
1680.00 0.5092
1717.00 0.5092
1722.00 0.5092
1730.00 0.4974
1735.00 0.4974
1756.00 0.4974
1760.00 0.4974
1765.00 0.4974
1789.00 0.4974
1814.00 0.4839
1826.00 0.4839
1838.00 0.4839
1841.00 0.4839
1842.00 0.4839
1847.00 0.4839
1856.00 0.4839
1869.00 0.4839
1884.00 0.4839
1904.00 0.4839
1926.00 0.4839
1938.00 0.4839
1959.00 0.4839
1965.00 0.4839
1979.00 0.4839
2007.00 0.4839
2011.00 0.4839
2015.00 0.4585
2030.00 0.4585
2034.00 0.4315
2048.00 0.4315
2051.00 0.4315
2057.00 0.4315
2065.00 0.4315
2093.00 0.3955
2128.00 0.3955
2144.00 0.3955
2170.00 0.3955
2175.00 0.3955
2372.00 0.3390
2380.00 0.3390
2388.00 0.3390
2456.00 0.2260
2551.00 0.2260
2556.00 0.2260