import marshal
import math
import ntpath
import os
import types
from typing import Dict, List, Tuple
from urllib.request import urlretrieve
import anndata
import numpy as np
import pandas as pd
import torch
from anndata.experimental.pytorch import AnnLoader
from torch.utils.data import DataLoader, TensorDataset
try:
import cupy as cp
CUPY_AVAILABLE = cp.cuda.is_available()
except:
CUPY_AVAILABLE = False
import gc
from ..logger import LoggerManager, main_tqdm, silence_logger
from ..models import *
from ..utils import download, load_clock_metadata, progress
from ._postprocessing import *
from ._preprocessing import *
[docs]
@progress("Load clock")
def load_clock(clock_name: str, device: str, dir: str, logger, indent_level: int = 2) -> Tuple:
"""
Loads the specified aging clock from a remote source and returns its components.
This function downloads the weights and configuration of a specified aging clock from a
remote server. This allows users to instantiate and use the clock in their analyses.
Parameters
----------
clock_name : str
The name of the aging clock to be loaded. This name is used to construct the URL
for downloading the clock's weights and configuration.
device : str
Device to move clock to. Eithe 'cpu' or 'cuda'.
dir : str
The directory to deposit the downloaded file.
logger : Logger
A logger object used for logging information during the function execution.
indent_level : int, optional
The indentation level for the logger, by default 2. It controls the formatting
of the log messages.
Returns
-------
pyagingModel
A clock model
Notes
-----
The clock's weights and configuration are assumed to be stored in a .pt (PyTorch) file
on a remote server. The URL for the clock is constructed based on the clock's name.
The function uses the `download` utility to retrieve the file. If the clock or its
components are not found, the function may fail or return incomplete information.
The logger is used extensively for progress tracking and information logging, enhancing
transparency and user experience.
Examples
--------
>>> clock = load_clock("clock1", "pyaging_data", logger)
"""
clock_name = clock_name.lower()
url = f"https://pyaging.s3.amazonaws.com/clocks/weights0.1.0/{clock_name}.pt"
try:
download(url, dir, logger, indent_level=indent_level)
except:
logger.error(
f"Clock {clock_name} is not available on pyaging. "
f"Please refer to the clock names in the clock glossary table "
f"in the package documentation page: pyaging.readthedocs.io",
indent_level=indent_level + 1,
)
raise NameError
# Define the path to the clock weights file
weights_path = os.path.join(dir, f"{clock_name}.pt")
# Load the clock from the file
clock = torch.load(weights_path, weights_only=False)
# Prepare clock for inference
clock.to(torch.float64)
clock.to(device)
clock.eval()
return clock
[docs]
@progress("Check features in adata")
def check_features_in_adata(
adata: anndata.AnnData,
model: pyagingModel,
logger,
indent_level: int = 2,
) -> anndata.AnnData:
"""
Verifies if all required features are present in an AnnData object and adds missing features.
This function checks an AnnData object (commonly used in single-cell analysis) to ensure
that it contains all the necessary features specified in the 'features' list inside the model.
If any features are missing, they are added to the AnnData object with a default value of 0 or
with a reference value if given. This is crucial for downstream analyses where the presence of
all specified features is assumed.
Parameters
----------
adata : anndata.AnnData
The AnnData object to be checked. It is a commonly used data structure in single-cell
genomics containing high-dimensional data.
model : pyagingModel
The pyagingModel of the aging clock of interest. Must contain defined features.
logger : Logger
A logger object used for logging information about the process, such as the number
of missing features.
indent_level : int, optional
The indentation level for the logger, by default 2. It controls the formatting
of the log messages.
Returns
-------
anndata.AnnData
The updated AnnData object, which includes any missing features added with a default
value of 0 (or reference value if provided).
Notes
-----
This function is particularly useful in preprocessing steps where the consistency of
data structure across different datasets is crucial. The function modifies the AnnData
object if there are missing features and logs detailed information about these modifications.
The added features are initialized with zeros. This approach, while providing completeness,
may introduce biases if not accounted for in downstream analyses. If reference values are
provided, then they are used instead of zeros.
Examples
--------
>>> updated_adata = check_features_in_adata(adata, bitage, ["gene1", "gene2"], logger)
>>> updated_adata.var_names
Index(['gene1', 'gene2', ...], dtype='object')
"""
# Preallocate the data matrix
adata.obsm[f"X_{model.metadata['clock_name']}"] = (
cp.empty((adata.n_obs, len(model.features)))
if CUPY_AVAILABLE
else np.empty((adata.n_obs, len(model.features)), order="F")
)
# Find indices of matching features in adata.var_names
feature_indices = {feature: i for i, feature in enumerate(adata.var_names)}
model_feature_indices = np.array([feature_indices.get(feature, -1) for feature in model.features])
# Identify missing features
missing_features_mask = model_feature_indices == -1
missing_features = np.array(model.features)[missing_features_mask].tolist()
# Assign values for existing features
existing_features_mask = ~missing_features_mask
existing_features_indices = model_feature_indices[existing_features_mask]
adata.obsm[f"X_{model.metadata['clock_name']}"][:, existing_features_mask] = adata.X[:, existing_features_indices]
# Handle missing features
adata.obsm[f"X_{model.metadata['clock_name']}"][:, missing_features_mask] = (
np.array(model.reference_values)[missing_features_mask] if model.reference_values is not None else 0
)
# Calculate missing features statistics
num_missing_features = len(missing_features)
percent_missing = 100 * num_missing_features / len(model.features)
# Add missing features and percent missing values to the clock
adata.uns[f"{model.metadata['clock_name']}_percent_na"] = percent_missing
adata.uns[f"{model.metadata['clock_name']}_missing_features"] = missing_features
# Raises error if there are no features in the data
if percent_missing == 100:
logger.error(
f"Every single feature out of {len(model.features)} features "
f"is missing. Please double check the features in the adata object"
f" actually contain the clock features such as {missing_features[: np.min([3, num_missing_features])]}, etc.",
indent_level=3,
)
raise NameError
# Log and add missing features if any
if len(missing_features) > 0:
logger.warning(
f"{num_missing_features} out of {len(model.features)} features "
f"({percent_missing:.2f}%) are missing: {missing_features[: np.min([3, num_missing_features])]}, etc.",
indent_level=indent_level + 1,
)
# If there are reference values provided
if model.reference_values is not None:
logger.info(
f"Using reference feature values for {model.metadata['clock_name']}",
indent_level=indent_level + 1,
)
else:
logger.info(
"Filling missing features entirely with 0",
indent_level=indent_level + 1,
)
else:
logger.info(
"All features are present in adata.var_names.",
indent_level=indent_level + 1,
)
[docs]
@progress("Predict ages with model")
def predict_ages_with_model(
adata: anndata.AnnData,
model: pyagingModel,
device: str,
batch_size: int,
logger,
indent_level: int = 2,
) -> torch.Tensor:
"""
Predict biological ages using a trained model and input data.
This function takes a machine learning model and input data, and returns predictions made by the model.
It's primarily used for estimating biological ages based on various biological markers. The function
assumes that the model is already trained. A dataloader is used because of possible memory constraints
for large datasets.
Parameters
----------
adata : anndata.AnnData
The AnnData object containing the dataset. Its `.X` attribute is expected to be a matrix where rows
correspond to samples and columns correspond to features.
model : pyagingModel
The pyagingModel of the aging clock of interest.
device : str
Device to move AnnData to during inference. Eithe 'cpu' or 'cuda'.
batch_size : int
Batch size for the AnnLoader object to predict age.
logger : Logger
A logger object for logging the progress or any relevant information during the prediction process.
indent_level : int, optional
The indentation level for logging messages, by default 2.
Returns
-------
predictions : torch.Tensor
An array of predicted ages or biological markers, as returned by the model.
Notes
-----
Ensure that the data is preprocessed (e.g., scaled, normalized) as required by the model before
passing it to this function. The model should be in evaluation mode if it's a type that has different
behavior during training and inference (e.g., PyTorch models).
The exact nature of the predictions (e.g., age, biological markers) depends on the model being used.
Examples
--------
>>> model = load_pretrained_model()
>>> predictions = predict_ages_with_model(model, "cpu", logger)
>>> print(predictions[:5])
[34.5, 29.3, 47.8, 50.1, 42.6]
"""
# If there is a preprocessing step
if model.preprocess_name is not None:
logger.info(
f"The preprocessing method is {model.preprocess_name}",
indent_level=indent_level + 1,
)
else:
logger.info("There is no preprocessing necessary", indent_level=indent_level + 1)
# If there is a postprocessing step
if model.postprocess_name is not None:
logger.info(
f"The postprocessing method is {model.postprocess_name}",
indent_level=indent_level + 1,
)
else:
logger.info("There is no postprocessing necessary", indent_level=indent_level + 1)
# Create an AnnLoader
use_cuda = torch.cuda.is_available()
dataloader = AnnLoader(adata, batch_size=batch_size, use_cuda=use_cuda)
# Use the AnnLoader for batched prediction
predictions = []
with torch.inference_mode():
for batch in main_tqdm(dataloader, indent_level=indent_level + 1, logger=logger):
batch_pred = model(batch.obsm[f"X_{model.metadata['clock_name']}"].to(torch.float64))
predictions.append(batch_pred)
# Concatenate all batch predictions
predictions = torch.cat(predictions)
return predictions
[docs]
@progress("Set PyTorch device")
def set_torch_device(logger, indent_level: int = 1) -> None:
"""
Set and return the PyTorch device based on the availability of CUDA.
This function checks if CUDA is available in the system and accordingly sets the PyTorch device to
either 'cuda' or 'cpu'. If CUDA is available, it utilizes GPU acceleration for PyTorch operations,
significantly enhancing computation speed for large datasets. The chosen device is logged for
user reference.
Parameters
----------
logger : Logger
A logger object for logging the selected device.
indent_level : int, optional
The indentation level for logging messages, by default 1.
Returns
-------
torch.device
The PyTorch device object set to 'cuda' if CUDA is available, or 'cpu' otherwise.
Notes
-----
The function automatically detects the availability of CUDA and makes a decision without user input.
This makes it convenient for deploying code on different machines without the need for manual
configuration.
It is important to use the returned device for all PyTorch operations to ensure that they are
executed on the correct hardware (CPU or GPU).
Examples
--------
>>> logger = pyaging.logger.LoggerManager.gen_logger("example")
>>> device = set_torch_device(logger)
>>> print(device)
device(type='cuda') # or device(type='cpu') if CUDA is not available
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}", indent_level=2)
return device
[docs]
def cleanup_clock_memory(model=None, clock_name=None, dir=None, **kwargs) -> None:
"""
Explicitly clean up memory and disk space from loaded clock models.
This function performs aggressive memory and disk cleanup to prevent
out-of-memory and out-of-disk-space issues during testing or when processing
multiple clocks sequentially. It deletes specified objects, removes downloaded
.pt files, and forces garbage collection.
Parameters
----------
model : pyagingModel, optional
The loaded clock model to delete from memory.
clock_name : str, optional
The name of the clock whose .pt file should be deleted from disk.
dir : str, optional
The directory containing the .pt file to delete. Required if clock_name is provided.
**kwargs : dict
Additional objects to delete from memory. Each key-value pair
represents an object name and the object itself to be deleted.
Notes
-----
This function is particularly useful during testing when multiple clocks
are loaded sequentially, as it prevents memory accumulation and disk space
consumption that can lead to "No space left on device" errors in CI environments.
The function performs the following cleanup steps:
1. Deletes the provided model object if given
2. Deletes any additional objects passed via kwargs
3. Removes the downloaded .pt file from disk if clock_name and dir are provided
4. Forces Python garbage collection
5. Clears PyTorch CUDA cache if available
Examples
--------
>>> model = load_clock("horvath2013", "cpu", "pyaging_data", logger)
>>> # ... use model ...
>>> cleanup_clock_memory(model=model, clock_name="horvath2013", dir="pyaging_data")
"""
# Delete the model if provided
if model is not None:
del model
# Delete any additional objects passed via kwargs
for name, obj in kwargs.items():
if obj is not None:
del obj
# Delete the .pt file from disk if specified
if clock_name is not None and dir is not None:
weights_path = os.path.join(dir, f"{clock_name}.pt")
try:
if os.path.exists(weights_path):
os.remove(weights_path)
except OSError:
# Silently ignore file deletion errors to avoid disrupting tests
pass
# Force garbage collection
gc.collect()
# Clear PyTorch CUDA cache
torch.cuda.empty_cache()