Getting started#
In this notebook, we use TorchSurv
to train a model that predicts relative risk of breast cancer recurrence. We use a public data set, the German Breast Cancer Study Group 2 (GBSG2). After training the model, we evaluate the predictive performance using evaluation metrics implemented in TorchSurv
.
We first load the dataset using the package lifelines. The GBSG2 dataset contains features and recurrence free survival time (in days) for 686 women undergoing hormonal treatment.
Dependencies#
To run this notebook, dependencies must be installed. the recommended method is to use our developpment conda environment (preffered). Instruction can be found here to install all optional dependancies. The other method is to install only required packages using the command line below:
[1]:
# Install only required packages (optional)
# %pip install lifelines
# %pip install matplotlib
# %pip install sklearn
# %pip install pandas
[2]:
import warnings
warnings.filterwarnings("ignore")
[3]:
import lifelines
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
# Our package
from torchsurv.loss.cox import neg_partial_log_likelihood
from torchsurv.loss.weibull import neg_log_likelihood, log_hazard, survival_function
from torchsurv.metrics.brier_score import BrierScore
from torchsurv.metrics.cindex import ConcordanceIndex
from torchsurv.metrics.auc import Auc
from torchsurv.stats.kaplan_meier import KaplanMeierEstimator
# PyTorch boilerplate - see https://github.com/Novartis/torchsurv/blob/main/docs/notebooks/helpers_introduction.py
from helpers_introduction import Custom_dataset, plot_losses
[4]:
# Constant parameters accross models
# Detect available accelerator; Downgrade batch size if only CPU available
if any([torch.cuda.is_available(), torch.backends.mps.is_available()]):
print("CUDA-enabled GPU/TPU is available.")
BATCH_SIZE = 128 # batch size for training
else:
print("No CUDA-enabled GPU found, using CPU.")
BATCH_SIZE = 32 # batch size for training
EPOCHS = 100
LEARNING_RATE = 1e-2
CUDA-enabled GPU/TPU is available.
Dataset overview#
[5]:
# Load GBSG2 dataset
df = lifelines.datasets.load_gbsg2()
df.head(5)
[5]:
horTh | age | menostat | tsize | tgrade | pnodes | progrec | estrec | time | cens | |
---|---|---|---|---|---|---|---|---|---|---|
0 | no | 70 | Post | 21 | II | 3 | 48 | 66 | 1814 | 1 |
1 | yes | 56 | Post | 12 | II | 7 | 61 | 77 | 2018 | 1 |
2 | yes | 58 | Post | 35 | II | 9 | 52 | 271 | 712 | 1 |
3 | yes | 59 | Post | 17 | II | 4 | 60 | 29 | 1807 | 1 |
4 | no | 73 | Post | 35 | II | 1 | 26 | 65 | 772 | 1 |
The dataset contains the categorical features:
horTh
: hormonal therapy, a factor at two levels (yes and no).age
: age of the patients in years.menostat
: menopausal status, a factor at two levels pre (premenopausal) and post (postmenopausal).tsize
: tumor size (in mm).tgrade
: tumor grade, a ordered factor at levels I < II < III.pnodes
: number of positive nodes.progrec
: progesterone receptor (in fmol).estrec
: estrogen receptor (in fmol).
Additionally, it contains our survival targets:
time
: recurrence free survival time (in days).cens
: censoring indicator (0- censored, 1- event).
One common approach is to use a one hot encoder to convert them into numerical features. We then seperate the dataframes into features X
and labels y
. The following code also partitions the labels and features into training and testing cohorts.
Data preparation#
[6]:
df_onehot = pd.get_dummies(df, columns=["horTh", "menostat", "tgrade"]).astype("float")
df_onehot.drop(
["horTh_no", "menostat_Post", "tgrade_I"],
axis=1,
inplace=True,
)
df_onehot.head(5)
[6]:
age | tsize | pnodes | progrec | estrec | time | cens | horTh_yes | menostat_Pre | tgrade_II | tgrade_III | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 70.0 | 21.0 | 3.0 | 48.0 | 66.0 | 1814.0 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 |
1 | 56.0 | 12.0 | 7.0 | 61.0 | 77.0 | 2018.0 | 1.0 | 1.0 | 0.0 | 1.0 | 0.0 |
2 | 58.0 | 35.0 | 9.0 | 52.0 | 271.0 | 712.0 | 1.0 | 1.0 | 0.0 | 1.0 | 0.0 |
3 | 59.0 | 17.0 | 4.0 | 60.0 | 29.0 | 1807.0 | 1.0 | 1.0 | 0.0 | 1.0 | 0.0 |
4 | 73.0 | 35.0 | 1.0 | 26.0 | 65.0 | 772.0 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 |
[7]:
df_train, df_test = train_test_split(df_onehot, test_size=0.3)
df_train, df_val = train_test_split(df_train, test_size=0.3)
print(
f"(Sample size) Training:{len(df_train)} | Validation:{len(df_val)} |Testing:{len(df_test)}"
)
(Sample size) Training:336 | Validation:144 |Testing:206
Let us setup the dataloaders for training, validation and testing.
[8]:
# Dataloader
dataloader_train = DataLoader(
Custom_dataset(df_train), batch_size=BATCH_SIZE, shuffle=True
)
dataloader_val = DataLoader(
Custom_dataset(df_val), batch_size=len(df_val), shuffle=False
)
dataloader_test = DataLoader(
Custom_dataset(df_test), batch_size=len(df_test), shuffle=False
)
[9]:
# Sanity check
x, (event, time) = next(iter(dataloader_train))
num_features = x.size(1)
print(f"x (shape) = {x.shape}")
print(f"num_features = {num_features}")
print(f"event = {event.shape}")
print(f"time = {time.shape}")
x (shape) = torch.Size([128, 9])
num_features = 9
event = torch.Size([128])
time = torch.Size([128])
Section 1: Cox proportional hazards model#
In this section, we use the Cox proportional hazards model. Given covariate \(x_{i}\), the hazard of patient \(i\) has the form
The baseline hazard \(\lambda_{0}(t)\) is identical across subjects (i.e., has no dependency on \(i\)). The subject-specific risk of event occurrence is captured through the relative hazards \(\{\theta(x_{i})\}_{i = 1, \dots, N}\).
We train a multi-layer perceptron (MLP) to model the subject-specific risk of event occurrence, i.e., the log relative hazards \(\log\theta(x_{i})\). Patients with lower recurrence time are assumed to have higher risk of event.
Section 1.1: MLP model for log relative hazards#
[10]:
cox_model = torch.nn.Sequential(
torch.nn.BatchNorm1d(num_features), # Batch normalization
torch.nn.Linear(num_features, 32),
torch.nn.ReLU(),
torch.nn.Dropout(),
torch.nn.Linear(32, 64),
torch.nn.ReLU(),
torch.nn.Dropout(),
torch.nn.Linear(64, 1), # Estimating log hazards for Cox models
)
Section 1.2: MLP model training#
[11]:
torch.manual_seed(42)
# Init optimizer for Cox
optimizer = torch.optim.Adam(cox_model.parameters(), lr=LEARNING_RATE)
# Initiate empty list to store the loss on the train and validation sets
train_losses = []
val_losses = []
# training loop
for epoch in range(EPOCHS):
epoch_loss = torch.tensor(0.0)
for i, batch in enumerate(dataloader_train):
x, (event, time) = batch
optimizer.zero_grad()
log_hz = cox_model(x) # shape = (16, 1)
loss = neg_partial_log_likelihood(log_hz, event, time, reduction="mean")
loss.backward()
optimizer.step()
epoch_loss += loss.detach()
if epoch % (EPOCHS // 10) == 0:
print(f"Epoch: {epoch:03}, Training loss: {epoch_loss:0.2f}")
# Reccord loss on train and test sets
epoch_loss /= i + 1
train_losses.append(epoch_loss)
with torch.no_grad():
x, (event, time) = next(iter(dataloader_val))
val_losses.append(
neg_partial_log_likelihood(cox_model(x), event, time, reduction="mean")
)
Epoch: 000, Training loss: 12.55
Epoch: 010, Training loss: 12.40
Epoch: 020, Training loss: 12.02
Epoch: 030, Training loss: 11.90
Epoch: 040, Training loss: 11.97
Epoch: 050, Training loss: 11.85
Epoch: 060, Training loss: 11.68
Epoch: 070, Training loss: 11.85
Epoch: 080, Training loss: 11.74
Epoch: 090, Training loss: 11.88
We can visualize the training and validation losses.
[12]:
plot_losses(train_losses, val_losses, "Cox")
Section 1.3: Cox proportional hazards model evaluation#
We evaluate the predictive performance of the model using
the concordance index (C-index), which measures the the probability that a model correctly predicts which of two comparable samples will experience an event first based on their estimated risk scores,
the Area Under the Receiver Operating Characteristic Curve (AUC), which measures the probability that a model correctly predicts which of two comparable samples will experience an event by time t based on their estimated risk scores.
We cannot use the Brier score because this model is not able to estimate the survival function.
We start by evaluating the subject-specific relative hazards on the test set
[13]:
cox_model.eval()
with torch.no_grad():
# test event and test time of length n
x, (event, time) = next(iter(dataloader_test))
log_hz = cox_model(x) # log hazard of length n
We obtain the concordance index, and its confidence interval
[14]:
# Concordance index
cox_cindex = ConcordanceIndex()
print("Cox model performance:")
print(f"Concordance-index = {cox_cindex(log_hz, event, time)}")
print(f"Confidence interval = {cox_cindex.confidence_interval()}")
Cox model performance:
Concordance-index = 0.639417290687561
Confidence interval = tensor([0.5143, 0.7645])
We can also test whether the observed concordance index is greater than 0.5. The statistical test is specified with H0: c-index = 0.5 and Ha: c-index > 0.5. The p-value of the statistical test is
[15]:
# H0: cindex = 0.5, Ha: cindex > 0.5
print("p-value = {}".format(cox_cindex.p_value(alternative="greater")))
p-value = 0.014473557472229004
For time-dependent prediction (e.g., 5-year mortality), the C-index is not a proper measure. Instead, it is recommended to use the AUC. The probability to correctly predicts which of two comparable patients will experience an event by 5-year based on their estimated risk scores is the AUC evaluated at 5-year (1825 days) obtained with
[16]:
cox_auc = Auc()
new_time = torch.tensor(1825.0)
# auc evaluated at new time = 1825, 5 year
print(f"AUC 5-yr = {cox_auc(log_hz, event, time, new_time=new_time)}")
print(f"AUC 5-yr (conf int.) = {cox_auc.confidence_interval()}")
AUC 5-yr = tensor([0.6342])
AUC 5-yr (conf int.) = tensor([0.5767, 0.6917])
As before, we can test whether the observed Auc at 5-year is greater than 0.5. The statistical test is specified with H0: auc = 0.5 and Ha: auc > 0.5. The p-value of the statistical test is
[17]:
print(f"AUC (p_value) = {cox_auc.p_value()}")
AUC (p_value) = tensor([4.7684e-06])
Section 2: Weibull accelerated failure time (AFT) model#
In this section, we use the Weibull accelerated failure (AFT) model. Given covariate \(x_{i}\), the hazard of patient \(i\) at time \(t\) has the form
Given the hazard form, it can be shown that the event density follows a Weibull distribution parametrized by scale \(\lambda(x_{i})\) and shape \(\rho(x_{i})\). The subject-specific risk of event occurrence at time \(t\) is captured through the hazards \(\{\lambda (t|x_{i})\}_{i = 1, \dots, N}\). We train a multi-layer perceptron (MLP) to model the subject-specific log scale, \(\log \lambda(x_{i})\), and the log shape, \(\log\rho(x_{i})\).
Section 2.1: MLP model for log scale and log shape#
[18]:
# Same architecture than Cox model, beside outputs dimension
weibull_model = torch.nn.Sequential(
torch.nn.BatchNorm1d(num_features), # Batch normalization
torch.nn.Linear(num_features, 32),
torch.nn.ReLU(),
torch.nn.Dropout(),
torch.nn.Linear(32, 64),
torch.nn.ReLU(),
torch.nn.Dropout(),
torch.nn.Linear(64, 2), # Estimating log parameters for Weibull model
)
Section 2.2: MLP model training#
[19]:
torch.manual_seed(42)
# Init optimizer for Weibull
optimizer = torch.optim.Adam(weibull_model.parameters(), lr=LEARNING_RATE)
# Initialize empty list to store loss on train and validation sets
train_losses = []
val_losses = []
# training loop
for epoch in range(EPOCHS):
epoch_loss = torch.tensor(0.0)
for i, batch in enumerate(dataloader_train):
x, (event, time) = batch
optimizer.zero_grad()
log_params = weibull_model(x) # shape = (16, 2)
loss = neg_log_likelihood(log_params, event, time, reduction="mean")
loss.backward()
optimizer.step()
epoch_loss += loss.detach()
if epoch % (EPOCHS // 10) == 0:
print(f"Epoch: {epoch:03}, Training loss: {epoch_loss:0.2f}")
# Reccord losses for the following figure
train_losses.append(epoch_loss)
with torch.no_grad():
x, (event, time) = next(iter(dataloader_val))
val_losses.append(
neg_log_likelihood(weibull_model(x), event, time, reduction="mean")
)
Epoch: 000, Training loss: 100312.30
Epoch: 010, Training loss: 19.79
Epoch: 020, Training loss: 20.48
Epoch: 030, Training loss: 16.80
Epoch: 040, Training loss: 17.24
Epoch: 050, Training loss: 17.22
Epoch: 060, Training loss: 16.75
Epoch: 070, Training loss: 16.55
Epoch: 080, Training loss: 16.76
Epoch: 090, Training loss: 16.69
We can visualize the training and validation losses.
[20]:
plot_losses(train_losses, val_losses, "Weibull")
Section 2.3: Weibull AFT model evaluation#
We evaluate the predictive performance of the model using
the C-index, which measures the the probability that a model correctly predicts which of two comparable samples will experience an event first based on their estimated risk scores,
the AUC, which measures the probability that a model correctly predicts which of two comparable samples will experience an event by time t based on their estimated risk scores, and
the Brier score, which measures the model’s calibration by calculating the mean square error between the estimated survival function and the empirical (i.e., in-sample) event status.
We start by obtaining the subject-specific log hazard and survival probability at every time \(t\) observed on the test set
[21]:
weibull_model.eval()
with torch.no_grad():
# event and time of length n
x, (event, time) = next(iter(dataloader_test))
log_params = weibull_model(x) # shape = (n,2)
# Compute the log hazards from weibull log parameters
log_hz = log_hazard(log_params, time) # shape = (n,n)
# Compute the survival probability from weibull log parameters
surv = survival_function(log_params, time) # shape = (n,n)
We can evaluate the concordance index, its confidence interval and the p-value of the statistical test testing whether the c-index is greater than 0.5:
[22]:
# Concordance index
weibull_cindex = ConcordanceIndex()
print("Weibull model performance:")
print(f"Concordance-index = {weibull_cindex(log_hz, event, time)}")
print(f"Confidence interval = {weibull_cindex.confidence_interval()}")
# H0: cindex = 0.5, Ha: cindex >0.5
print(f"p-value = {weibull_cindex.p_value(alternative = 'greater')}")
Weibull model performance:
Concordance-index = 0.4327795207500458
Confidence interval = tensor([0.3040, 0.5615])
p-value = 0.8468805551528931
For time-dependent prediction (e.g., 5-year mortality), the C-index is not a proper measure. Instead, it is recommended to use the AUC. The probability to correctly predicts which of two comparable patients will experience an event by 5-year based on their estimated risk scores is the AUC evaluated at 5-year (1825 days) obtained with
[23]:
new_time = torch.tensor(1825.0)
# subject-specific log hazard at \5-yr
log_hz_t = log_hazard(log_params, time=new_time) # shape = (n)
weibull_auc = Auc()
# auc evaluated at new time = 1825, 5 year
print(f"AUC 5-yr = {weibull_auc(log_hz_t, event, time, new_time=new_time)}")
print(f"AUC 5-yr (conf int.) = {weibull_auc.confidence_interval()}")
print(f"AUC 5-yr (p value) = {weibull_auc.p_value(alternative='greater')}")
AUC 5-yr = tensor([0.4158])
AUC 5-yr (conf int.) = tensor([0.3684, 0.4632])
AUC 5-yr (p value) = tensor([0.9998])
Lastly, we can evaluate the time-dependent Brier score and the integrated Brier score
[24]:
brier_score = BrierScore()
# brier score at first 5 times
print(f"Brier score = {brier_score(surv, event, time)[:5]}")
print(f"Brier score (conf int.) = {brier_score.confidence_interval()[:,:5]}")
# integrated brier score
print(f"Integrated Brier score = {brier_score.integral()}")
Brier score = tensor([0.4088, 0.4081, 0.4071, 0.4390, 0.4405])
Brier score (conf int.) = tensor([[0.4044, 0.4021, 0.4000, 0.4294, 0.4303],
[0.4133, 0.4140, 0.4143, 0.4485, 0.4507]])
Integrated Brier score = 0.24420785903930664
We can test whether the time-dependent Brier score is smaller than what would be expected if the survival model was not providing accurate predictions beyond random chance. We use a bootstrap permutation test and obtain the p-value with:
[25]:
# H0: bs = bs0, Ha: bs < bs0; where bs0 is the expected brier score if the survival model was not providing accurate predictions beyond random chance.
# p-value for brier score at first 5 times
print(f"Brier score (p-val) = {brier_score.p_value(alternative = 'less')[:5]}")
Brier score (p-val) = tensor([0.7680, 0.9690, 0.6890, 0.7730, 0.9120])
Section 3: Models comparison#
We can compare the predictive performance of the Cox proportional hazards model against the Weibull AFT model.
Section 3.1: Concordance index#
The statistical test is formulated as follows, H0: cindex cox = cindex weibull, Ha: cindex cox > cindex weibull
[26]:
print(f"Cox cindex = {cox_cindex.cindex}")
print(f"Weibull cindex = {weibull_cindex.cindex}")
print("p-value = {}".format(cox_cindex.compare(weibull_cindex)))
Cox cindex = 0.639417290687561
Weibull cindex = 0.4327795207500458
p-value = 0.01457300502806902
Section 3.2: AUC at 5-year#
The statistical test is formulated as follows, H0: 5-yr auc cox = 5-yr auc weibull, Ha: 5-yr auc cox > 5-yr auc weibull
[27]:
print(f"Cox 5-yr AUC = {cox_auc.auc}")
print(f"Weibull 5-yr AUC = {weibull_auc.auc}")
print("p-value = {}".format(cox_auc.compare(weibull_auc)))
Cox 5-yr AUC = tensor([0.6342])
Weibull 5-yr AUC = tensor([0.4158])
p-value = tensor([4.9211e-08])
Section 4: Kaplan Meier#
[28]:
# Create a Kaplan-Meier estimator
km = KaplanMeierEstimator()
# Use our observed testing dataset
event = torch.tensor(df_test['cens'].values).bool()
time = torch.tensor(df_test['time'].values)
# Compute the estimator
km(event, time)
[29]:
#plot estimate
km.plot_km()
[30]:
# Print the survival values at each time step
km.print_survival_table()
Time Survival
----------------
16.00 1.0000
17.00 1.0000
18.00 1.0000
98.00 0.9951
113.00 0.9901
160.00 0.9852
177.00 0.9803
180.00 0.9753
181.00 0.9704
186.00 0.9704
191.00 0.9654
195.00 0.9604
223.00 0.9555
241.00 0.9505
247.00 0.9455
251.00 0.9405
273.00 0.9405
276.00 0.9405
286.00 0.9355
308.00 0.9305
316.00 0.9254
338.00 0.9204
343.00 0.9154
348.00 0.9104
350.00 0.9053
353.00 0.9003
358.00 0.8953
371.00 0.8902
377.00 0.8852
415.00 0.8802
420.00 0.8752
424.00 0.8752
426.00 0.8701
448.00 0.8650
460.00 0.8600
463.00 0.8600
475.00 0.8549
476.00 0.8498
481.00 0.8447
490.00 0.8396
515.00 0.8345
526.00 0.8345
536.00 0.8294
541.00 0.8294
544.00 0.8243
545.00 0.8191
547.00 0.8140
548.00 0.8088
552.00 0.8037
554.00 0.7985
559.00 0.7934
570.00 0.7934
573.00 0.7882
575.00 0.7830
578.00 0.7778
594.00 0.7726
596.00 0.7726
623.00 0.7726
624.00 0.7674
632.00 0.7621
637.00 0.7621
650.00 0.7568
657.00 0.7568
662.00 0.7515
687.00 0.7461
730.00 0.7408
733.00 0.7408
734.00 0.7408
737.00 0.7408
740.00 0.7408
745.00 0.7353
762.00 0.7298
784.00 0.7242
799.00 0.7187
805.00 0.7132
819.00 0.7076
827.00 0.7021
841.00 0.7021
855.00 0.7021
859.00 0.6965
883.00 0.6908
889.00 0.6852
890.00 0.6795
891.00 0.6738
933.00 0.6738
936.00 0.6738
940.00 0.6738
945.00 0.6680
964.00 0.6621
967.00 0.6621
969.00 0.6621
972.00 0.6621
974.00 0.6621
983.00 0.6559
991.00 0.6498
1059.00 0.6437
1062.00 0.6437
1078.00 0.6437
1090.00 0.6374
1093.00 0.6312
1100.00 0.6312
1105.00 0.6249
1108.00 0.6186
1152.00 0.6186
1162.00 0.6122
1170.00 0.6058
1174.00 0.5994
1212.00 0.5994
1218.00 0.5930
1219.00 0.5930
1222.00 0.5930
1253.00 0.5864
1264.00 0.5864
1280.00 0.5797
1283.00 0.5797
1296.00 0.5797
1323.00 0.5797
1329.00 0.5728
1343.00 0.5728
1351.00 0.5728
1357.00 0.5728
1358.00 0.5728
1363.00 0.5655
1364.00 0.5655
1371.00 0.5580
1420.00 0.5506
1427.00 0.5506
1434.00 0.5506
1441.00 0.5506
1460.00 0.5429
1469.00 0.5429
1472.00 0.5429
1481.00 0.5349
1486.00 0.5349
1490.00 0.5349
1499.00 0.5349
1502.00 0.5349
1525.00 0.5262
1527.00 0.5262
1570.00 0.5262
1604.00 0.5262
1617.00 0.5262
1629.00 0.5262
1645.00 0.5262
1653.00 0.5262
1666.00 0.5262
1680.00 0.5262
1684.00 0.5161
1693.00 0.5161
1703.00 0.5161
1707.00 0.5161
1730.00 0.5054
1751.00 0.5054
1756.00 0.5054
1767.00 0.5054
1771.00 0.5054
1791.00 0.5054
1818.00 0.5054
1826.00 0.5054
1838.00 0.5054
1858.00 0.5054
1884.00 0.5054
1897.00 0.5054
1904.00 0.5054
1918.00 0.4901
1926.00 0.4901
1938.00 0.4901
1956.00 0.4901
1959.00 0.4901
1965.00 0.4901
1975.00 0.4712
1977.00 0.4712
1981.00 0.4712
2007.00 0.4712
2010.00 0.4712
2014.00 0.4712
2027.00 0.4712
2030.00 0.4464
2048.00 0.4464
2065.00 0.4464
2093.00 0.4185
2132.00 0.4185
2144.00 0.4185
2192.00 0.4185
2227.00 0.4185
2233.00 0.4185
2237.00 0.4185
2297.00 0.4185
2372.00 0.3662
2388.00 0.3662
2449.00 0.3662
2456.00 0.3662
2467.00 0.3662
2539.00 0.3662