{
"cells": [
{
"cell_type": "markdown",
"id": "ca2213c2-6abc-4340-853a-7ab1e06e68d3",
"metadata": {},
"source": [
"# Getting started\n",
"\n",
"In this notebook, we use `TorchSurv` to train a model that predicts relative risk of breast cancer recurrence. We use a public data set, the [German Breast Cancer Study Group 2 (GBSG2)](https://paperswithcode.com/dataset/gbsg2). After training the model, we evaluate the predictive performance using evaluation metrics implemented in `TorchSurv`.\n",
"\n",
"\n",
"We first load the dataset using the package [lifelines](https://lifelines.readthedocs.io/en/latest/). The GBSG2 dataset contains features and recurrence free survival time (in days) for 686 women undergoing hormonal treatment. \n",
"\n",
"### Dependencies\n",
"\n",
"To run this notebook, dependencies must be installed. the recommended method is to use our developpment conda environment (**preffered**). Instruction can be found [here](https://opensource.nibr.com/torchsurv/devnotes.html#set-up-a-development-environment-via-conda) to install all optional dependancies. The other method is to install only required packages using the command line below:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "160c8e19",
"metadata": {},
"outputs": [],
"source": [
"# Install only required packages (optional)\n",
"# %pip install lifelines\n",
"# %pip install matplotlib\n",
"# %pip install sklearn\n",
"# %pip install pandas\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "013dbcb4",
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2601dd00-7bd2-49d5-9bdf-a84205872890",
"metadata": {},
"outputs": [],
"source": [
"import lifelines\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"from torch.utils.data import DataLoader\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Our package\n",
"from torchsurv.loss.cox import neg_partial_log_likelihood\n",
"from torchsurv.loss.weibull import neg_log_likelihood, log_hazard, survival_function\n",
"from torchsurv.metrics.brier_score import BrierScore\n",
"from torchsurv.metrics.cindex import ConcordanceIndex\n",
"from torchsurv.metrics.auc import Auc\n",
"\n",
"# PyTorch boilerplate - see https://github.com/Novartis/torchsurv/blob/main/docs/notebooks/helpers_introduction.py\n",
"from helpers_introduction import Custom_dataset, plot_losses"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d7a98ea2-100f-43ef-8c45-c786ddcd313e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CUDA-enabled GPU/TPU is available.\n"
]
}
],
"source": [
"# Constant parameters accross models\n",
"# Detect available accelerator; Downgrade batch size if only CPU available\n",
"if any([torch.cuda.is_available(), torch.backends.mps.is_available()]):\n",
" print(\"CUDA-enabled GPU/TPU is available.\")\n",
" BATCH_SIZE = 128 # batch size for training\n",
"else:\n",
" print(\"No CUDA-enabled GPU found, using CPU.\")\n",
" BATCH_SIZE = 32 # batch size for training\n",
"\n",
"EPOCHS = 100\n",
"LEARNING_RATE = 1e-2"
]
},
{
"cell_type": "markdown",
"id": "38dd4c6e-2934-44f5-88fa-1d9d02032fc3",
"metadata": {},
"source": [
"## Dataset overview"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1df49737-dc02-4d6b-acd7-d03b79f18a29",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
horTh
\n",
"
age
\n",
"
menostat
\n",
"
tsize
\n",
"
tgrade
\n",
"
pnodes
\n",
"
progrec
\n",
"
estrec
\n",
"
time
\n",
"
cens
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
no
\n",
"
70
\n",
"
Post
\n",
"
21
\n",
"
II
\n",
"
3
\n",
"
48
\n",
"
66
\n",
"
1814
\n",
"
1
\n",
"
\n",
"
\n",
"
1
\n",
"
yes
\n",
"
56
\n",
"
Post
\n",
"
12
\n",
"
II
\n",
"
7
\n",
"
61
\n",
"
77
\n",
"
2018
\n",
"
1
\n",
"
\n",
"
\n",
"
2
\n",
"
yes
\n",
"
58
\n",
"
Post
\n",
"
35
\n",
"
II
\n",
"
9
\n",
"
52
\n",
"
271
\n",
"
712
\n",
"
1
\n",
"
\n",
"
\n",
"
3
\n",
"
yes
\n",
"
59
\n",
"
Post
\n",
"
17
\n",
"
II
\n",
"
4
\n",
"
60
\n",
"
29
\n",
"
1807
\n",
"
1
\n",
"
\n",
"
\n",
"
4
\n",
"
no
\n",
"
73
\n",
"
Post
\n",
"
35
\n",
"
II
\n",
"
1
\n",
"
26
\n",
"
65
\n",
"
772
\n",
"
1
\n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" horTh age menostat tsize tgrade pnodes progrec estrec time cens\n",
"0 no 70 Post 21 II 3 48 66 1814 1\n",
"1 yes 56 Post 12 II 7 61 77 2018 1\n",
"2 yes 58 Post 35 II 9 52 271 712 1\n",
"3 yes 59 Post 17 II 4 60 29 1807 1\n",
"4 no 73 Post 35 II 1 26 65 772 1"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Load GBSG2 dataset\n",
"df = lifelines.datasets.load_gbsg2()\n",
"df.head(5)"
]
},
{
"cell_type": "markdown",
"id": "8f23ce41-c0eb-4c30-83f3-2a2d45dcf097",
"metadata": {},
"source": [
"The dataset contains the categorical features: \n",
"\n",
"- `horTh`: hormonal therapy, a factor at two levels (yes and no).\n",
"- `age`: age of the patients in years.\n",
"- `menostat`: menopausal status, a factor at two levels pre (premenopausal) and post (postmenopausal).\n",
"- `tsize`: tumor size (in mm).\n",
"- `tgrade`: tumor grade, a ordered factor at levels I < II < III.\n",
"- `pnodes`: number of positive nodes.\n",
"- `progrec`: progesterone receptor (in fmol).\n",
"- `estrec`: estrogen receptor (in fmol).\n",
"\n",
"Additionally, it contains our survival targets:\n",
"\n",
"- `time`: recurrence free survival time (in days).\n",
"- `cens`: censoring indicator (0- censored, 1- event).\n",
"\n",
"One common approach is to use a [one hot encoder](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.get_dummies.html) to convert them into numerical features. We then seperate the dataframes into features `X` and labels `y`. The following code also partitions the labels and features into training and testing cohorts."
]
},
{
"cell_type": "markdown",
"id": "34132fea-daa6-46a5-8429-16df73886a51",
"metadata": {},
"source": [
"## Data preparation"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7a5fd9ef-2643-46b7-9c98-05ff919026ea",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
age
\n",
"
tsize
\n",
"
pnodes
\n",
"
progrec
\n",
"
estrec
\n",
"
time
\n",
"
cens
\n",
"
horTh_yes
\n",
"
menostat_Pre
\n",
"
tgrade_II
\n",
"
tgrade_III
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
70.0
\n",
"
21.0
\n",
"
3.0
\n",
"
48.0
\n",
"
66.0
\n",
"
1814.0
\n",
"
1.0
\n",
"
0.0
\n",
"
0.0
\n",
"
1.0
\n",
"
0.0
\n",
"
\n",
"
\n",
"
1
\n",
"
56.0
\n",
"
12.0
\n",
"
7.0
\n",
"
61.0
\n",
"
77.0
\n",
"
2018.0
\n",
"
1.0
\n",
"
1.0
\n",
"
0.0
\n",
"
1.0
\n",
"
0.0
\n",
"
\n",
"
\n",
"
2
\n",
"
58.0
\n",
"
35.0
\n",
"
9.0
\n",
"
52.0
\n",
"
271.0
\n",
"
712.0
\n",
"
1.0
\n",
"
1.0
\n",
"
0.0
\n",
"
1.0
\n",
"
0.0
\n",
"
\n",
"
\n",
"
3
\n",
"
59.0
\n",
"
17.0
\n",
"
4.0
\n",
"
60.0
\n",
"
29.0
\n",
"
1807.0
\n",
"
1.0
\n",
"
1.0
\n",
"
0.0
\n",
"
1.0
\n",
"
0.0
\n",
"
\n",
"
\n",
"
4
\n",
"
73.0
\n",
"
35.0
\n",
"
1.0
\n",
"
26.0
\n",
"
65.0
\n",
"
772.0
\n",
"
1.0
\n",
"
0.0
\n",
"
0.0
\n",
"
1.0
\n",
"
0.0
\n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age tsize pnodes progrec estrec time cens horTh_yes \\\n",
"0 70.0 21.0 3.0 48.0 66.0 1814.0 1.0 0.0 \n",
"1 56.0 12.0 7.0 61.0 77.0 2018.0 1.0 1.0 \n",
"2 58.0 35.0 9.0 52.0 271.0 712.0 1.0 1.0 \n",
"3 59.0 17.0 4.0 60.0 29.0 1807.0 1.0 1.0 \n",
"4 73.0 35.0 1.0 26.0 65.0 772.0 1.0 0.0 \n",
"\n",
" menostat_Pre tgrade_II tgrade_III \n",
"0 0.0 1.0 0.0 \n",
"1 0.0 1.0 0.0 \n",
"2 0.0 1.0 0.0 \n",
"3 0.0 1.0 0.0 \n",
"4 0.0 1.0 0.0 "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_onehot = pd.get_dummies(df, columns=[\"horTh\", \"menostat\", \"tgrade\"]).astype(\"float\")\n",
"df_onehot.drop(\n",
" [\"horTh_no\", \"menostat_Post\", \"tgrade_I\"],\n",
" axis=1,\n",
" inplace=True,\n",
")\n",
"df_onehot.head(5)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0f8b7f3b-fb2a-4d74-ac99-8f6390b2f5eb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(Sample size) Training:336 | Validation:144 |Testing:206\n"
]
}
],
"source": [
"df_train, df_test = train_test_split(df_onehot, test_size=0.3)\n",
"df_train, df_val = train_test_split(df_train, test_size=0.3)\n",
"print(\n",
" f\"(Sample size) Training:{len(df_train)} | Validation:{len(df_val)} |Testing:{len(df_test)}\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "00ad6603-0dff-4991-992a-081ba9a4fafa",
"metadata": {},
"source": [
"Let us setup the dataloaders for training, validation and testing."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "326c03fc-91f1-493b-a9ba-820de17fb2f8",
"metadata": {},
"outputs": [],
"source": [
"# Dataloader\n",
"dataloader_train = DataLoader(\n",
" Custom_dataset(df_train), batch_size=BATCH_SIZE, shuffle=True\n",
")\n",
"dataloader_val = DataLoader(\n",
" Custom_dataset(df_val), batch_size=len(df_val), shuffle=False\n",
")\n",
"dataloader_test = DataLoader(\n",
" Custom_dataset(df_test), batch_size=len(df_test), shuffle=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "570386fb-f0ea-4061-bae2-11b274e7f851",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x (shape) = torch.Size([128, 9])\n",
"num_features = 9\n",
"event = torch.Size([128])\n",
"time = torch.Size([128])\n"
]
}
],
"source": [
"# Sanity check\n",
"x, (event, time) = next(iter(dataloader_train))\n",
"num_features = x.size(1)\n",
"\n",
"print(f\"x (shape) = {x.shape}\")\n",
"print(f\"num_features = {num_features}\")\n",
"print(f\"event = {event.shape}\")\n",
"print(f\"time = {time.shape}\")"
]
},