cpa.CPA#

class cpa.CPA(adata, split_key=None, train_split='train', valid_split='test', test_split='ood', use_rdkit_embeddings=False, **hyper_params)[source]#

CPA model

Parameters:
adata Anndata

Registered Annotation Dataset

n_latent int

Number of latent dimensions used for drug and Autoencoder

recon_loss str

Either gauss or NB. Autoencoder loss function.

doser_type str

Type of doser network. Either sigm, logsigm or mlp.

split_key str, optional

Key used to split the data between train test and validation. This must correspond to a observation key for the adata, composed of values ‘train’, ‘test’, and ‘ood’. By default None.

**hyper_params

CPA’s hyper-parameters.

Examples

>>> import cpa
>>> import scanpy as sc
>>> adata = sc.read('dataset.h5ad')
>>> adata = cpa.CPA.setup_anndata(adata,
                                  perturbation_keys={'perturbation': 'condition',  'dosage': 'dose_val'},
                                  categorical_covariate_keys=['cell_type'],
                                  control_key='control'
                                  )
>>> hyperparams = {'autoencoder_depth': 3, 'autoencoder_width': 256}
>>> model = cpa.CPA(adata,
                    n_latent=256,
                    loss_ae='gauss',
                    doser_type='logsigm',
                    split_key='split',
                    )

Attributes

adata

Data attached to model instance.

adata_manager

Manager instance associated with self.adata.

covars_encoder

device

The current device that the module's params are on.

history

Returns computed metrics during training.

is_trained

Whether the model has been trained.

pert_encoder

pert_smiles_map

test_indices

Observations that are in test set.

train_indices

Observations that are in train set.

validation_indices

Observations that are in validation set.

Methods

convert_legacy_save(dir_path, output_dir_path)

Converts a legacy saved model (<v0.15.0) to the updated save format.

custom_predict([covars_to_add, basal, ...])

Predicts the output of the model on the given input data.

get_anndata_manager(adata[, required])

Retrieves the AnnDataManager for a given AnnData object specific to this model instance.

get_covar_embeddings(covariate[, ...])

Computes Covariate embeddings (e.g. cell_type, tissue, etc.).

get_from_registry(adata, registry_key)

Returns the object in AnnData associated with the key in the data registry.

get_latent_representation([adata, indices, ...])

Returns All latent representations for the given dataset

get_pert_embeddings([dosage, pert])

Computes all/specific perturbation (e.g. drug) embeddings.

load(dir_path[, adata, use_gpu])

Loads the model from the specified directory.

load_registry(dir_path[, prefix])

Return the full registry saved with the model.

predict([adata, indices, batch_size, ...])

Counterfactual-friendly gene expression prediction

register_manager(adata_manager)

Registers an AnnDataManager instance with this model class.

save(dir_path[, overwrite, save_anndata])

Saves the state of the model.

setup_anndata(adata, perturbation_key, ...)

Annotation Data setup function

to_device(device)

Move model to device.

train([max_epochs, use_gpu, train_size, ...])

Trains CPA on the given dataset

view_anndata_setup([adata, ...])

Print summary of the setup for the initial AnnData or a given AnnData object.

view_setup_args(dir_path[, prefix])

Print args used to setup a saved model.