Open In Colab Open In nbviewer

The best DNAm mortality predictor: CpGPTGrimAge3#

Table of Contents#

  1. Read Quick Setup Tutorial

  2. Setup Environment

  3. Load Data

  4. Load Model and Dependencies

  5. Prepare Data Objects

  6. Compute Protein Proxies

  7. Calculate CpGPTGrimAge3

0. Read Quick Setup Tutorial#

Before, going through this tutorial, please familiarize yourself with the quick setup tutorial.

1. Setup Environment#

CpGPT needs to be installed. The easiest is to use the following:

[ ]:
!pip install CpGPT --quiet

Please check out more instructions in the offical CpGPT repo.

We’ll import the necessary Python packages and set up our environment for CpGPT. We’ll be using a mix of standard data science libraries and CpGPT-specific modules. We’ll also set some important variables that will be used throughout the notebook. Pay attention to these as you may need to adjust them based on your specific setup and requirements.

[1]:
# Random seed for reproducibility
RANDOM_SEED = 42

# Directory paths
DEPENDENCIES_DIR = "../dependencies"
LLM_DEPENDENCIES_DIR = DEPENDENCIES_DIR + "/human"
DATA_DIR = "../data"
PROCESSED_DIR = "../data/tutorials/processed/predict_mortality"

MODEL_NAME = "proteins" # this is the name of the model checkpoint required for CpGPTGrimAge3
MODEL_CHECKPOINT_PATH = f"../dependencies/model/weights/{MODEL_NAME}.ckpt"
MODEL_CONFIG_PATH = f"../dependencies/model/config/{MODEL_NAME}.yaml"
MODEL_VOCAB_PATH = f"../dependencies/model/vocab/{MODEL_NAME}.json"

BETAS_PATH = "../data/cpgcorpus/raw/GSE237561/GPL13534/betas/QCDPB.arrow"
FILTERED_BETAS_PATH = "../data/cpgcorpus/raw/GSE237561/GPL13534/betas/QCDPB_filtered.arrow"
METADATA_PATH = "../data/cpgcorpus/raw/GSE237561/GPL13534/metadata/metadata.arrow"

# The maximum context length to give to the model
MAX_INPUT_LENGTH = 10_000 # you might wanna go higher hardware permitting

⚠️ Warning

It is recommended to have a GPU for inference as CPU might be slow.

Reconstructing the methylome for a few hundred samples might take up to one hour on a CPU. ⌛

This might be a great exercise in testing your patience.

1.2 Import packages#

[ ]:
# Standard library imports
import warnings
import os
import json

warnings.simplefilter(action="ignore", category=FutureWarning)

# Plotting imports
import gdown
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyaging as pya
import seaborn as sns
from tqdm.rich import tqdm

# Lightning imports
from lightning.pytorch import seed_everything

# cpgpt-specific imports
from cpgpt.data.components.cpgpt_datasaver import CpGPTDataSaver
from cpgpt.data.cpgpt_datamodule import CpGPTDataModule
from cpgpt.trainer.cpgpt_trainer import CpGPTTrainer
from cpgpt.data.components.dna_llm_embedder import DNALLMEmbedder
from cpgpt.data.components.illumina_methylation_prober import IlluminaMethylationProber
from cpgpt.infer.cpgpt_inferencer import CpGPTInferencer
from cpgpt.model.cpgpt_module import m_to_beta

# Set random seed for reproducibility
seed_everything(RANDOM_SEED, workers=True)
try:
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
except:
    pass
WARNING:torchao.kernel.intmm:Warning: Detected no triton, on systems without Triton certain kernels will not work
Seed set to 42
42

2. Load Data#

If you have your own data, please feel free to skip the following step but make sure it is saved in a .arrow format. Here, as an example target dataset, we’ll use GSE237561, which contains methylation profiling data from 126 peripheral whole blood samples collected from 26 individuals across two independent cohorts. These samples were collected at three timepoints: prior to clozapine initiation, 4-12 weeks after initiation, and 6 months after initiation.

