cpa.CPA.custom_predict#

CPA.custom_predict(covars_to_add=None, basal=False, add_batch=True, add_pert=True, adata=None, indices=None, batch_size=32, n_samples=20, return_mean=True)[source]#

Predicts the output of the model on the given input data.

Parameters:
covars_to_add Optional[Sequence[str]]

List of covariates to add to the basal latent representation.

basal bool

Whether to use just the basal latent representation. If True, add_batch and add_pert are ignored.

add_batch bool

Whether to add the batch covariate to the latent representation.

add_pert bool

Whether to add the perturbation covariate to the latent representation.

adata Optional[AnnData]

The input data to predict on.

indices Optional[Sequence[int]]

The indices of the cells to predict on.

batch_size Optional[int]

The batch size to use for prediction.

n_samples int

The number of samples to use for stochastic prediction.

return_mean bool

Whether to return the mean of the samples or all the samples.

Returns:

A dictionary of AnnData objects containing the predicted gene expression for the specified covariates, and latent representations for different covariate combinations.

Return type:

latent_outputs (AnnData)