
{
"cell_type": "markdown",
"id": "6b53d40d-d2c4-4dd7-bb85-97d4e946c356",
"metadata": {},
"source": [
"## Section 1: Cox proportional hazards model\n",
"\n",
"In this section, we use the [Cox proportional hazards model](../_autosummary/torchsurv.loss.cox.html). Given covariate $x_{i}$, the hazard of patient $i$ has the form\n",
"$$\n",
"\\lambda (t|x_{i}) =\\lambda_{0}(t)\\theta(x_{i})\n",
"$$\n",
"The baseline hazard $\\lambda_{0}(t)$ is identical across subjects (i.e., has no dependency on $i$). The subject-specific risk of event occurrence is captured through the relative hazards $\\{\\theta(x_{i})\\}_{i = 1, \\dots, N}$.\n",
"\n",
"We train a multi-layer perceptron (MLP) to model the subject-specific risk of event occurrence, i.e., the log relative hazards $\\log\\theta(x_{i})$. Patients with lower recurrence time are assumed to have higher risk of event. "
]
},
{
"cell_type": "markdown",
"id": "46343fe0",
"metadata": {},
"source": [
"### Section 1.1: MLP model for log relative hazards"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9c2bd89a-c90a-4795-aab5-b5c21906a0de",
"metadata": {},
"outputs": [],
"source": [
"cox_model = torch.nn.Sequential(\n",
" torch.nn.BatchNorm1d(num_features), # Batch normalization\n",
" torch.nn.Linear(num_features, 32),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Dropout(),\n",
" torch.nn.Linear(32, 64),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Dropout(),\n",
" torch.nn.Linear(64, 1), # Estimating log hazards for Cox models\n",
")"
]
},
{
"cell_type": "markdown",
"id": "97c90244",
"metadata": {},
"source": [
"### Section 1.2: MLP model training"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d7889dc1-1cfa-424e-a586-481cbc789581",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 000, Training loss: 12.75\n",
"Epoch: 010, Training loss: 12.02\n",
"Epoch: 020, Training loss: 11.79\n",
"Epoch: 030, Training loss: 11.84\n",
"Epoch: 040, Training loss: 11.61\n",
"Epoch: 050, Training loss: 11.61\n",
"Epoch: 060, Training loss: 11.46\n",
"Epoch: 070, Training loss: 11.57\n",
"Epoch: 080, Training loss: 11.56\n",
"Epoch: 090, Training loss: 11.20\n"
]
}
],
"source": [
"torch.manual_seed(42)\n",
"\n",
"# Init optimizer for Cox\n",
"optimizer = torch.optim.Adam(cox_model.parameters(), lr=LEARNING_RATE)\n",
"\n",
"# Initiate empty list to store the loss on the train and validation sets\n",
"train_losses = []\n",
"val_losses = []\n",
"\n",
"# training loop\n",
"for epoch in range(EPOCHS):\n",
" epoch_loss = torch.tensor(0.0)\n",
" for i, batch in enumerate(dataloader_train):\n",
" x, (event, time) = batch\n",
" optimizer.zero_grad()\n",
" log_hz = cox_model(x) # shape = (16, 1)\n",
" loss = neg_partial_log_likelihood(log_hz, event, time, reduction=\"mean\")\n",
" loss.backward()\n",
" optimizer.step()\n",
" epoch_loss += loss.detach()\n",
"\n",
" if epoch % (EPOCHS // 10) == 0:\n",
" print(f\"Epoch: {epoch:03}, Training loss: {epoch_loss:0.2f}\")\n",
"\n",
" # Reccord loss on train and test sets\n",
" epoch_loss /= i + 1\n",
" train_losses.append(epoch_loss)\n",
" with torch.no_grad():\n",
" x, (event, time) = next(iter(dataloader_val))\n",
" val_losses.append(\n",
" neg_partial_log_likelihood(cox_model(x), event, time, reduction=\"mean\")\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "0e2bdd8c-f84c-4003-98f4-220ddab518d1",
"metadata": {},
"source": [
"We can visualize the training and validation losses."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "21afc248-303a-4156-8d9c-b97be3e0a56b",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_losses(train_losses, val_losses, \"Cox\")"
]
},