[3]:
# First let's declare the inferencer
inferencer = CpGPTInferencer(dependencies_dir=DEPENDENCIES_DIR, data_dir=DATA_DIR)

inferencer.download_cpgcorpus_dataset("GSE237561")
cpgpt -CpGPTInferencer: Initializing class CpGPTInferencer.
cpgpt -CpGPTInferencer: Using device: cpu.
cpgpt -CpGPTInferencer: Using dependencies directory: ../dependencies
cpgpt -CpGPTInferencer: Using data directory: ../data
cpgpt -CpGPTInferencer: There are 19 CpGPT models available such as age, age_cot, average_adultweight, etc.
cpgpt -CpGPTInferencer: There are 2088 GSE datasets available such as GSE100184, GSE100208, GSE100209, etc.
cpgpt -CpGPTInferencer: Dataset GSE237561 already exists at ../data/cpgcorpus/raw/GSE237561 (skipping download).
[4]:
# Load betas matrix
df = pd.read_feather(BETAS_PATH)
df.set_index("GSM_ID", inplace=True)

df.head()
[4]:
cg00000029 cg00000108 cg00000109 cg00000165 cg00000236 cg00000289 cg00000292 cg00000321 cg00000363 cg00000622 ... rs7746156 rs798149 rs845016 rs877309 rs9292570 rs9363764 rs939290 rs951295 rs966367 rs9839873
GSM_ID
GSM7625568 0.592157 0.964411 0.899373 NaN 0.861353 NaN 0.894893 0.280911 0.386535 0.017273 ... 0.974810 0.022425 0.086804 0.031910 0.535025 0.548239 0.064924 0.536217 0.091576 0.787829
GSM7625569 0.657346 0.962779 0.920897 0.170290 0.868804 NaN 0.945775 0.303094 0.409573 0.015473 ... 0.977200 0.023908 0.093790 0.024449 0.524899 0.554774 0.048559 0.521857 0.081250 0.772327
GSM7625570 0.662022 0.964065 0.903984 0.180436 0.867933 NaN 0.915102 0.241706 0.392485 0.015086 ... 0.979132 0.022763 0.091286 0.028332 0.518550 0.584879 0.054734 0.529968 0.101047 0.765933
GSM7625571 0.599778 0.961087 0.903260 NaN 0.845338 NaN 0.910445 0.277753 0.405914 0.016514 ... 0.978393 0.019485 0.069463 0.024282 0.508195 0.569030 0.052701 0.508199 0.083252 0.787774
GSM7625572 0.556610 0.960655 0.893885 NaN 0.846172 NaN 0.916346 0.285945 0.404618 0.014193 ... 0.978121 0.019678 0.541755 0.982736 0.528156 0.524965 0.965231 0.041688 0.672611 0.924847

5 rows × 485578 columns

[5]:
# Load metadata
metadata = pd.read_feather(METADATA_PATH)
metadata.set_index("GSM_ID", inplace=True)

metadata.head()
[5]:
title geo_accession status submission_date last_update_date type channel_count source_name_ch1 organism_ch1 characteristics_ch1 ... cd8t:ch1 days.on.clozapine:ch1 gran:ch1 institute:ch1 mono:ch1 nk:ch1 participant_id:ch1 Sex:ch1 smokingscore:ch1 visit:ch1
GSM_ID
GSM7625568 genomic DNA from ID 0003 for 'a' visit GSM7625568 Public on Jul 17 2024 Jul 17 2023 Jul 17 2024 genomic 1 peripheral whole blood Homo sapiens participant_id: 0003 ... 0.124088032132122 0 0.623747063903508 KCL 0.0577366406423333 0.0181886611163226 0003 M 0.744140048408035 a
GSM7625569 genomic DNA from ID 0003 for 'b' visit GSM7625569 Public on Jul 17 2024 Jul 17 2023 Jul 17 2024 genomic 1 peripheral whole blood Homo sapiens participant_id: 0003 ... 0.140642508939063 42 0.532222309707589 KCL 0.0794055065754302 0.0165444940880789 0003 M 0.778521295727892 b
GSM7625570 genomic DNA from ID 0003 for 'd' visit GSM7625570 Public on Jul 17 2024 Jul 17 2023 Jul 17 2024 genomic 1 peripheral whole blood Homo sapiens participant_id: 0003 ... 0.112389247455345 84 0.610861991010505 KCL 0.0694956639909667 0.00363827604122652 0003 M 0.591346224352136 d
GSM7625571 genomic DNA from ID 0003 for 'e' visit GSM7625571 Public on Jul 17 2024 Jul 17 2023 Jul 17 2024 genomic 1 peripheral whole blood Homo sapiens participant_id: 0003 ... 0.0927185883637476 168 0.647578326303527 KCL 0.0886719923597006 0.0163329200193279 0003 M -0.771644717383253 e
GSM7625572 genomic DNA from ID 0005 for 'a' visit GSM7625572 Public on Jul 17 2024 Jul 17 2023 Jul 17 2024 genomic 1 peripheral whole blood Homo sapiens participant_id: 0005 ... 0.109718862397854 0 0.505235489278915 KCL 0.0514219148451151 0.0610566697720137 0005 M 1.32031205622569 a

5 rows × 57 columns

3. Load Model and Dependencies#

In order to calculate CpGPTGrimAge3, we need to calculate several DNA methylation plasma protein proxies with a finetuned model. The checkpoint is called proteins and it predicts 322 plasma protein levels which are normalized with mean 0 and variance 1 (μ = 0, σ² = 1).

3.1 Download Checkpoint and Configuration Files#

[6]:
# Download the checkpoint and configuration files
inferencer.download_model(MODEL_NAME)
cpgpt -CpGPTInferencer: Model checkpoint already exists at ../dependencies/model/weights/proteins.ckpt (skipping download).
cpgpt -CpGPTInferencer: Model config already exists at ../dependencies/model/config/proteins.yaml (skipping download).
cpgpt -CpGPTInferencer: Model vocabulary already exists at ../dependencies/model/vocab/proteins.json (skipping download).
cpgpt -CpGPTInferencer: Successfully downloaded model 'proteins'.

3.2 Load Model#

[7]:
# Load the model configuration
config = inferencer.load_cpgpt_config(MODEL_CONFIG_PATH)

# Load the model weights
model = inferencer.load_cpgpt_model(
    config,
    model_ckpt_path=MODEL_CHECKPOINT_PATH,
    strict_load=True,
)
cpgpt -CpGPTInferencer: Loaded CpGPT model config.
cpgpt -CpGPTInferencer: Instantiated CpGPT model from config.
cpgpt -CpGPTInferencer: Using device: cpu.
cpgpt -CpGPTInferencer: Loading checkpoint from: ../dependencies/model/weights/proteins.ckpt
cpgpt -CpGPTInferencer: Checkpoint loaded into the model.

3.3 Load Vocab#

The proteins model was trained with a vocabulary of about 4,689 CpG sites. Ideally, the data would be filtered to only include those features (or a subset thereof).

[8]:
# Load the vocabulary
with open(MODEL_VOCAB_PATH, "r") as f:
    vocab = json.load(f)
[9]:
model_input_features = vocab['input']

model_input_features[:5]
[9]:
['cg21830050', 'cg10381813', 'cg08067365', 'cg09864227', 'cg07213830']
[10]:
model_output_features = vocab['output']

model_output_features[:5]
[10]:
['cpgpt_tnfsf13', 'cpgpt_il33', 'cpgpt_calca', 'cpgpt_npy', 'cpgpt_hla-dra']

3.4 Download Dependencies#

[ ]:
inferencer.download_dependencies(species="human")

4. Prepare Data Objects#

4.1 Declare Embedder and Prober#

In order to retrieve the sample embeddings, we need to memory-map the data. This is done by using the CpGPTDataSaver class. We first need to define the DNALLMEmbedder and IlluminaMethylationProber classes, which contain the information about the DNA LLM Embeddings and the conversion between Illumina array probes to genomic locations, respectively.

[11]:
embedder = DNALLMEmbedder(dependencies_dir=LLM_DEPENDENCIES_DIR)
cpgpt -DNALLMEmbedder: Initializing class DNALLMEmbedder.
cpgpt -DNALLMEmbedder: Genome files will be stored under ../dependencies/human/genomes.
cpgpt -DNALLMEmbedder: DNA embeddings will be stored under ../dependencies/human/dna_embeddings and subdirectories.
cpgpt -DNALLMEmbedder: Ensembl metadata dictionary loaded successfully
[12]:
prober = IlluminaMethylationProber(dependencies_dir=LLM_DEPENDENCIES_DIR, embedder=embedder)
cpgpt -IlluminaMethylationProber: Initializing class IlluminaMethylationProber.
cpgpt -IlluminaMethylationProber: Illumina methylation manifest files will be stored under ../dependencies/human/manifests.
cpgpt -IlluminaMethylationProber: Illumina metadata dictionary loaded successfully.

4.2 Filter Vocab#

[13]:
common_features = list(set(model_input_features) & set(df.columns))
df_filtered = df.loc[:, common_features]
df_filtered.to_feather(FILTERED_BETAS_PATH)

df_filtered.head()
[13]:
cg26259818 cg26866325 cg23263937 cg15779600 cg16075139 cg04820362 cg11782409 cg07870237 cg01802397 cg04634427 ... cg01431830 cg18241647 cg14009688 cg26866482 cg23191950 cg03998338 cg06532212 cg23268677 cg20405584 cg14984434
GSM_ID
GSM7625568 0.900749 0.199913 0.107259 NaN NaN 0.819918 0.903952 0.252276 0.049561 0.781003 ... NaN 0.940283 0.644005 0.047353 0.251154 0.916748 0.488639 0.063751 0.027925 0.946631
GSM7625569 0.884256 0.148837 0.118698 NaN NaN 0.877955 0.885812 0.286435 0.055951 0.736555 ... NaN 0.932127 0.624015 0.044781 0.184450 0.921210 0.510761 0.057729 0.028602 0.928797
GSM7625570 0.869737 0.174457 0.083522 NaN NaN 0.885510 0.885832 0.298032 0.058532 0.709248 ... NaN 0.932670 0.684649 0.044844 0.177397 0.916927 0.503245 0.073991 0.031689 0.933811
GSM7625571 0.902602 0.196606 0.081571 NaN NaN 0.879783 0.891532 0.255523 0.056742 0.780125 ... NaN 0.950026 0.616295 0.047359 0.171478 0.940402 0.496729 0.073860 0.035205 0.942940
GSM7625572 0.899257 0.153354 0.107341 NaN NaN 0.829583 0.864145 0.282356 0.060830 0.691230 ... NaN 0.958417 0.649126 0.043928 0.232675 0.752656 0.506659 0.065188 0.047482 0.937227

5 rows × 4689 columns

4.3 Memory-Map Data#

[14]:
# Define datasaver
datasaver = CpGPTDataSaver(data_paths=FILTERED_BETAS_PATH, processed_dir=PROCESSED_DIR)

# Process the file
datasaver.process_files(prober, embedder)
cpgpt -CpGPTDataSaver: Initializing class CpGPTDataSaver.
cpgpt -CpGPTDataSaver: Dataset folders will be stored under ../data/tutorials/processed/predict_mortality.
cpgpt -CpGPTDataSaver: Loaded existing dataset metrics.
cpgpt -CpGPTDataSaver: Loaded existing genomic locations.
cpgpt -CpGPTDataSaver: Starting file processing.
cpgpt -CpGPTDataSaver: 1 files already processed. Skipping those.

4.4 Declare data module#

Let’s define one data module to use with our model:

