import json
import logging
import os
from tkinter import N
from typing import Optional, Sequence, Union, List, Dict
from rdkit import Chem
from rdkit.Chem import AllChem
import torch.nn as nn
import numpy as np
import pandas as pd
import torch
from pytorch_lightning.callbacks import EarlyStopping
from scvi.data import AnnDataManager
from scvi.dataloaders import DataSplitter
from torch.nn import functional as F
from scvi.data.fields import (
LayerField,
CategoricalObsField,
NumericalObsField,
ObsmField,
)
from anndata import AnnData
from scvi.model.base import BaseModelClass
from scvi.train import TrainRunner
from scvi.train._callbacks import SaveBestState
from scvi.utils import setup_anndata_dsp
from tqdm import tqdm
from ._module import CPAModule
from ._utils import CPA_REGISTRY_KEYS
from ._task import CPATrainingPlan
from ._data import AnnDataSplitter
logger = logging.getLogger(__name__)
logger.propagate = False
[docs]
class CPA(BaseModelClass):
"""CPA model
Parameters
----------
adata : Anndata
Registered Annotation Dataset
n_latent: int
Number of latent dimensions used for drug and Autoencoder
recon_loss: str
Either `gauss` or `NB`. Autoencoder loss function.
doser_type: str
Type of doser network. Either `sigm`, `logsigm` or `mlp`.
split_key : str, optional
Key used to split the data between train test and validation.
This must correspond to a observation key for the adata, composed of values
'train', 'test', and 'ood'. By default None.
**hyper_params:
CPA's hyper-parameters.
Examples
--------
>>> import cpa
>>> import scanpy as sc
>>> adata = sc.read('dataset.h5ad')
>>> adata = cpa.CPA.setup_anndata(adata,
perturbation_keys={'perturbation': 'condition', 'dosage': 'dose_val'},
categorical_covariate_keys=['cell_type'],
control_key='control'
)
>>> hyperparams = {'autoencoder_depth': 3, 'autoencoder_width': 256}
>>> model = cpa.CPA(adata,
n_latent=256,
loss_ae='gauss',
doser_type='logsigm',
split_key='split',
)
"""
covars_encoder: dict = None
pert_encoder: dict = None
pert_smiles_map: dict = None
def __init__(
self,
adata: AnnData,
split_key: str = None,
train_split: Union[str, List[str]] = "train",
valid_split: Union[str, List[str]] = "test",
test_split: Union[str, List[str]] = "ood",
use_rdkit_embeddings: bool = False,
**hyper_params,
):
super().__init__(adata)
self.split_key = split_key
self.drugs = list(self.pert_encoder.keys())
self.covars = {
covar: list(self.covars_encoder[covar].keys())
for covar in self.covars_encoder.keys()
}
if use_rdkit_embeddings and self.pert_smiles_map is not None:
# get morgan fingerprint vectors for drug embeddings
drug_embeddings = self.__get_rdkit_embeddings()
hyper_params['drug_embeddings'] = drug_embeddings
self.module = CPAModule(
n_genes=adata.n_vars,
n_perts=len(self.pert_encoder),
covars_encoder=self.covars_encoder,
**hyper_params,
).float()
train_indices, valid_indices, test_indices = None, None, None
if split_key is not None:
train_split = (
train_split if isinstance(train_split, list) else [train_split]
)
valid_split = (
valid_split if isinstance(valid_split, list) else [valid_split]
)
test_split = test_split if isinstance(test_split, list) else [test_split]
train_indices = np.where(adata.obs.loc[:, split_key].isin(train_split))[0]
valid_indices = np.where(adata.obs.loc[:, split_key].isin(valid_split))[0]
test_indices = np.where(adata.obs.loc[:, split_key].isin(test_split))[0]
self.train_indices = train_indices
self.valid_indices = valid_indices
self.test_indices = test_indices
self._model_summary_string = f"Compositional Perturbation Autoencoder"
self.init_params_ = self._get_init_params(locals())
self.epoch_history = None
def __get_rdkit_embeddings(
self,
):
assert self.pert_smiles_map not in [None, []]
query_drug_names = list(self.pert_encoder.keys())
query_drug_names.remove('<PAD>')
smiles_list = [self.pert_smiles_map[drug] for drug in list(query_drug_names)]
drug_fps = []
for smiles in smiles_list:
mol = Chem.MolFromSmiles(smiles)
fps = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
drug_fps.append(np.array(fps))
drug_fps = np.vstack(drug_fps)
print(drug_fps.shape)
embeddings = AnnData(X=drug_fps)
embeddings.obs.index = smiles_list
embeddings = embeddings[list(smiles_list), :]
drug_embeddings = nn.Embedding(
len(self.pert_encoder),
embeddings.shape[1],
padding_idx=CPA_REGISTRY_KEYS.PADDING_IDX,
)
pad_X = np.zeros(shape=(1, embeddings.n_vars))
X = np.concatenate((pad_X, embeddings.X), 0)
drug_embeddings.weight.data.copy_(torch.tensor(X))
drug_embeddings.weight.requires_grad = False
return drug_embeddings
[docs]
@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
perturbation_key: str,
control_group: str,
dosage_key: Optional[str] = None,
batch_key: Optional[str] = None,
layer: Optional[str] = None,
smiles_key: Optional[str] = None,
is_count_data: Optional[bool] = True,
categorical_covariate_keys: Optional[List[str]] = [],
deg_uns_key: Optional[str] = None,
deg_uns_cat_key: Optional[str] = None,
max_comb_len: int = 2,
**kwargs,
):
"""
Annotation Data setup function
Parameters
----------
adata: anndata.AnnData
AnnData object
perturbation_key: str
Key in `adata.obs` containing perturbations
control_group: str
Control group name
dosage_key: str, optional
Key in `adata.obs` containing perturbation dosages, by default None. If None, all dosages are set to 1.0
batch_key: str, optional
Key in `adata.obs` containing batch information, by default None
layer: str, optional
Key in `adata.layers` containing gene expression data, by default None. If None, `adata.X` is used
is_count_data: bool, optional
Whether the data is count data, by default False
categorical_covariate_keys: List[str], optional
List of keys in `adata.obs` containing categorical covariates, by default None
deg_uns_key: str, optional
Key in `adata.uns` containing differentially expressed genes for each combination of covariates and perturbations, by default None
deg_uns_cat_key: str, optional
Key in `adata.obs` containing covariate combinations for each cell, by default None
max_comb_len: int, optional
Maximum number of perturbations in a combination, by default 2
"""
CPA_REGISTRY_KEYS.PERTURBATION_KEY = perturbation_key
CPA_REGISTRY_KEYS.PERTURBATION_DOSAGE_KEY = dosage_key
CPA_REGISTRY_KEYS.CAT_COV_KEYS = categorical_covariate_keys
CPA_REGISTRY_KEYS.MAX_COMB_LENGTH = max_comb_len
CPA_REGISTRY_KEYS.BATCH_KEY = batch_key
if dosage_key is None:
print(f'Warning: dosage_key is not set. Setting it to "1.0" for all cells')
dosage_key = 'CPA_dose_val'
adata.obs[dosage_key] = adata.obs[perturbation_key].apply(lambda x: '+'.join(['1.0' for _ in x.split('+')])).values
CPA_REGISTRY_KEYS.PERTURBATION_DOSAGE_KEY = dosage_key
perturbations = adata.obs[perturbation_key].astype(str).values
dosages = adata.obs[dosage_key].astype(str).values
category_key = f"{cls.__name__}_cat"
keys = categorical_covariate_keys + [perturbation_key]
if batch_key is not None:
keys = [batch_key] + keys
adata.obs[category_key] = adata.obs[keys].apply(lambda x: "_".join(x), axis=1)
CPA_REGISTRY_KEYS.CATEGORY_KEY = category_key
if cls.pert_encoder is None:
# get unique drugs
perts_names_unique = set()
for d in np.unique(perturbations):
[perts_names_unique.add(i) for i in d.split("+") if i != control_group]
perts_names_unique = ["<PAD>", control_group] + sorted(
list(perts_names_unique)
)
CPA_REGISTRY_KEYS.PADDING_IDX = 0
pert_encoder = {pert: i for i, pert in enumerate(perts_names_unique)}
else:
pert_encoder = cls.pert_encoder
perts_names_unique = list(pert_encoder.keys())
if smiles_key is not None:
if cls.pert_smiles_map is None:
pert_smiles_map = {}
for pert in perts_names_unique:
if pert != "<PAD>":
try:
pert_smiles_map[pert] = adata.obs.loc[
adata.obs[perturbation_key] == pert, smiles_key
].values[0]
except:
pert_name = adata.obs.loc[
adata.obs[perturbation_key].str.contains(pert), perturbation_key
].values[0]
smiles = adata.obs.loc[
adata.obs[perturbation_key].str.contains(pert), smiles_key
].values[0]
pert_smiles_map[pert] = smiles.split('..')[pert_name.split('+').index(pert)]
cls.pert_smiles_map = pert_smiles_map
else:
pert_smiles_map = cls.pert_smiles_map
pert_map = {}
for condition in tqdm(perturbations):
perts_list = np.where(np.isin(perts_names_unique, condition.split("+")))[0]
pert_map[condition] = list(perts_list) + [
CPA_REGISTRY_KEYS.PADDING_IDX
for _ in range(max_comb_len - len(perts_list))
]
dose_map = {}
for dosage_str in tqdm(dosages):
dosages_list = [float(i) for i in dosage_str.split("+")]
dose_map[dosage_str] = list(dosages_list) + [
0.0 for _ in range(max_comb_len - len(dosages_list))
]
data_perts = np.vstack(
np.vectorize(lambda x: pert_map[x], otypes=[np.ndarray])(perturbations)
).astype(int)
adata.obsm[CPA_REGISTRY_KEYS.PERTURBATIONS] = data_perts
data_perts_dosages = np.vstack(
np.vectorize(lambda x: dose_map[x], otypes=[np.ndarray])(dosages)
).astype(float)
adata.obsm[CPA_REGISTRY_KEYS.PERTURBATIONS_DOSAGES] = data_perts_dosages
# setup control column
control_key = f"{cls.__name__}_{control_group}"
CPA_REGISTRY_KEYS.CONTROL_KEY = control_key
adata.obs[control_key] = (adata.obs[perturbation_key] == control_group).astype(
int
)
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
LayerField(
registry_key=CPA_REGISTRY_KEYS.X_KEY,
layer=layer,
is_count_data=is_count_data,
),
ObsmField(
CPA_REGISTRY_KEYS.PERTURBATIONS,
CPA_REGISTRY_KEYS.PERTURBATIONS,
is_count_data=True,
correct_data_format=True,
),
ObsmField(
CPA_REGISTRY_KEYS.PERTURBATIONS_DOSAGES,
CPA_REGISTRY_KEYS.PERTURBATIONS_DOSAGES,
is_count_data=False,
correct_data_format=True,
),
CategoricalObsField(
registry_key=CPA_REGISTRY_KEYS.PERTURBATION_KEY,
attr_key=perturbation_key,
),
] + [
CategoricalObsField(registry_key=covar, attr_key=covar)
for covar in categorical_covariate_keys
]
anndata_fields.append(
NumericalObsField(registry_key=control_key, attr_key=control_key)
)
anndata_fields.append(
CategoricalObsField(registry_key=category_key, attr_key=category_key)
)
if batch_key is not None:
anndata_fields.append(
CategoricalObsField(registry_key=batch_key, attr_key=batch_key)
)
if deg_uns_key:
n_deg_r2 = kwargs.pop("n_deg_r2", 10)
cov_cond_unique = np.unique(adata.obs[deg_uns_cat_key].astype(str).values)
cov_cond_map = {}
cov_cond_map_r2 = {}
for cov_cond in tqdm(cov_cond_unique):
if cov_cond in adata.uns[deg_uns_key].keys():
mask_hvg = adata.var_names.isin(
adata.uns[deg_uns_key][cov_cond]
).astype(int)
mask_hvg_r2 = adata.var_names.isin(
adata.uns[deg_uns_key][cov_cond][:n_deg_r2]
).astype(int)
cov_cond_map[cov_cond] = list(mask_hvg)
cov_cond_map_r2[cov_cond] = list(mask_hvg_r2)
else:
no_mask = list(np.ones(shape=(adata.n_vars,)))
cov_cond_map[cov_cond] = no_mask
cov_cond_map_r2[cov_cond] = no_mask
mask = np.vstack(
np.vectorize(lambda x: cov_cond_map[x], otypes=[np.ndarray])(
adata.obs[deg_uns_cat_key].astype(str).values
)
)
mask_r2 = np.vstack(
np.vectorize(lambda x: cov_cond_map[x], otypes=[np.ndarray])(
adata.obs[deg_uns_cat_key].astype(str).values
)
)
CPA_REGISTRY_KEYS.DEG_MASK = "deg_mask"
CPA_REGISTRY_KEYS.DEG_MASK_R2 = "deg_mask_r2"
adata.obsm[CPA_REGISTRY_KEYS.DEG_MASK] = np.array(mask)
adata.obsm[CPA_REGISTRY_KEYS.DEG_MASK_R2] = np.array(mask_r2)
anndata_fields.append(
ObsmField(
CPA_REGISTRY_KEYS.DEG_MASK,
CPA_REGISTRY_KEYS.DEG_MASK,
is_count_data=True,
correct_data_format=True,
)
)
anndata_fields.append(
ObsmField(
CPA_REGISTRY_KEYS.DEG_MASK_R2,
CPA_REGISTRY_KEYS.DEG_MASK_R2,
is_count_data=True,
correct_data_format=True,
)
)
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)
keys = categorical_covariate_keys
if batch_key is not None:
keys.append(batch_key)
covars_encoder = {}
for covar in keys:
covars_encoder[covar] = {
c: i
for i, c in enumerate(
adata_manager.registry["field_registries"][covar]["state_registry"][
"categorical_mapping"
]
)
}
if cls.covars_encoder is None:
cls.covars_encoder = covars_encoder
if cls.pert_encoder is None:
cls.pert_encoder = pert_encoder
[docs]
def train(
self,
max_epochs: Optional[int] = None,
use_gpu: Optional[Union[str, int, bool]] = None,
train_size: float = 0.9,
validation_size: Optional[float] = None,
batch_size: int = 128,
plan_kwargs: Optional[dict] = None,
save_path: Optional[str] = None,
check_val_every_n_epoch: int = 10,
early_stopping_patience: int = 10,
**trainer_kwargs,
):
"""
Trains CPA on the given dataset
Parameters
----------
max_epochs: int
Maximum number of epochs for training
use_gpu: bool
Whether to use GPU if available
train_size: float
Fraction of training data in the case of randomly splitting dataset to train/valdiation
if `split_key` is not set in model's constructor
validation_size: float
Fraction of validation data in the case of randomly splitting dataset to train/valdiation
if `split_key` is not set in model's constructor
batch_size: int
Size of mini-batches for training
early_stopping: bool
If `True`, EarlyStopping will be used during training on validation dataset
plan_kwargs: dict
`CPATrainingPlan` parameters
save_path: str
Path to save the model after the end of training
check_val_every_n_epoch: int
How often to check validation metrics
early_stopping_patience: int
Patience for early stopping
**trainer_kwargs:
Additional parameters for `cpa.CPATrainingPlan`
"""
if max_epochs is None:
n_cells = self.adata.n_obs
max_epochs = np.min([round((20000 / n_cells) * 400), 400])
plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict()
manual_splitting = (
(self.valid_indices is not None)
and (self.train_indices is not None)
and (self.test_indices is not None)
)
if manual_splitting:
data_splitter = AnnDataSplitter(
self.adata_manager,
train_indices=self.train_indices,
valid_indices=self.valid_indices,
test_indices=self.test_indices,
batch_size=batch_size,
use_gpu=use_gpu,
)
else:
data_splitter = DataSplitter(
self.adata_manager,
train_size=train_size,
validation_size=validation_size,
batch_size=batch_size,
use_gpu=use_gpu,
)
perturbation_key = CPA_REGISTRY_KEYS.PERTURBATION_KEY
pert_adv_encoder = {
c: i
for i, c in enumerate(
self.adata_manager.registry["field_registries"][perturbation_key][
"state_registry"
]["categorical_mapping"]
)
}
drug_weights = []
n_adv_perts = len(self.adata.obs[perturbation_key].unique())
for condition in tqdm(list(pert_adv_encoder.keys())):
n_positive = len(self.adata[self.adata.obs[perturbation_key] == condition])
drug_weights.append((self.adata.n_obs / n_positive) - 1.0)
self.training_plan = CPATrainingPlan(
self.module,
self.covars_encoder,
n_adv_perts=n_adv_perts,
**plan_kwargs,
drug_weights=drug_weights,
)
trainer_kwargs["early_stopping"] = False
trainer_kwargs["check_val_every_n_epoch"] = check_val_every_n_epoch
es_callback = EarlyStopping(
monitor="cpa_metric",
patience=early_stopping_patience,
check_on_train_epoch_end=False,
verbose=False,
mode="max",
)
if "callbacks" in trainer_kwargs.keys() and isinstance(
trainer_kwargs.get("callbacks"), list
):
trainer_kwargs["callbacks"] += [es_callback]
else:
trainer_kwargs["callbacks"] = [es_callback]
if save_path is None:
save_path = "./"
checkpoint = SaveBestState(
monitor="cpa_metric", mode="max", period=1, verbose=True
)
trainer_kwargs["callbacks"].append(checkpoint)
self.runner = TrainRunner(
self,
training_plan=self.training_plan,
data_splitter=data_splitter,
max_epochs=max_epochs,
use_gpu=use_gpu,
early_stopping_monitor="cpa_metric",
early_stopping_mode="max",
**trainer_kwargs,
)
self.runner()
self.epoch_history = pd.DataFrame().from_dict(self.training_plan.epoch_history)
if save_path is not False:
self.save(save_path, overwrite=True)
[docs]
@torch.no_grad()
def get_latent_representation(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = 32,
):
"""Returns All latent representations for the given dataset
Parameters
----------
adata : Optional[AnnData], optional
[description], by default None
indices : Optional[Sequence[int]], optional
Optional indices, by default None
batch_size : Optional[int], optional
Batch size to use, by default None
Returns
-------
latent_outputs : Dict[str, anndata.AnnData]
Dictionary of latent representations containing:
- 'latent_corrected': batch-corrected (if batch_key is set) latent representation
- 'latent_basal': basal latent representation
- 'latent_after': final latent representation which can be used for decoding.
"""
if self.is_trained_ is False:
raise RuntimeError("Please train the model first.")
adata = self._validate_anndata(adata)
if indices is None:
indices = np.arange(adata.n_obs)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size, shuffle=False
)
latent_basal = []
latent = []
latent_corrected = []
for tensors in tqdm(scdl):
tensors, _ = self.module.mixup_data(tensors, alpha=0.0)
inference_inputs = self.module._get_inference_input(tensors)
outputs = self.module.inference(**inference_inputs)
latent_basal += [outputs["z_basal"].cpu().numpy()]
latent += [outputs["z"].cpu().numpy()]
latent_corrected += [outputs["z_corrected"].cpu().numpy()]
latent_basal_adata = AnnData(
X=np.concatenate(latent_basal, axis=0), obs=adata.obs.copy()
)
latent_basal_adata.obs_names = adata.obs_names
latent_corrected_adata = AnnData(
X=np.concatenate(latent_corrected, axis=0), obs=adata.obs.copy()
)
latent_corrected_adata.obs_names = adata.obs_names
latent_adata = AnnData(X=np.concatenate(latent, axis=0), obs=adata.obs.copy())
latent_adata.obs_names = adata.obs_names
latent_outputs = {
"latent_corrected": latent_corrected_adata,
"latent_basal": latent_basal_adata,
"latent_after": latent_adata,
}
return latent_outputs
[docs]
@torch.no_grad()
def predict(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = 32,
n_samples: int = 20,
return_mean: bool = True,
):
"""Counterfactual-friendly gene expression prediction
To produce counterfactuals, save `adata.X` to `adata.obsm['X_true']`
and set it to control cells gene expression.
For the case of reconstruction, you can pass original adata
without any further modifications.
Returns
-------
None (predictions are saved to `adata.obsm[f'CPA_pred']`)
"""
assert self.module.recon_loss in ["gauss", "nb", "zinb"]
self.module.eval()
adata = self._validate_anndata(adata)
if indices is None:
indices = np.arange(adata.n_obs)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size, shuffle=False
)
xs = []
for tensors in tqdm(scdl):
x_pred = (
self.module.get_expression(tensors, n_samples=n_samples)['px']
.detach()
.cpu()
.numpy()
)
xs.append(x_pred)
if n_samples > 1 and self.module.variational:
# The -2 axis correspond to cells.
x_pred = np.concatenate(xs, axis=1)
else:
x_pred = np.concatenate(xs, axis=0)
if self.module.variational and n_samples > 1 and return_mean:
x_pred = x_pred.mean(0)
adata.obsm[f"{self.__class__.__name__}_pred"] = x_pred
[docs]
def custom_predict(
self,
covars_to_add: Optional[Sequence[str]] = None,
basal=False,
add_batch: bool = True,
add_pert: bool = True,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = 32,
n_samples: int = 20,
return_mean: bool = True,
) -> AnnData:
"""
Predicts the output of the model on the given input data.
Args:
covars_to_add (Optional[Sequence[str]]): List of covariates to add to the basal latent representation.
basal (bool): Whether to use just the basal latent representation. If True, `add_batch` and `add_pert` are ignored.
add_batch (bool): Whether to add the batch covariate to the latent representation.
add_pert (bool): Whether to add the perturbation covariate to the latent representation.
adata (Optional[AnnData]): The input data to predict on.
indices (Optional[Sequence[int]]): The indices of the cells to predict on.
batch_size (Optional[int]): The batch size to use for prediction.
n_samples (int): The number of samples to use for stochastic prediction.
return_mean (bool): Whether to return the mean of the samples or all the samples.
Returns:
latent_outputs (AnnData): A dictionary of AnnData objects containing the predicted gene expression for the specified
covariates, and latent representations for different covariate combinations.
"""
if covars_to_add is None:
covars_to_add = []
for covar in covars_to_add:
assert covar in self.module.covars_encoder.keys(
), f"covariate {covar} not found in learned covariates"
if basal:
latent_key = "z_basal"
else:
if add_batch and add_pert:
latent_key = "z"
elif add_batch:
latent_key = "z_no_pert"
elif add_pert:
latent_key = "z_corrected"
else:
latent_key = "z_no_pert_corrected"
assert self.module.recon_loss in ["gauss", "nb", "zinb"]
self.module.eval()
adata = self._validate_anndata(adata)
if indices is None:
indices = np.arange(adata.n_obs)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size, shuffle=False
)
xs = []
zs = []
z_correcteds = []
z_no_perts = []
z_no_pert_correcteds = []
z_basals = []
for tensors in tqdm(scdl):
predictions = self.module.get_expression(
tensors, n_samples=n_samples, covars_to_add=covars_to_add, latent=latent_key)
px = predictions['px']
z = predictions['z']
z_corrected = predictions['z_corrected']
z_no_pert = predictions['z_no_pert']
z_no_pert_corrected = predictions['z_no_pert_corrected']
z_basal = predictions['z_basal']
x_pred = (
px.detach().cpu().numpy()
)
xs.append(x_pred)
z = (
z.detach().cpu().numpy()
)
zs.append(z)
z_corrected = (
z_corrected.detach().cpu().numpy()
)
z_correcteds.append(z_corrected)
z_no_pert = (
z_no_pert.detach().cpu().numpy()
)
z_no_perts.append(z_no_pert)
z_no_pert_corrected = (
z_no_pert_corrected.detach().cpu().numpy()
)
z_no_pert_correcteds.append(z_no_pert_corrected)
z_basal = (
z_basal.detach().cpu().numpy()
)
z_basals.append(z_basal)
if n_samples > 1 and self.module.variational:
# The -2 axis correspond to cells.
x_pred = np.concatenate(xs, axis=1)
z = np.concatenate(zs, axis=1)
z_corrected = np.concatenate(z_correcteds, axis=1)
z_no_pert = np.concatenate(z_no_perts, axis=1)
z_no_pert_corrected = np.concatenate(z_no_pert_correcteds, axis=1)
z_basal = np.concatenate(z_basals, axis=1)
else:
x_pred = np.concatenate(xs, axis=0)
z = np.concatenate(zs, axis=0)
z_corrected = np.concatenate(z_correcteds, axis=0)
z_no_pert = np.concatenate(z_no_perts, axis=0)
z_no_pert_corrected = np.concatenate(z_no_pert_correcteds, axis=0)
z_basal = np.concatenate(z_basals, axis=0)
if self.module.variational and n_samples > 1 and return_mean:
x_pred = x_pred.mean(0)
z = z.mean(0)
z_corrected = z_corrected.mean(0)
z_no_pert = z_no_pert.mean(0)
z_no_pert_corrected = z_no_pert_correcteds.mean(0)
z_basal = z_basal.mean(0)
latent_x_pred = AnnData(
X=x_pred, obs=adata.obs.copy()
)
latent_x_pred.obs_names = adata.obs_names
latent_z = AnnData(
X=z, obs=adata.obs.copy()
)
latent_z.obs_names = adata.obs_names
latent_z_corrected = AnnData(
X=z_corrected, obs=adata.obs.copy()
)
latent_z_corrected.obs_names = adata.obs_names
latent_z_no_pert = AnnData(
X=z_no_pert, obs=adata.obs.copy()
)
latent_z_no_pert.obs_names = adata.obs_names
latent_z_no_pert_corrected = AnnData(
X=z_no_pert_corrected, obs=adata.obs.copy()
)
latent_z_no_pert_corrected.obs_names = adata.obs_names
latent_z_basal = AnnData(
X=z_basal, obs=adata.obs.copy()
)
latent_z_basal.obs_names = adata.obs_names
latent_outputs = {
"latent_x_pred": latent_x_pred,
"latent_z": latent_z,
"latent_z_corrected": latent_z_corrected,
"latent_z_no_pert": latent_z_no_pert,
"latent_z_no_pert_corrected": latent_z_no_pert_corrected,
"latent_z_basal": latent_z_basal,
}
return latent_outputs
[docs]
@torch.no_grad()
def get_pert_embeddings(self, dosage=1.0, pert: Optional[str] = None):
"""Computes all/specific perturbation (e.g. drug) embeddings
Parameters
----------
dosage : float
Dosage of interest, by default 1.0
pert: str, optional
Perturbation name if single perturbation embedding is desired
Returns
-------
AnnData with perturbation embeddings in `.X` and perturbation names saved in `.obs['pert_name']`.
"""
self.module.eval()
if isinstance(dosage, float):
if pert is None:
n_drugs = len(self.pert_encoder)
treatments = [torch.arange(n_drugs, device=self.device).long().unsqueeze(1)]
for _ in range(CPA_REGISTRY_KEYS.MAX_COMB_LENGTH - 1):
treatments += [torch.zeros(n_drugs, device=self.device).long().unsqueeze(1) + CPA_REGISTRY_KEYS.PADDING_IDX]
treatments = torch.cat(treatments, dim=1) # (n_drugs, max_comb_len)
treatments_dosages = [torch.tensor([dosage for _ in range(n_drugs)], device=self.device).float().unsqueeze(1)] # (n_drugs, 1)
for _ in range(CPA_REGISTRY_KEYS.MAX_COMB_LENGTH - 1):
treatments_dosages += [torch.zeros(n_drugs, device=self.device).float().unsqueeze(1) + CPA_REGISTRY_KEYS.PADDING_IDX]
treatments_dosages = torch.cat(treatments_dosages, dim=1) # (n_drugs, max_comb_len)
else:
treatments = [self.pert_encoder[pert]] + [CPA_REGISTRY_KEYS.PADDING_IDX for _ in range(CPA_REGISTRY_KEYS.MAX_COMB_LENGTH - 1)]
treatments = torch.LongTensor(treatments).to(self.device).unsqueeze(0)
treatments_dosages = [dosage] + [CPA_REGISTRY_KEYS.PADDING_IDX for _ in range(CPA_REGISTRY_KEYS.MAX_COMB_LENGTH - 1)]
treatments_dosages = torch.FloatTensor(treatments_dosages).to(self.device).unsqueeze(0)
else:
raise NotImplementedError
embeds = self.module.pert_network(treatments, treatments_dosages).detach().cpu().numpy() # (1 or n_drugs, n_latent)
pert_latent_adata = AnnData(X=embeds)
pert_latent_adata.obs['pert_name'] = [pert] if pert is not None else self.pert_encoder.keys()
return pert_latent_adata
[docs]
@torch.no_grad()
def get_covar_embeddings(self, covariate: str, covariate_value: str = None):
"""Computes Covariate embeddings (e.g. cell_type, tissue, etc.)
Parameters
----------
covariate : str
covariate to be computed
covariate_value: str, Optional
Covariate specific value for embedding computation
Returns
-------
AnnData object with covariate embeddings in `.X` and covariate values in `.obs[covariate]`
"""
# assert and print the error
assert covariate in self.covars_encoder.keys(), f"covariate {covariate} not found in learned covariates"
self.module.eval()
if covariate_value is None:
covar_ids = torch.arange(
len(self.covars_encoder[covariate]), device=self.device
).long().unsqueeze(1)
else:
covar_ids = torch.LongTensor(
[self.covars_encoder[covariate][covariate_value]]
).to(self.device).long().unsqueeze(1)
embeddings = self.module.covars_embeddings[covariate](covar_ids).detach().cpu().numpy() # (n_covars, n_latent)
covar_latent_adata = AnnData(X=embeddings)
covar_latent_adata.obs[covariate] = [covariate_value] if covariate_value is not None else self.covars_encoder[covariate].keys()
return covar_latent_adata
[docs]
def save(
self,
dir_path: str,
overwrite: bool = False,
save_anndata: bool = False,
**anndata_write_kwargs,
):
"""
Saves the state of the model.
Parameters
----------
dir_path : `str`
Path to a directory.
overwrite : `bool`, optional (default: `False`)
Whether to overwrite the model/data in `dir_path` if it already exists.
save_anndata : `bool`, optional (default: `False`)
Whether to save the anndata along with the model.
**anndata_write_kwargs : keyword arguments
Keyword arguments to pass to anndata's write function.
"""
os.makedirs(dir_path, exist_ok=True)
# save public dictionaries
total_dict = {
"pert_encoder": self.pert_encoder,
"covars_encoder": self.covars_encoder,
"pert_smiles_map": self.pert_smiles_map,
}
json_dict = json.dumps(total_dict)
with open(os.path.join(dir_path, "CPA_info.json"), "w") as f:
f.write(json_dict)
if isinstance(self.epoch_history, dict):
self.epoch_history = pd.DataFrame().from_dict(
self.training_plan.epoch_history
)
self.epoch_history.to_csv(
os.path.join(dir_path, "history.csv"), index=False
)
elif isinstance(self.epoch_history, pd.DataFrame):
self.epoch_history.to_csv(
os.path.join(dir_path, "history.csv"), index=False
)
return super().save(
dir_path=dir_path,
overwrite=overwrite,
save_anndata=save_anndata,
**anndata_write_kwargs,
)
[docs]
@classmethod
def load(
cls,
dir_path: str,
adata: Optional[AnnData] = None,
use_gpu: Optional[Union[str, int, bool]] = None,
):
"""
Loads the model from the specified directory.
Parameters
----------
dir_path : `str`
Path to saved model.
adata : `~anndata.AnnData`, optional (default: `None`)
Annotated data matrix. Will call `cpa.CPA.setup_anndata` on the data after model restoration.
use_gpu : `bool` or `str` or `int`, optional (default: `None`)
Whether a GPU should be used. If `True`, will use GPU.
Returns
-------
:class:`~scvi.core.models.CPA`
Restored model from the specified directory.
"""
# load public dictionaries
with open(os.path.join(dir_path, "CPA_info.json")) as f:
total_dict = json.load(f)
cls.pert_encoder = total_dict["pert_encoder"]
cls.covars_encoder = total_dict["covars_encoder"]
cls.pert_smiles_map = total_dict.get("pert_smiles_map", None)
model = super().load(dir_path, adata, use_gpu)
try:
model.epoch_history = pd.read_csv(os.path.join(dir_path, "history.csv"))
except:
print("WARNING: The history was not found.")
return model