
{
"cell_type": "markdown",
"id": "bd881d14-9646-48e0-bcc3-f29be358161f",
"metadata": {},
"source": [
"### Section 1.3: Cox proportional hazards model evaluation\n",
"\n",
"We evaluate the predictive performance of the model using \n",
"\n",
"* the [concordance index](../_autosummary/torchsurv.metrics.cindex.html) (C-index), which measures the the probability that a model correctly predicts which of two comparable samples will experience an event first based on their estimated risk scores,\n",
"* the [Area Under the Receiver Operating Characteristic Curve](../_autosummary/torchsurv.metrics.auc.html) (AUC), which measures the probability that a model correctly predicts which of two comparable samples will experience an event by time t based on their estimated risk scores.\n",
"\n",
"We cannot use the Brier score because this model is not able to estimate the survival function."
]
},
{
"cell_type": "markdown",
"id": "0d2e7996",
"metadata": {},
"source": [
"We start by evaluating the subject-specific relative hazards on the test set "
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "272a997d-a978-4e9b-bb0b-d90e4f03a530",
"metadata": {},
"outputs": [],
"source": [
"cox_model.eval()\n",
"with torch.no_grad():\n",
" # test event and test time of length n\n",
" x, (event, time) = next(iter(dataloader_test))\n",
" log_hz = cox_model(x) # log hazard of length n"
]
},
{
"cell_type": "markdown",
"id": "77bd0fe9",
"metadata": {},
"source": [
"We obtain the concordance index, and its confidence interval"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "3c9ad489-9e53-40ac-8931-8941597760a8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cox model performance:\n",
"Concordance-index = 0.6610040068626404\n",
"Confidence interval = tensor([0.5505, 0.7716])\n"
]
}
],
"source": [
"# Concordance index\n",
"cox_cindex = ConcordanceIndex()\n",
"print(\"Cox model performance:\")\n",
"print(f\"Concordance-index = {cox_cindex(log_hz, event, time)}\")\n",
"print(f\"Confidence interval = {cox_cindex.confidence_interval()}\")"
]
},
{
"cell_type": "markdown",
"id": "507b410a",
"metadata": {},
"source": [
"We can also test whether the observed concordance index is greater than 0.5. The statistical test is specified with H0: c-index = 0.5 and Ha: c-index > 0.5. The p-value of the statistical test is"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "7d34ba82",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"p-value = 0.002155780792236328\n"
]
}
],
"source": [
"# H0: cindex = 0.5, Ha: cindex > 0.5\n",
"print(\"p-value = {}\".format(cox_cindex.p_value(alternative=\"greater\")))"
]
},
{
"cell_type": "markdown",
"id": "a60919a9",
"metadata": {},
"source": [
"For time-dependent prediction (e.g., 5-year mortality), the C-index is not a proper measure. Instead, it is recommended to use the AUC. The probability to correctly predicts which of two comparable patients will experience an event by 5-year based on their estimated risk scores is the AUC evaluated at 5-year (1825 days) obtained with"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "907312f9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AUC 5-yr = tensor([0.6954])\n",
"AUC 5-yr (conf int.) = tensor([0.6360, 0.7548])\n"
]
}
],
"source": [
"cox_auc = Auc()\n",
"\n",
"new_time = torch.tensor(1825.0)\n",
"\n",
"# auc evaluated at new time = 1825, 5 year\n",
"print(f\"AUC 5-yr = {cox_auc(log_hz, event, time, new_time=new_time)}\")\n",
"print(f\"AUC 5-yr (conf int.) = {cox_auc.confidence_interval()}\")"
]
},
{
"cell_type": "markdown",
"id": "41e7e69f",
"metadata": {},
"source": [
"As before, we can test whether the observed Auc at 5-year is greater than 0.5. The statistical test is specified with H0: auc = 0.5 and Ha: auc > 0.5. The p-value of the statistical test is"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "702e5a74",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AUC (p_value) = tensor([0.])\n"
]
}
],
"source": [
"print(f\"AUC (p_value) = {cox_auc.p_value()}\")"
]
},
{
"cell_type": "markdown",
"id": "b8f517e6-b0a4-4fbc-aac5-b500b4aca169",
"metadata": {},
"source": [
"## Section 2: Weibull accelerated failure time (AFT) model"
]
},