Source code for cpa._api

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import copy
import itertools

import numpy as np
import pandas as pd
import scanpy as sc
import torch
from anndata import AnnData
from sklearn.metrics import r2_score
from sklearn.metrics.pairwise import cosine_distances, euclidean_distances

from ._model import CPA
from ._utils import CPA_REGISTRY_KEYS


[docs] class ComPertAPI: """ API for CPA model to make it compatible with scanpy. """ def __init__(self, adata: AnnData, model: CPA, de_genes_uns_key: str = 'rank_genes_groups_cov', pert_category_key: str = 'cov_drug_dose_name', control_group: str = 'ctrl', experiment: str = 'drug'): """ Parameters ---------- adata : AnnData Annotated data matrix. model : CPA Pre-trained CPA model. de_genes_uns_key : str, optional Key for the DE genes in `adata.uns`, by default 'rank_genes_groups_cov' pert_category_key : str, optional Key for the perturbation category in `adata.obs`, by default 'cov_drug_dose_name' control_group : str, optional Name of the control group, by default 'ctrl' experiment : str, optional Type of experiment, by default 'drug'. Can be 'drug' or 'gene'. """ self.perturbation_key = CPA_REGISTRY_KEYS.PERTURBATION_KEY self.dose_key = CPA_REGISTRY_KEYS.PERTURBATION_DOSAGE_KEY self.covars_key = CPA_REGISTRY_KEYS.CAT_COV_KEYS adata.obs['control'] = (adata.obs[self.perturbation_key] == control_group).astype(int).values self.control_group = control_group self.control_key = 'control' self.model = model self.adata = adata self.var_names = adata.var_names self.experiment = experiment.lower() if de_genes_uns_key in adata.uns.keys(): self.de_genes = adata.uns[de_genes_uns_key] else: self.de_genes = None self.split_key = model.split_key data_types = list(np.unique(adata.obs[self.split_key])) + ['all'] self.unique_perts = list(model.pert_encoder.keys()) self.unique_covars = {} for covar in model.covars_encoder.keys(): self.unique_covars[covar] = list(model.covars_encoder[covar].keys()) self.num_drugs = len(model.pert_encoder) self.perts_dict = model.pert_encoder.copy() self.covars_dict = model.covars_encoder.copy() self.emb_covars = None self.emb_perts = None self.seen_covars_perts = None self.comb_emb = None self.control_cat = None self.seen_covars_perts = {} self.adatas = {} self.pert_categories_key = pert_category_key for k in data_types: if k == 'all': self.adatas[k] = adata.copy() self.seen_covars_perts[k] = np.unique(self.adatas[k].obs[self.pert_categories_key]) else: self.adatas[k] = adata[adata.obs[self.split_key] == k] self.seen_covars_perts[k] = np.unique(self.adatas[k].obs[self.pert_categories_key]) self.measured_points = {} self.num_measured_points = {} for k in data_types: self.measured_points[k] = {} self.num_measured_points[k] = {} for covar_cat in self.seen_covars_perts[k]: num_points = len(np.where(self.adatas[k].obs[self.pert_categories_key] == covar_cat)[0]) self.num_measured_points[k][covar_cat] = num_points if self.experiment == 'drug': cov, pert, dose = covar_cat.split('_') else: cov, pert = covar_cat.split('_') dose = '+'.join(['1.0' for _ in pert.split('+')]) if not ('+' in dose): dose = float(dose) if cov in self.measured_points[k].keys(): if pert in self.measured_points[k][cov].keys(): self.measured_points[k][cov][pert].append(dose) else: self.measured_points[k][cov][pert] = [dose] else: self.measured_points[k][cov] = {pert: [dose]} self.measured_points['all'] = copy.deepcopy(self.measured_points['train']) for cov in self.measured_points['ood'].keys(): for covar_cat in self.measured_points['ood'][cov].keys(): if covar_cat in self.measured_points['train'][cov].keys(): self.measured_points['all'][cov][covar_cat] = \ self.measured_points['train'][cov][covar_cat].copy() + \ self.measured_points['ood'][cov][covar_cat].copy() else: self.measured_points['all'][cov][covar_cat] = \ self.measured_points['ood'][cov][covar_cat].copy()
[docs] @torch.no_grad() def get_pert_embeddings(self, dose=1.0): """ Parameters ---------- dose : int (default: 1.0) Dose at which to evaluate latent embedding vector. Returns ------- If return_anndata is True, returns anndata object. Otherwise, doesn't return anything. Always saves embeddding in self.emb_perts. """ return self.model.get_pert_embeddings(dose)
[docs] def get_covars_embeddings(self, covariate: str): """ Parameters ---------- covariate: str covariate column name in adata.obs dataframe Returns ------- If return_anndata is True, returns anndata object. Otherwise, doesn't return anything. Always saves embeddding in self.emb_covars. """ return self.model.get_covar_embeddings(covariate)
[docs] def get_drug_encoding_(self, drugs, doses=None): """ Parameters ---------- drugs : str Drugs combination as a string, where individual drugs are separated with a plus. doses : str, optional (default: None) Doses corresponding to the drugs combination as a string. Individual drugs are separated with a plus. Returns ------- One hot encodding for a mixture of drugs. """ cell_drugs = np.isin(self.unique_perts, drugs.split('+')) if doses is not None: doses = np.array(doses.split("+")).astype(float) else: doses = np.ones([1, len(drugs.split('+'))]).astype(float) drug_mix = np.zeros([1, self.num_drugs]).astype(float) drug_mix[0, cell_drugs] = doses # drug_mix = np.zeros([1, self.num_drugs]) # atomic_drugs = drugs.split('+') # doses = str(doses) # # if doses is None: # doses_list = [1.0] * len(atomic_drugs) # else: # doses_list = [float(d) for d in str(doses).split('+')] # for j, drug in enumerate(atomic_drugs): # drug_mix += doses_list[j] * self.perts_dict[drug] # return drug_mix
[docs] def mix_drugs(self, drugs_list, doses_list=None): """ Gets a list of drugs combinations to mix, e.g. ['A+B', 'B+C'] and corresponding doses. Parameters ---------- drugs_list : list List of drug combinations, where each drug combination is a string. Individual drugs in the combination are separated with a plus. doses_list : str, optional (default: None) List of corresponding doses, where each dose combination is a string. Individual doses in the combination are separated with a plus. Returns ------- If return_anndata is True, returns anndata structure of the combinations, otherwise returns a np.array of corresponding drug. """ drug_mix = np.zeros([len(drugs_list), self.num_drugs]) for i, drug_combo in enumerate(drugs_list): drug_mix[i] = self.get_drug_encoding_(drug_combo, doses=doses_list[i]) emb = self.model.get_pert_embeddings(torch.Tensor(drug_mix).to( self.model.device)).cpu().clone().detach().numpy() adata = sc.AnnData(emb) adata.obs[self.perturbation_key] = drugs_list adata.obs[self.dose_key] = doses_list return adata
[docs] def latent_dose_response(self, perturbations=None, dose=None, contvar_min=0, contvar_max=1, n_points=100): """ Parameters ---------- perturbations : list List containing two names for which to return complete pairwise dose-response. doses : np.array (default: None) Doses values. If None, default values will be generated on a grid: n_points in range [contvar_min, contvar_max]. contvar_min : float (default: 0) Minimum dose value to generate for default option. contvar_max : float (default: 0) Maximum dose value to generate for default option. n_points : int (default: 100) Number of dose points to generate for default option. Returns ------- pd.DataFrame """ # dosers work only for atomic drugs. TODO add drug combinations if perturbations is None: perturbations = self.unique_perts if dose is None: dose = np.linspace(contvar_min, contvar_max, n_points) n_points = len(dose) df = pd.DataFrame(columns=[self.perturbation_key, self.dose_key, 'response']) for drug in perturbations: d = self.perts_dict[drug] this_drug = torch.Tensor(dose).to(self.model.device).view(-1, 1) if self.model.module.doser_type == 'mlp': response = (self.model.module.pert_network.dosers[d](this_drug).sigmoid() * this_drug.gt( 0)).cpu().clone().detach().numpy().reshape(-1) else: response = self.model.module.pert_network.dosers.one_drug(this_drug.view(-1), d).cpu().clone().detach().numpy().reshape(-1) df_drug = pd.DataFrame(list(zip([drug] * n_points, dose, list(response))), columns=[self.perturbation_key, self.dose_key, 'response']) df = pd.concat([df, df_drug], ignore_index=True) return df
[docs] def latent_dose_response2D(self, perturbations, dose=None, contvar_min=0, contvar_max=1, n_points=100, ): """ Parameters ---------- perturbations : list, optional (default: None) List of atomic drugs for which to return latent dose response. Currently drug combinations are not supported. doses : np.array (default: None) Doses values. If None, default values will be generated on a grid: n_points in range [contvar_min, contvar_max]. contvar_min : float (default: 0) Minimum dose value to generate for default option. contvar_max : float (default: 0) Maximum dose value to generate for default option. n_points : int (default: 100) Number of dose points to generate for default option. Returns ------- pd.DataFrame """ # dosers work only for atomic drugs. TODO add drug combinations assert len(perturbations) == 2, "You should provide a list of 2 perturbations." if dose is None: dose = np.linspace(contvar_min, contvar_max, n_points) n_points = len(dose) df = pd.DataFrame(columns=perturbations + ['response']) response = {} for drug in perturbations: d = self.perts_dict[drug] this_drug = torch.Tensor(dose).to(self.model.device).view(-1, 1) if self.model.module.doser_type == 'mlp': response[drug] = (self.model.module.pert_network.dosers[d](this_drug).sigmoid() * this_drug.gt( 0)).cpu().clone().detach().numpy().reshape(-1) else: response[drug] = self.model.module.pert_network.dosers.one_drug(this_drug.view(-1), d).cpu().clone().detach().numpy().reshape( -1) l = 0 for i in range(len(dose)): for j in range(len(dose)): df.loc[l] = [dose[i], dose[j], response[perturbations[0]][i] + response[perturbations[1]][j]] l += 1 return df
[docs] def compute_comb_emb(self, thrh=30): """ Generates an AnnData object containing all the latent vectors of the cov+dose*pert combinations seen during training. Called in api.compute_uncertainty(), stores the AnnData in self.comb_emb. Parameters ---------- Returns ------- """ if self.seen_covars_perts['train'] is None: raise ValueError('Need to run parse_training_conditions() first!') emb_covars = None for cov in self.unique_covars.keys(): emb_covars = self.get_covars_embeddings(cov) if emb_covars is None else emb_covars.concatenate( self.get_covars_embeddings(cov)) # Generate adata with all cov+pert latent vect combinations tmp_ad_list = [] for cov_pert in self.seen_covars_perts['train']: if self.num_measured_points['train'][cov_pert] > thrh: *covs_loop, pert_loop, dose_loop = cov_pert.split('_') if len(pert_loop.split('+')) > 1: # combination of drugs drugs_doses = self.get_drug_encoding_(pert_loop, dose_loop) emb_perts = self.model.get_pert_embeddings(drugs_doses) emb_covs = emb_perts else: emb_perts = self.get_pert_embeddings(dose=float(dose_loop)) emb_covs = emb_perts.X[emb_perts.obs.condition == pert_loop] for cov_value in covs_loop: for cov in self.unique_covars.keys(): if cov_value in self.unique_covars[cov]: emb = emb_covars[emb_covars.obs[cov] == cov_value].X emb_covs = emb_covs + emb if emb_covs is not None else emb X = emb_covs tmp_ad = sc.AnnData( X=X ) tmp_ad.obs['cov_pert'] = '_'.join([*covs_loop, pert_loop, dose_loop]) tmp_ad_list.append(tmp_ad) self.comb_emb = tmp_ad_list[0].concatenate(tmp_ad_list[1:])
[docs] def compute_uncertainty( self, covs, pert, dose, thrh=30 ): """ Compute uncertainties for the queried covariate+perturbation combination. The distance from the closest condition in the training set is used as a proxy for uncertainty. Parameters ---------- covs: list of strings Covariates (eg. cell_type) for the queried uncertainty pert: string Perturbation for the queried uncertainty. In case of combinations the format has to be 'pertA+pertB' dose: string String which contains the dose of the perturbation queried. In case of combinations the format has to be 'doseA+doseB' Returns ------- min_cos_dist: float Minimum cosine distance with the training set. min_eucl_dist: float Minimum euclidean distance with the training set. closest_cond_cos: string Closest training condition wrt cosine distances. closest_cond_eucl: string Closest training condition wrt euclidean distances. """ if self.comb_emb is None: self.compute_comb_emb(thrh=30) if len(str(dose).split('+')) > 1: drug_encoded = self.get_drug_encoding_(pert, dose) drug_emb = self.model.get_pert_embeddings(drug_encoded) else: drug_emb = self.model.get_pert_embeddings(dosage=float(dose), pert=pert) cond_emb = drug_emb for cov in covs: for cov_col, cov_col_values in self.unique_covars.items(): if cov in cov_col_values: cond_emb += self.model.get_covar_embeddings(cov_col, cov).reshape(-1, ) break cos_dist = cosine_distances(cond_emb, self.comb_emb.X)[0] min_cos_dist = np.min(cos_dist) cos_idx = np.argmin(cos_dist) closest_cond_cos = self.comb_emb.obs.cov_pert[cos_idx] eucl_dist = euclidean_distances(cond_emb, self.comb_emb.X)[0] min_eucl_dist = np.min(eucl_dist) eucl_idx = np.argmin(eucl_dist) closest_cond_eucl = self.comb_emb.obs.cov_pert[eucl_idx] return min_cos_dist, min_eucl_dist, closest_cond_cos, closest_cond_eucl
[docs] def predict( self, genes, df, uncertainty=True, sample=False, n_samples=10 ): """Predict values of control 'genes' conditions specified in df. Parameters ---------- genes : np.array Control cells. df : pd.DataFrame Values for perturbations and covariates to generate. uncertainty: bool (default: True) Compute uncertainties for the generated cells. sample : bool (default: False) If sample is True, returns samples from gausssian distribution with mean and variance estimated by the model. Otherwise, returns just means and variances estimated by the model. n_samples : int (default: 10) Number of samples to sample if sampling is True. Returns ------- If return_anndata is True, returns anndata structure. Otherwise, returns np.arrays for gene_means, gene_vars and a data frame for the corresponding conditions df_obs. """ num = genes.shape[0] dim = genes.shape[1] if sample: print('Careful! These are sampled values! Better use means and \ variances for downstream tasks!') gene_means_list = [] gene_vars_list = [] df_list = [] for i in range(len(df)): comb_name = df.loc[i][self.perturbation_key] dose_name = df.loc[i][self.dose_key] covars_name = list(df.loc[i][self.covars_key]) feed_adata = AnnData(X=genes, obs={self.perturbation_key: [comb_name] * num, self.dose_key: [dose_name] * num, self.control_key: [0] * num, }) feed_adata.obsm['drugs_doses'] = self.get_drug_encoding_(comb_name, dose_name).repeat(num, axis=0) for idx, covar in enumerate(covars_name): feed_adata.obs[self.covars_key[idx]] = [covar] * num pred_adata_mean, pred_adata_var = self.model.predict(feed_adata) gene_means_list.append(pred_adata_mean.X) gene_vars_list.append(pred_adata_var.X) if sample: df_list.append( pd.DataFrame( [df.loc[i].values] * num * n_samples, columns=df.columns ) ) dist = torch.distributions.normal.Normal( torch.Tensor(pred_adata_mean.X), torch.Tensor(pred_adata_var.X), ) gene_means_list.append( dist.sample(torch.Size([n_samples])) .cpu() .detach() .numpy() .reshape(-1, dim) ) else: df_list.append( pd.DataFrame( [df.loc[i].values] * num, columns=df.columns ) ) if uncertainty: cos_dist, eucl_dist, closest_cond_cos, closest_cond_eucl = \ self.compute_uncertainty( covs=covars_name, pert=comb_name, dose=dose_name ) df_list[-1] = df_list[-1].assign( uncertainty_cosine=cos_dist, uncertainty_euclidean=eucl_dist, closest_cond_cosine=closest_cond_cos, closest_cond_euclidean=closest_cond_eucl ) gene_means = np.concatenate(gene_means_list) gene_vars = np.concatenate(gene_vars_list) df_obs = pd.concat(df_list) del df_list, gene_means_list, gene_vars_list adata = sc.AnnData(gene_means) adata.var_names = self.var_names adata.obs = df_obs if not sample: adata.layers["variance"] = gene_vars adata.obs.index = adata.obs.index.astype(str) # type fix return adata
[docs] def get_response( self, doses=None, contvar_min=None, contvar_max=None, n_points=50, ncells_max=100, perturbations=None, control_name='test_control' ): """Decoded dose response data frame. Parameters ---------- dataset : CompPertDataset The file location of the spreadsheet doses : np.array (default: None) Doses values. If None, default values will be generated on a grid: n_points in range [contvar_min, contvar_max]. contvar_min : float (default: 0) Minimum dose value to generate for default option. contvar_max : float (default: 0) Maximum dose value to generate for default option. n_points : int (default: 100) Number of dose points to generate for default option. perturbations : list (default: None) List of perturbations for dose response Returns ------- pd.DataFrame of decoded response values of genes and average response. """ if contvar_min is None: contvar_min = 0.0 if contvar_max is None: contvar_max = 1.0 # doses = torch.Tensor(np.linspace(contvar_min, contvar_max, n_points)) if doses is None: doses = np.linspace(contvar_min, contvar_max, n_points) if perturbations is None: perturbations = self.unique_perts response = pd.DataFrame(columns=[*self.covars_key, self.perturbation_key, self.dose_key, 'response'] + list(self.var_names)) split, control_value = control_name.split("_") i = 0 for ict, ct in enumerate(self.unique_covars['cell_type']): genes_control = self.adatas[split][ self.adatas[split].obs[self.control_key] == 1 if control_value == 'control' else 0] genes_control = genes_control[genes_control.obs[self.covars_key[0]] == ct] if len(genes_control) < 1: print('Warning! Not enought control cells for this covariate.\ Taking control cells from all covariates.') # genes_control = datasets[control_name].genes if ncells_max < len(genes_control): ncells_max = min(ncells_max, len(genes_control)) idx = np.random.choice(range(len(genes_control)), ncells_max, replace=False) genes_control = genes_control[idx] num, dim = genes_control.shape[0], genes_control.shape[1] if not isinstance(genes_control.X, np.ndarray): control_avg = genes_control.X.toarray().mean(axis=0).reshape(-1) else: control_avg = genes_control.X.mean(axis=0).reshape(-1) for idr, drug in enumerate(perturbations): # if not (drug in genes_control.obs[self.perturbation_key].unique()): for dose in doses: df = pd.DataFrame(data={self.covars_key[0]: [ct], self.perturbation_key: [drug], self.dose_key: [str(dose)]}) gene_means_adata = self.predict(genes_control.X, df) predicted_data = np.mean(gene_means_adata.X, axis=0).reshape(-1) response.loc[i] = [ct, drug, dose, np.linalg.norm(predicted_data - control_avg)] + \ list(predicted_data - control_avg) i += 1 return response
[docs] def get_response_reference( self, perturbations=None ): """Computes reference values of the response. Parameters ---------- perturbations : list (default: None) List of perturbations for dose response Returns ------- pd.DataFrame of decoded response values of genes and average response. """ if perturbations is None: perturbations = self.unique_perts reference_response_curve = pd.DataFrame(columns=[*self.covars_key, self.perturbation_key, self.dose_key, 'split', 'num_cells', 'response'] + \ list(self.var_names)) dataset_ctr = self.adatas['train'][self.adatas['train'].obs[self.control_key] == 1] dataset_trt = self.adatas['train'][self.adatas['train'].obs[self.control_key] == 0] self.adatas['training_control'] = dataset_ctr self.adatas['training_treated'] = dataset_trt self.seen_covars_perts['training_control'] = np.unique( self.adatas['training_control'].obs[self.pert_categories_key]) self.seen_covars_perts['training_treated'] = np.unique( self.adatas['training_treated'].obs[self.pert_categories_key]) i = 0 for split in ['training_treated', 'ood']: dataset = self.adatas[split] for pert in self.seen_covars_perts[split]: ct, drug, dose_val = pert.split('_') if drug in perturbations: if not ('+' in dose_val): dose = float(dose_val) else: dose = dose_val genes_control = dataset_ctr[dataset_ctr.obs[self.covars_key[0]] == ct].X if not isinstance(genes_control, np.ndarray): genes_control = genes_control.toarray() if len(genes_control) < 1: print('Warning! Not enough control cells for this covariate. \ Taking control cells from all covariates.') genes_control = dataset_ctr.X num, dim = genes_control.shape[0], genes_control.shape[1] control_avg = genes_control.mean(axis=0).reshape(-1) idx = np.where((dataset.obs[self.pert_categories_key] == pert))[0] if len(idx): if not isinstance(dataset.X, np.ndarray): y_true = dataset.X.toarray()[idx, :].mean(axis=0) else: y_true = dataset.X[idx, :].mean(axis=0) new_row = [ct, drug, dose, split, len(idx), np.linalg.norm(y_true - control_avg)] \ + list(y_true - control_avg) reference_response_curve.loc[i] = new_row i += 1 return reference_response_curve
[docs] def get_response2D( self, perturbations, covar, doses=None, contvar_min=None, contvar_max=None, n_points=10, ncells_max=100, fixed_drugs='', fixed_doses='' ): """Decoded dose response data frame. Parameters ---------- perturbations : list List of length 2 of perturbations for dose response. covar : str Name of a covariate for which to compute dose-response. doses : np.array (default: None) Doses values. If None, default values will be generated on a grid: n_points in range [contvar_min, contvar_max]. contvar_min : float (default: 0) Minimum dose value to generate for default option. contvar_max : float (default: 0) Maximum dose value to generate for default option. n_points : int (default: 100) Number of dose points to generate for default option. Returns ------- pd.DataFrame of decoded response values of genes and average response. """ assert len(perturbations) == 2, "You should provide a list of 2 perturbations." if contvar_min is None: contvar_min = 0.0 if contvar_max is None: contvar_max = 1.0 # doses = torch.Tensor(np.linspace(contvar_min, contvar_max, n_points)) if doses is None: doses = np.linspace(contvar_min, contvar_max, n_points) test_control_adata = self.adatas['test'][self.adatas['test'].obs[self.control_key] == 1] genes_control = test_control_adata[test_control_adata.obs[self.covars_key[0]] == covar] if len(genes_control) < 1: print('Warning! Not enought control cells for this covariate. \ Taking control cells from all covariates.') ncells_max = min(ncells_max, len(genes_control)) idx = np.random.choice(range(len(genes_control)), ncells_max) genes_control = genes_control[idx] num, dim = genes_control.shape[0], genes_control.shape[1] if not isinstance(genes_control.X, np.ndarray): control_avg = genes_control.X.toarray().mean(0).reshape(-1) else: control_avg = genes_control.X.mean(0).reshape(-1) response = pd.DataFrame(columns=perturbations + ['response'] + list(self.var_names)) drug = perturbations[0] + '+' + perturbations[1] dose_vals = [f"{d[0]}+{d[1]}" for d in itertools.product(*[doses, doses])] dose_comb = [list(d) for d in itertools.product(*[doses, doses])] i = 0 if not (drug in ['Vehicle', 'EGF', 'unst', 'control', 'ctrl']): for dose in dose_vals: df = pd.DataFrame(data={self.perturbation_key: [drug + fixed_drugs], self.dose_key: [dose + fixed_doses], self.covars_key[0]: [covar]}) gene_means_adata = self.predict(genes_control.X, df) predicted_data = np.mean(gene_means_adata.X, axis=0).reshape(-1) response.loc[i] = [*dose_comb[i], np.linalg.norm(control_avg - predicted_data)] + \ list(predicted_data - control_avg) i += 1 return response
[docs] def get_cycle_uncertainty( self, genes_from, df_from, df_to, ncells_max=100, direction='forward' ): """Uncertainty for a single condition. Parameters ---------- genes_from: torch.Tensor Genes for comparison. df_from: pd.DataFrame Full description of the condition. df_to: pd.DataFrame Full description of the control condition. ncells_max: int, optional (defaul: 100) Max number of cells to use. Returns ------- tuple with uncertainty estimations: (MSE, 1-R2). """ self.model.eval() genes_control = genes_from.clone().detach() if ncells_max < len(genes_control): idx = torch.LongTensor(np.random.choice(range(len(genes_control)), \ ncells_max, replace=False)) genes_control = genes_control[idx] gene_condition, _, _ = self.predict(genes_control, df_to, \ return_anndata=False, sample=False) gene_condition = torch.Tensor(gene_condition).clone().detach() gene_return, _, _ = self.predict(gene_condition, df_from, \ return_anndata=False, sample=False) if direction == 'forward': # control -> condition -> control' genes_control = genes_control.numpy() ctr = np.mean(genes_control, axis=0) ret = np.mean(gene_return, axis=0) return np.mean((genes_control - gene_return) ** 2), 1 - r2_score(ctr, ret) else: # control -> condition -> control' -> condition' gene_return = torch.Tensor(gene_return).clone().detach() gene_condition_return, _, _ = self.predict(gene_return, df_to, \ return_anndata=False, sample=False) gene_condition = gene_condition.numpy() ctr = np.mean(gene_condition, axis=0) ret = np.mean(gene_condition_return, axis=0) return np.mean((gene_condition - gene_condition_return) ** 2), \ 1 - r2_score(ctr, ret)
[docs] def print_complete_cycle_uncertainty( self, datasets, datasets_ctr, ncells_max=1000, split_list=['test', 'ood'], direction='forward' ): uncert = pd.DataFrame(columns=[self.covars_key, self.perturbation_key, self.dose_key, 'split', 'MSE', '1-R2']) ctr_covar, ctrl_name, ctr_dose = datasets_ctr.pert_categories[0].split('_') df_ctrl = pd.DataFrame({self.perturbation_key: [ctrl_name], self.dose_key: [ctr_dose], self.covars_key: [ctr_covar]}) i = 0 for split in split_list: dataset = datasets[split] print(split) for pert_cat in np.unique(dataset.pert_categories): idx = np.where(dataset.pert_categories == pert_cat)[0] genes = dataset.genes[idx, :] covar, pert, dose = pert_cat.split('_') df_cond = pd.DataFrame({self.perturbation_key: [pert], self.dose_key: [dose], self.covars_key: [covar]}) if direction == 'back': # condition -> control -> condition uncert.loc[i] = [covar, pert, dose, split] + \ list(self.get_cycle_uncertainty(genes, df_cond, \ df_ctrl, ncells_max=ncells_max)) else: # control -> condition -> control uncert.loc[i] = [covar, pert, dose, split] + \ list(self.get_cycle_uncertainty(datasets_ctr.genes, \ df_ctrl, df_cond, ncells_max=ncells_max, \ direction=direction)) i += 1 return uncert
[docs] def evaluate_r2( self, perturbations=None, control_adata_key: str = 'test', ): """ Measures different quality metrics about an ComPert `autoencoder`, when tasked to translate some `genes_control` into each of the drug/cell_type combinations described in `dataset`. Considered metrics are R2 score about means and variances for all genes, as well as R2 score about means and variances about differentially expressed (_de) genes. """ if perturbations is None: perturbations = self.unique_perts scores = pd.DataFrame(columns=[*self.covars_key, self.perturbation_key, self.dose_key, 'R2_mean', 'R2_mean_DE', 'R2_var', 'R2_var_DE', 'num_cells']) control_adata = self.adatas[control_adata_key].copy() control_adata = control_adata[control_adata.obs[self.control_key] == 1] icond = 0 for pert_category in np.unique(self.adata.obs[self.pert_categories_key].value_counts().index): # pert_category category contains: 'celltype_perturbation_dose' info *covs, drug, dose = pert_category.split('_') if drug in perturbations: if self.de_genes: de_genes = self.de_genes[pert_category] else: de_genes = list(self.adata.var_names) true_adata = self.adata[self.adata.obs[self.pert_categories_key] == pert_category] control_adata_ct = control_adata[control_adata.obs[self.covars_key[0]] == covs[0]] feed_adata = sc.AnnData(X=control_adata_ct.X) feed_adata.var_names = control_adata_ct.var_names feed_adata.obs_names = control_adata_ct.obs_names feed_adata.obs = control_adata_ct.obs.copy() feed_adata.obs[self.perturbation_key] = drug feed_adata.obs[self.dose_key] = dose feed_adata.obsm['drugs_doses'] = self.get_drug_encoding_(drug, dose).repeat(feed_adata.n_obs, axis=0) de_idx = np.where(self.adata.var_names.isin(np.array(de_genes)))[0] if len(true_adata) > 0: pred_mean_adata, pred_var_adata = self.model.predict(feed_adata, batch_size=512) # estimate metrics only for reasonably-sized drug/cell-type combos y_true = true_adata.X if not isinstance(y_true, np.ndarray): y_true = y_true.toarray() # true means and variances yt_m = y_true.mean(axis=0) yt_v = y_true.var(axis=0) # predicted means and variances yp_m = pred_mean_adata.X.mean(0) yp_v = pred_var_adata.X.mean(0) mean_score = r2_score(yt_m, yp_m) var_score = r2_score(yt_v, yp_v) mean_score_de = r2_score(yt_m[de_idx], yp_m[de_idx]) var_score_de = r2_score(yt_v[de_idx], yp_v[de_idx]) scores.loc[icond] = pert_category.split('_') + \ [mean_score, mean_score_de, var_score, var_score_de, y_true.shape[0]] icond += 1 return scores
def get_reference_from_combo( perturbations_list, datasets, splits=['training', 'ood'] ): """ A simple function that produces a pd.DataFrame of individual drugs-doses combinations used among the splits (for a fixed covariate). """ df_list = [] for split_name in splits: full_dataset = datasets[split_name] ref = {'num_cells': []} for pp in perturbations_list: ref[pp] = [] ndrugs = len(perturbations_list) for pert_cat in np.unique(full_dataset.pert_categories): _, pert, dose = pert_cat.split('_') pert_list = pert.split('+') if set(pert_list) == set(perturbations_list): dose_list = dose.split('+') ncells = len(full_dataset.pert_categories[ full_dataset.pert_categories == pert_cat]) for j in range(ndrugs): ref[pert_list[j]].append(float(dose_list[j])) ref['num_cells'].append(ncells) print(pert, dose, ncells) df = pd.DataFrame.from_dict(ref) df['split'] = split_name df_list.append(df) return pd.concat(df_list) def linear_interp(y1, y2, x1, x2, x): a = (y1 - y2) / (x1 - x2) b = y1 - a * x1 y = a * x + b return y def evaluate_r2_benchmark( compert_api, datasets, pert_category, pert_category_list ): scores = pd.DataFrame(columns=[compert_api.covars_key, compert_api.perturbation_key, compert_api.dose_key, 'R2_mean', 'R2_mean_DE', 'R2_var', 'R2_var_DE', 'num_cells', 'benchmark', 'method']) de_idx = np.where( datasets['ood'].var_names.isin( np.array(datasets['ood'].de_genes[pert_category])))[0] idx = np.where(datasets['ood'].pert_categories == pert_category)[0] y_true = datasets['ood'].genes[idx, :].numpy() # true means and variances yt_m = y_true.mean(axis=0) yt_v = y_true.var(axis=0) icond = 0 if len(idx) > 0: for pert_category_predict in pert_category_list: if '+' in pert_category_predict: pert1, pert2 = pert_category_predict.split('+') idx_pred1 = np.where(datasets['training'].pert_categories == \ pert1)[0] idx_pred2 = np.where(datasets['training'].pert_categories == \ pert2)[0] y_pred1 = datasets['training'].genes[idx_pred1, :].numpy() y_pred2 = datasets['training'].genes[idx_pred2, :].numpy() x1 = float(pert1.split('_')[2]) x2 = float(pert2.split('_')[2]) x = float(pert_category.split('_')[2]) yp_m1 = y_pred1.mean(axis=0) yp_m2 = y_pred2.mean(axis=0) yp_v1 = y_pred1.var(axis=0) yp_v2 = y_pred2.var(axis=0) yp_m = linear_interp(yp_m1, yp_m2, x1, x2, x) yp_v = linear_interp(yp_v1, yp_v2, x1, x2, x) # yp_m = (y_pred1.mean(axis=0) + y_pred2.mean(axis=0))/2 # yp_v = (y_pred1.var(axis=0) + y_pred2.var(axis=0))/2 else: idx_pred = np.where(datasets['training'].pert_categories == \ pert_category_predict)[0] print(pert_category_predict, len(idx_pred)) y_pred = datasets['training'].genes[idx_pred, :].numpy() # predicted means and variances yp_m = y_pred.mean(axis=0) yp_v = y_pred.var(axis=0) mean_score = r2_score(yt_m, yp_m) var_score = r2_score(yt_v, yp_v) mean_score_de = r2_score(yt_m[de_idx], yp_m[de_idx]) var_score_de = r2_score(yt_v[de_idx], yp_v[de_idx]) scores.loc[icond] = pert_category.split('_') + \ [mean_score, mean_score_de, var_score, var_score_de, \ len(idx), pert_category_predict, 'benchmark'] icond += 1 return scores