cpa.CPA#
- class cpa.CPA(adata, split_key=None, train_split='train', valid_split='test', test_split='ood', use_rdkit_embeddings=False, **hyper_params)[source]#
CPA model
- Parameters:
- adata Anndata
Registered Annotation Dataset
- n_latent int
Number of latent dimensions used for drug and Autoencoder
- recon_loss str
Either gauss or NB. Autoencoder loss function.
- doser_type str
Type of doser network. Either sigm, logsigm or mlp.
- split_key str, optional
Key used to split the data between train test and validation. This must correspond to a observation key for the adata, composed of values ‘train’, ‘test’, and ‘ood’. By default None.
- **hyper_params
CPA’s hyper-parameters.
Examples
>>> import cpa >>> import scanpy as sc >>> adata = sc.read('dataset.h5ad') >>> adata = cpa.CPA.setup_anndata(adata, perturbation_keys={'perturbation': 'condition', 'dosage': 'dose_val'}, categorical_covariate_keys=['cell_type'], control_key='control' ) >>> hyperparams = {'autoencoder_depth': 3, 'autoencoder_width': 256} >>> model = cpa.CPA(adata, n_latent=256, loss_ae='gauss', doser_type='logsigm', split_key='split', )
Attributes
Data attached to model instance.
Manager instance associated with self.adata.
The current device that the module's params are on.
Returns computed metrics during training.
Whether the model has been trained.
Observations that are in test set.
Observations that are in train set.
Observations that are in validation set.
Methods
convert_legacy_save(dir_path, output_dir_path)Converts a legacy saved model (<v0.15.0) to the updated save format.
custom_predict([covars_to_add, basal, ...])Predicts the output of the model on the given input data.
get_anndata_manager(adata[, required])Retrieves the
AnnDataManagerfor a given AnnData object specific to this model instance.get_covar_embeddings(covariate[, ...])Computes Covariate embeddings (e.g. cell_type, tissue, etc.).
get_from_registry(adata, registry_key)Returns the object in AnnData associated with the key in the data registry.
get_latent_representation([adata, indices, ...])Returns All latent representations for the given dataset
get_pert_embeddings([dosage, pert])Computes all/specific perturbation (e.g. drug) embeddings.
load(dir_path[, adata, use_gpu])Loads the model from the specified directory.
load_registry(dir_path[, prefix])Return the full registry saved with the model.
predict([adata, indices, batch_size, ...])Counterfactual-friendly gene expression prediction
register_manager(adata_manager)Registers an
AnnDataManagerinstance with this model class.save(dir_path[, overwrite, save_anndata])Saves the state of the model.
setup_anndata(adata, perturbation_key, ...)Annotation Data setup function
to_device(device)Move model to device.
train([max_epochs, use_gpu, train_size, ...])Trains CPA on the given dataset
view_anndata_setup([adata, ...])Print summary of the setup for the initial AnnData or a given AnnData object.
view_setup_args(dir_path[, prefix])Print args used to setup a saved model.