
{
"cell_type": "markdown",
"id": "769ddcf5",
"metadata": {},
"source": [
"In this section, we use the [Weibull accelerated failure (AFT) model](../_autosummary/torchsurv.loss.weibull.html). Given covariate $x_{i}$, the hazard of patient $i$ at time $t$ has the form\n",
"$$\n",
"\\lambda (t|x_{i}) = \\frac{\\rho(x_{i}) } {\\lambda(x_{i}) } + \\left(\\frac{t}{\\lambda(x_{i})}\\right)^{\\rho(x_{i}) - 1}\n",
"$$\n",
"\n",
"Given the hazard form, it can be shown that the event density follows a Weibull distribution parametrized by scale $\\lambda(x_{i})$ and shape $\\rho(x_{i})$. The subject-specific risk of event occurrence at time $t$ is captured through the hazards $\\{\\lambda (t|x_{i})\\}_{i = 1, \\dots, N}$. We train a multi-layer perceptron (MLP) to model the subject-specific log scale, $\\log \\lambda(x_{i})$, and the log shape, $\\log\\rho(x_{i})$. "
]
},
{
"cell_type": "markdown",
"id": "a580702e",
"metadata": {},
"source": [
"### Section 2.1: MLP model for log scale and log shape"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "35b92c10-e5fb-491d-9e27-743bcffdced2",
"metadata": {},
"outputs": [],
"source": [
"# Same architecture than Cox model, beside outputs dimension\n",
"weibull_model = torch.nn.Sequential(\n",
" torch.nn.BatchNorm1d(num_features), # Batch normalization\n",
" torch.nn.Linear(num_features, 32),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Dropout(),\n",
" torch.nn.Linear(32, 64),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Dropout(),\n",
" torch.nn.Linear(64, 2), # Estimating log parameters for Weibull model\n",
")"
]
},
{
"cell_type": "markdown",
"id": "e96c6985",
"metadata": {},
"source": [
"### Section 2.2: MLP model training"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "3d5c6f77-6245-42b0-ae48-33b57789b651",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 000, Training loss: 4903.93\n",
"Epoch: 010, Training loss: 20.11\n",
"Epoch: 020, Training loss: 19.83\n",
"Epoch: 030, Training loss: 19.17\n",
"Epoch: 040, Training loss: 17.98\n",
"Epoch: 050, Training loss: 18.01\n",
"Epoch: 060, Training loss: 18.66\n",
"Epoch: 070, Training loss: 17.93\n",
"Epoch: 080, Training loss: 18.28\n",
"Epoch: 090, Training loss: 17.48\n"
]
}
],
"source": [
"torch.manual_seed(42)\n",
"\n",
"# Init optimizer for Weibull\n",
"optimizer = torch.optim.Adam(weibull_model.parameters(), lr=LEARNING_RATE)\n",
"\n",
"# Initialize empty list to store loss on train and validation sets\n",
"train_losses = []\n",
"val_losses = []\n",
"\n",
"# training loop\n",
"for epoch in range(EPOCHS):\n",
" epoch_loss = torch.tensor(0.0)\n",
" for i, batch in enumerate(dataloader_train):\n",
" x, (event, time) = batch\n",
" optimizer.zero_grad()\n",
" log_params = weibull_model(x) # shape = (16, 2)\n",
" loss = neg_log_likelihood(log_params, event, time, reduction=\"mean\")\n",
" loss.backward()\n",
" optimizer.step()\n",
" epoch_loss += loss.detach()\n",
"\n",
" if epoch % (EPOCHS // 10) == 0:\n",
" print(f\"Epoch: {epoch:03}, Training loss: {epoch_loss:0.2f}\")\n",
"\n",
" # Reccord losses for the following figure\n",
" train_losses.append(epoch_loss)\n",
" with torch.no_grad():\n",
" x, (event, time) = next(iter(dataloader_val))\n",
" val_losses.append(\n",
" neg_log_likelihood(weibull_model(x), event, time, reduction=\"mean\")\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "4aba21b6",
"metadata": {},
"source": [
"We can visualize the training and validation losses."
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "243a4fa9-f751-46e7-83f3-e623bfd3518e",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_losses(train_losses, val_losses, \"Weibull\")"
]
},