[15]:
# Define datamodule
datamodule = CpGPTDataModule(
    predict_dir=PROCESSED_DIR,
    dependencies_dir=LLM_DEPENDENCIES_DIR,
    batch_size=1,
    num_workers=0,
    max_length=MAX_INPUT_LENGTH,
    dna_llm=config.data.dna_llm,
    dna_context_len=config.data.dna_context_len,
    sorting_strategy=config.data.sorting_strategy,
    pin_memory=False,
)
cpgpt -DNALLMEmbedder: Initializing class DNALLMEmbedder.
cpgpt -DNALLMEmbedder: Genome files will be stored under ../dependencies/human/genomes.
cpgpt -DNALLMEmbedder: DNA embeddings will be stored under ../dependencies/human/dna_embeddings and subdirectories.
cpgpt -DNALLMEmbedder: Ensembl metadata dictionary loaded successfully

5. Compute Protein Proxies#

5.1 Declare Trainer#

Given all models were trained under mixed precision, we’ll use the precision="16-mixed" argument.

[16]:
trainer = CpGPTTrainer(precision="16-mixed")
Using 16bit Automatic Mixed Precision (AMP)
/Users/lucascamillo/mambaforge/envs/cpgpt/lib/python3.10/site-packages/torch/amp/grad_scaler.py:132: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling.
  warnings.warn(
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

5.2 Get Predictions#

[17]:
# Get the target sample embeddings
forward_pass = trainer.predict(
    model=model,
    datamodule=datamodule,
    predict_mode="forward",
    return_keys=["pred_conditions"]
)

pred_conditions_df = pd.DataFrame(forward_pass['pred_conditions'], index=df.index, columns=model_output_features)
pred_conditions_df.head()
cpgpt -CpGPTDataset: Initializing class CpGPTDataset.
cpgpt -CpGPTDataset: Loaded existing dataset metrics.
/Users/lucascamillo/mambaforge/envs/cpgpt/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/lucascamillo/mambaforge/envs/cpgpt/lib/python3.10/site-packages/torch/amp/autocast_mode.py:265: UserWarning:
User provided device_type of 'cuda', but CUDA is not available. Disabling
  warnings.warn(
[17]:
cpgpt_tnfsf13 cpgpt_il33 cpgpt_calca cpgpt_npy cpgpt_hla-dra cpgpt_c1qa cpgpt_fth1 cpgpt_s100b cpgpt_ceacam5 cpgpt_mme ... cpgpt_ccl15 cpgpt_ccl14 cpgpt_ccl13 cpgpt_ccl11 cpgpt_ccl1 cpgpt_saa1 cpgpt_s100a9 cpgpt_s100a12 cpgpt_bdnf cpgpt_vgf
GSM_ID
GSM7625568 0.159939 -0.109247 -0.123497 0.268346 0.311079 -0.104481 0.307234 0.013596 -0.029327 0.266559 ... -0.100315 0.267266 0.010796 0.257249 0.012845 -0.144005 -0.199204 -0.288781 0.283817 -0.094640
GSM7625569 0.047051 -0.255467 -0.281614 0.243430 0.189521 -0.217191 0.198763 0.019737 -0.110035 0.173720 ... -0.218816 0.086749 -0.205998 0.042463 -0.109191 -0.249524 -0.164367 -0.338399 0.262606 -0.106812
GSM7625570 0.050562 -0.225324 -0.236831 0.222035 0.186863 -0.197240 0.215832 -0.016688 -0.087977 0.180057 ... -0.197247 0.127577 -0.130641 0.094274 -0.094252 -0.226423 -0.189340 -0.316972 0.234579 -0.125441
GSM7625571 0.194388 -0.103120 -0.060972 0.246503 0.228984 -0.092588 0.311946 -0.041283 -0.011208 0.265450 ... -0.103527 0.264676 0.018594 0.178806 0.027473 -0.150791 -0.140418 -0.208958 0.266402 -0.067963
GSM7625572 0.017105 -0.295679 -0.310758 0.232463 0.167696 -0.245472 0.177904 0.006786 -0.177094 0.163067 ... -0.246631 0.001140 -0.265735 -0.007235 -0.135532 -0.271881 -0.144465 -0.389837 0.226135 -0.125433

5 rows × 322 columns

6. Calculate CpGPTGrimAge3#

In the last step, we need to join together all features that are necessary to calculate CpGPTGrimAge3, namely:

  • age: chronological age of the sample;

  • GrimAge2 proxies: protein and lifestyle proxies from GrimAge version 2;

  • CpGPT protein proxies: protein levels predicted with CpGPT.

Join All Features#

[ ]:
# Get age from the metadata
age = metadata.loc[:, ['age:ch1']].astype(float)
age.columns = ['age']

# Add age to the filtered betas
df_filtered['age'] = metadata.loc[:, ['age:ch1']].astype(float)

# Get GrimAge2 proxies
grimage2_proxies = [
    "grimage2timp1",
    "grimage2packyrs",
    "grimage2logcrp",
    "grimage2b2m",
    "grimage2adm",
    "grimage2leptin",
    "grimage2gdf15",
    "grimage2pai1",
]
adata_grimage2 = pya.pp.df_to_adata(df_filtered, verbose=False)
pya.pred.predict_age(adata_grimage2, clock_names=grimage2_proxies, verbose=False)

# Combine all features
combined_df = pd.concat([age, adata_grimage2.obs, pred_conditions_df], axis=1)

combined_df.head()
age grimage2timp1 grimage2packyrs grimage2logcrp grimage2b2m grimage2adm grimage2leptin grimage2gdf15 cpgpt_tnfsf13 cpgpt_il33 ... cpgpt_ccl15 cpgpt_ccl14 cpgpt_ccl13 cpgpt_ccl11 cpgpt_ccl1 cpgpt_saa1 cpgpt_s100a9 cpgpt_s100a12 cpgpt_bdnf cpgpt_vgf
GSM_ID
GSM7625568 27.713889 33443.488165 4.889224 -0.207662 1.657720e+06 309.088007 6452.711603 525.537382 0.159939 -0.109247 ... -0.100315 0.267266 0.010796 0.257249 0.012845 -0.144005 -0.199204 -0.288781 0.283817 -0.094640
GSM7625569 27.713889 33603.182758 -0.675848 0.151504 1.679469e+06 319.938681 5657.270706 550.893628 0.047051 -0.255467 ... -0.218816 0.086749 -0.205998 0.042463 -0.109191 -0.249524 -0.164367 -0.338399 0.262606 -0.106812
GSM7625570 27.713889 33807.609589 3.680232 -0.113398 1.682385e+06 312.595083 6654.595126 545.927825 0.050562 -0.225324 ... -0.197247 0.127577 -0.130641 0.094274 -0.094252 -0.226423 -0.189340 -0.316972 0.234579 -0.125441
GSM7625571 27.713889 33880.497852 3.483722 0.168095 1.680143e+06 317.126936 6792.263369 531.686439 0.194388 -0.103120 ... -0.103527 0.264676 0.018594 0.178806 0.027473 -0.150791 -0.140418 -0.208958 0.266402 -0.067963
GSM7625572 23.183333 33192.500811 4.700933 -0.575709 1.590738e+06 317.324722 4798.638056 522.853005 0.017105 -0.295679 ... -0.246631 0.001140 -0.265735 -0.007235 -0.135532 -0.271881 -0.144465 -0.389837 0.226135 -0.125433

5 rows × 330 columns

6.2 Calculate CpGPTGrimAge3#

[19]:
adata = pya.pp.df_to_adata(combined_df, verbose=False)
pya.pred.predict_age(adata, clock_names=["cpgptgrimage3"], verbose=False)

adata.obs.head()
[19]:
cpgptgrimage3
GSM_ID
GSM7625568 35.866626
GSM7625569 34.805627
GSM7625570 34.625747
GSM7625571 37.634215
GSM7625572 30.980501