
{
"cell_type": "markdown",
"id": "86139132-d337-47b7-a8ad-eac1e255f91d",
"metadata": {},
"source": [
"### Section 2.3: Weibull AFT model evaluation\n",
"\n",
"We evaluate the predictive performance of the model using \n",
"\n",
"* the [C-index](../_autosummary/torchsurv.metrics.cindex.html), which measures the the probability that a model correctly predicts which of two comparable samples will experience an event first based on their estimated risk scores,\n",
"* the [AUC](../_autosummary/torchsurv.metrics.auc.html), which measures the probability that a model correctly predicts which of two comparable samples will experience an event by time t based on their estimated risk scores, and\n",
"* the [Brier score](../_autosummary/torchsurv.metrics.brier_score.html), which measures the model's calibration by calculating the mean square error between the estimated survival function and the empirical (i.e., in-sample) event status."
]
},
{
"cell_type": "markdown",
"id": "1cb226f5",
"metadata": {},
"source": [
"We start by obtaining the subject-specific log hazard and survival probability at every time $t$ observed on the test set"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "11599a1f-597b-4ebf-8a15-d3f9db1ebcca",
"metadata": {},
"outputs": [],
"source": [
"weibull_model.eval()\n",
"with torch.no_grad():\n",
" # event and time of length n\n",
" x, (event, time) = next(iter(dataloader_test))\n",
" log_params = weibull_model(x) # shape = (n,2)\n",
"\n",
"# Compute the log hazards from weibull log parameters\n",
"log_hz = log_hazard(log_params, time) # shape = (n,n)\n",
"\n",
"# Compute the survival probability from weibull log parameters\n",
"surv = survival_function(log_params, time) # shape = (n,n)"
]
},
{
"cell_type": "markdown",
"id": "7e309515",
"metadata": {},
"source": [
"We can evaluate the concordance index, its confidence interval and the p-value of the statistical test testing whether the c-index is greater than 0.5:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "afd7e7a5-c909-41eb-a48f-a9c9832eb33b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Weibull model performance:\n",
"Concordance-index = 0.4752066135406494\n",
"Confidence interval = tensor([0.3505, 0.5999])\n",
"p-value = 0.6516208648681641\n"
]
}
],
"source": [
"# Concordance index\n",
"weibull_cindex = ConcordanceIndex()\n",
"print(\"Weibull model performance:\")\n",
"print(f\"Concordance-index = {weibull_cindex(log_hz, event, time)}\")\n",
"print(f\"Confidence interval = {weibull_cindex.confidence_interval()}\")\n",
"\n",
"# H0: cindex = 0.5, Ha: cindex >0.5\n",
"print(f\"p-value = {weibull_cindex.p_value(alternative = 'greater')}\")"
]
},
{
"cell_type": "markdown",
"id": "d985e48c",
"metadata": {},
"source": [
"For time-dependent prediction (e.g., 5-year mortality), the C-index is not a proper measure. Instead, it is recommended to use the AUC. The probability to correctly predicts which of two comparable patients will experience an event by 5-year based on their estimated risk scores is the AUC evaluated at 5-year (1825 days) obtained with"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "ca4e6c56",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AUC 5-yr = tensor([0.5116])\n",
"AUC 5-yr (conf int.) = tensor([0.4587, 0.5645])\n",
"AUC 5-yr (p value) = tensor([0.3335])\n"
]
}
],
"source": [
"new_time = torch.tensor(1825.0)\n",
"\n",
"# subject-specific log hazard at \\5-yr\n",
"log_hz_t = log_hazard(log_params, time=new_time) # shape = (n)\n",
"weibull_auc = Auc()\n",
"\n",
"# auc evaluated at new time = 1825, 5 year\n",
"print(f\"AUC 5-yr = {weibull_auc(log_hz_t, event, time, new_time=new_time)}\")\n",
"print(f\"AUC 5-yr (conf int.) = {weibull_auc.confidence_interval()}\")\n",
"print(f\"AUC 5-yr (p value) = {weibull_auc.p_value(alternative='greater')}\")"
]
},
{
"cell_type": "markdown",
"id": "66b00e9f",
"metadata": {},
"source": [
"Lastly, we can evaluate the time-dependent Brier score and the integrated Brier score"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "b1d99480-b643-4836-acd3-7614fa903543",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Brier score = tensor([0.4384, 0.4445, 0.4472, 0.4525, 0.4509])\n",
"Brier score (conf int.) = tensor([[0.4324, 0.4372, 0.4391, 0.4433, 0.4411],\n",
" [0.4444, 0.4517, 0.4553, 0.4617, 0.4607]])\n",
"Integrated Brier score = 0.24550026655197144\n"
]
}
],
"source": [
"brier_score = BrierScore()\n",
"\n",
"# brier score at first 5 times\n",
"print(f\"Brier score = {brier_score(surv, event, time)[:5]}\")\n",
"print(f\"Brier score (conf int.) = {brier_score.confidence_interval()[:,:5]}\")\n",
"\n",
"# integrated brier score\n",
"print(f\"Integrated Brier score = {brier_score.integral()}\")"
]
},
{
"cell_type": "markdown",
"id": "0ca1d08c",
"metadata": {},
"source": [
"We can test whether the time-dependent Brier score is smaller than what would be expected if the survival model was not providing accurate predictions beyond random chance. We use a bootstrap permutation test and obtain the p-value with:"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "9754fdcd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Brier score (p-val) = tensor([0.0190, 0.1500, 0.4940, 0.5490, 0.4680])\n"
]
}
],
"source": [
"# H0: bs = bs0, Ha: bs < bs0; where bs0 is the expected brier score if the survival model was not providing accurate predictions beyond random chance.\n",
"\n",
"# p-value for brier score at first 5 times\n",
"print(f\"Brier score (p-val) = {brier_score.p_value(alternative = 'less')[:5]}\")"
]
},
{
"cell_type": "markdown",
"id": "31f7e7f6-8f07-4f82-8653-8d0d2d1ed84f",
"metadata": {},
"source": [
"## Section 3: Models comparison\n",
"\n",
"We can compare the predictive performance of the Cox proportional hazards model against the Weibull AFT model."
]
},
{
"cell_type": "markdown",
"id": "ed057468-ce75-4d3e-a825-71b55effcec8",
"metadata": {},
"source": [
"### Section 3.1: Concordance index\n",
"The statistical test is formulated as follows, H0: cindex cox = cindex weibull, Ha: cindex cox > cindex weibull"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "ea66963f-2537-4390-bb65-c773275b292b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cox cindex = 0.6610040068626404\n",
"Weibull cindex = 0.4752066135406494\n",
"p-value = 0.018631745129823685\n"
]
}
],
"source": [
"print(f\"Cox cindex = {cox_cindex.cindex}\")\n",
"print(f\"Weibull cindex = {weibull_cindex.cindex}\")\n",
"print(\"p-value = {}\".format(cox_cindex.compare(weibull_cindex)))"
]
},
{
"cell_type": "markdown",
"id": "f478e8df",
"metadata": {},
"source": [
"### Section 3.2: AUC at 5-year\n",
"\n",
"The statistical test is formulated as follows, H0: 5-yr auc cox = 5-yr auc weibull, Ha: 5-yr auc cox > 5-yr auc weibull"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "0c4e1651",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cox 5-yr AUC = tensor([0.6954])\n",
"Weibull 5-yr AUC = tensor([0.5116])\n",
"p-value = tensor([1.5964e-05])\n"
]
}
],
"source": [
"print(f\"Cox 5-yr AUC = {cox_auc.auc}\")\n",
"print(f\"Weibull 5-yr AUC = {weibull_auc.auc}\")\n",
"print(\"p-value = {}\".format(cox_auc.compare(weibull_auc)))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "torchsurv_env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}