Source code for cpa._module

import numpy as np
import torch
import torch.nn as nn
from scvi import settings
from scvi.distributions import NegativeBinomial, ZeroInflatedNegativeBinomial
from scvi.module import Classifier
from scvi.module.base import BaseModuleClass, auto_move_data
from scvi.nn import Encoder, DecoderSCVI
from torch.distributions import Normal
from torch.distributions.kl import kl_divergence as kl
from torchmetrics.functional import accuracy, pearson_corrcoef, r2_score

from ._metrics import knn_purity
from ._utils import PerturbationNetwork, VanillaEncoder, CPA_REGISTRY_KEYS

from typing import Optional


[docs] class CPAModule(BaseModuleClass): """ CPA module using Gaussian/NegativeBinomial/Zero-InflatedNegativeBinomial Likelihood Parameters ---------- n_genes: int Number of input genes n_perts: int, Number of total unique perturbations covars_encoder: dict Dictionary of covariates with keys as each covariate name and values as unique values of the corresponding covariate n_latent: int dimensionality of the latent space recon_loss: str Autoencoder loss (either "gauss", "nb" or "zinb") doser_type: str Type of dosage network (either "logsigm", "sigm", or "linear") n_hidden_encoder: int Number of hidden units in encoder n_layers_encoder: int Number of layers in encoder n_hidden_decoder: int Number of hidden units in decoder n_layers_decoder: int Number of layers in decoder n_hidden_doser: int Number of hidden units in dosage network n_layers_doser: int Number of layers in dosage network use_batch_norm_encoder: bool Whether to use batch norm in encoder use_layer_norm_encoder: bool Whether to use layer norm in encoder use_batch_norm_decoder: bool Whether to use batch norm in decoder use_layer_norm_decoder: bool Whether to use layer norm in decoder dropout_rate_encoder: float Dropout rate in encoder dropout_rate_decoder: float Dropout rate in decoder variational: bool Whether to use variational inference seed: int Random seed """ def __init__(self, n_genes: int, n_perts: int, covars_encoder: dict, drug_embeddings: Optional[np.ndarray] = None, n_latent: int = 128, recon_loss: str = "nb", doser_type: str = "logsigm", n_hidden_encoder: int = 256, n_layers_encoder: int = 3, n_hidden_decoder: int = 256, n_layers_decoder: int = 3, n_hidden_doser: int = 128, n_layers_doser: int = 2, use_batch_norm_encoder: bool = True, use_layer_norm_encoder: bool = False, use_batch_norm_decoder: bool = True, use_layer_norm_decoder: bool = False, dropout_rate_encoder: float = 0.0, dropout_rate_decoder: float = 0.0, variational: bool = False, seed: int = 0, ): super().__init__() recon_loss = recon_loss.lower() assert recon_loss in ['gauss', 'zinb', 'nb'] torch.manual_seed(seed) np.random.seed(seed) settings.seed = seed self.n_genes = n_genes self.n_perts = n_perts self.n_latent = n_latent self.recon_loss = recon_loss self.doser_type = doser_type self.variational = variational self.covars_encoder = covars_encoder if variational: self.encoder = Encoder( n_genes, n_latent, var_activation=nn.Softplus(), n_hidden=n_hidden_encoder, n_layers=n_layers_encoder, use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, dropout_rate=dropout_rate_encoder, activation_fn=nn.ReLU, return_dist=True, ) else: self.encoder = VanillaEncoder( n_input=n_genes, n_output=n_latent, n_cat_list=[], n_hidden=n_hidden_encoder, n_layers=n_layers_encoder, use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, dropout_rate=dropout_rate_encoder, activation_fn=nn.ReLU, output_activation='linear', ) # Decoder components if self.recon_loss in ['zinb', 'nb']: # setup the parameters of your generative model, as well as your inference model self.px_r = torch.nn.Parameter(torch.randn(self.n_genes)) # decoder goes from n_latent-dimensional space to n_input-d data self.decoder = DecoderSCVI( n_input=n_latent, n_output=n_genes, n_layers=n_layers_decoder, n_hidden=n_hidden_decoder, use_batch_norm=use_batch_norm_decoder, use_layer_norm=use_layer_norm_decoder, ) elif recon_loss == "gauss": self.decoder = Encoder(n_input=n_latent, n_output=n_genes, n_layers=n_layers_decoder, n_hidden=n_hidden_decoder, dropout_rate=dropout_rate_decoder, use_batch_norm=use_batch_norm_decoder, use_layer_norm=use_layer_norm_decoder, var_activation=None, ) else: raise Exception('Invalid Loss function for Autoencoder') # Embeddings # 1. Drug Network self.pert_network = PerturbationNetwork(n_perts=n_perts, n_latent=n_latent, doser_type=doser_type, n_hidden=n_hidden_doser, n_layers=n_layers_doser, drug_embeddings=drug_embeddings, ) # 2. Covariates Embedding self.covars_embeddings = nn.ModuleDict( { key: torch.nn.Embedding(len(unique_covars), n_latent) for key, unique_covars in self.covars_encoder.items() } ) self.metrics = { 'pearson_r': pearson_corrcoef, 'r2_score': r2_score }
[docs] def mixup_data(self, tensors, alpha: float = 0.0, opt=False): """ Returns mixed inputs, pairs of targets, and lambda """ alpha = max(0.0, alpha) if alpha == 0.0: mixup_lambda = 1.0 else: mixup_lambda = np.random.beta(alpha, alpha) x = tensors[CPA_REGISTRY_KEYS.X_KEY] y_perturbations = tensors[CPA_REGISTRY_KEYS.PERTURBATION_KEY] perturbations = tensors[CPA_REGISTRY_KEYS.PERTURBATIONS] perturbations_dosages = tensors[CPA_REGISTRY_KEYS.PERTURBATIONS_DOSAGES] batch_size = x.size()[0] index = torch.randperm(batch_size).to(x.device) mixed_x = mixup_lambda * x + (1. - mixup_lambda) * x[index, :] tensors[CPA_REGISTRY_KEYS.X_KEY] = mixed_x tensors[CPA_REGISTRY_KEYS.X_KEY + '_true'] = x tensors[CPA_REGISTRY_KEYS.X_KEY + '_mixup'] = x[index] tensors[CPA_REGISTRY_KEYS.PERTURBATION_KEY + '_mixup'] = y_perturbations[index] tensors[CPA_REGISTRY_KEYS.PERTURBATIONS + '_mixup'] = perturbations[index] tensors[CPA_REGISTRY_KEYS.PERTURBATIONS_DOSAGES + '_mixup'] = perturbations_dosages[index] for covar, encoder in self.covars_encoder.items(): tensors[covar + '_mixup'] = tensors[covar][index] return tensors, mixup_lambda
def _get_inference_input(self, tensors): x = tensors[CPA_REGISTRY_KEYS.X_KEY] # batch_size, n_genes perts = { 'true': tensors[CPA_REGISTRY_KEYS.PERTURBATIONS], 'mixup': tensors[CPA_REGISTRY_KEYS.PERTURBATIONS + '_mixup'] } perts_doses = { 'true': tensors[CPA_REGISTRY_KEYS.PERTURBATIONS_DOSAGES], 'mixup': tensors[CPA_REGISTRY_KEYS.PERTURBATIONS_DOSAGES + '_mixup'], } covars_dict = dict() for covar, unique_covars in self.covars_encoder.items(): encoded_covars = tensors[covar].view(-1, ) # (batch_size,) encoded_covars_mixup = tensors[covar + '_mixup'].view(-1, ) # (batch_size,) covars_dict[covar] = encoded_covars covars_dict[covar + '_mixup'] = encoded_covars_mixup return dict( x=x, perts=perts, perts_doses=perts_doses, covars_dict=covars_dict, )
[docs] @auto_move_data def inference( self, x, perts, perts_doses, covars_dict, mixup_lambda: float = 1.0, n_samples: int = 1, covars_to_add: Optional[list] = None, ): batch_size = x.shape[0] if self.recon_loss in ['nb', 'zinb']: # log the input to the variational distribution for numerical stability x_ = torch.log(1 + x) library = torch.log(x.sum(1)).unsqueeze(1) else: x_ = x library = None, None if self.variational: qz, z_basal = self.encoder(x_) else: qz, z_basal = None, self.encoder(x_) if self.variational and n_samples > 1: sampled_z = qz.sample((n_samples,)) z_basal = self.encoder.z_transformation(sampled_z) if self.recon_loss in ['nb', 'zinb']: library = library.unsqueeze(0).expand( (n_samples, library.size(0), library.size(1)) ) z_pert_true = self.pert_network(perts['true'], perts_doses['true']) if mixup_lambda < 1.0: z_pert_mixup = self.pert_network(perts['mixup'], perts_doses['mixup']) z_pert = mixup_lambda * z_pert_true + (1. - mixup_lambda) * z_pert_mixup else: z_pert = z_pert_true z_covs = torch.zeros_like(z_basal) # ([n_samples,] batch_size, n_latent) z_covs_wo_batch = torch.zeros_like(z_basal) # ([n_samples,] batch_size, n_latent) batch_key = CPA_REGISTRY_KEYS.BATCH_KEY if covars_to_add is None: covars_to_add = list(self.covars_encoder.keys()) for covar, encoder in self.covars_encoder.items(): if covar in covars_to_add: z_cov = self.covars_embeddings[covar](covars_dict[covar].long()) if len(encoder) > 1: z_cov_mixup = self.covars_embeddings[covar](covars_dict[covar + '_mixup'].long()) z_cov = mixup_lambda * z_cov + (1. - mixup_lambda) * z_cov_mixup z_cov = z_cov.view(batch_size, self.n_latent) # batch_size, n_latent z_covs += z_cov if covar != batch_key: z_covs_wo_batch += z_cov z = z_basal + z_pert + z_covs z_corrected = z_basal + z_pert + z_covs_wo_batch z_no_pert = z_basal + z_covs z_no_pert_corrected = z_basal + z_covs_wo_batch return dict( z=z, z_corrected=z_corrected, z_no_pert=z_no_pert, z_no_pert_corrected=z_no_pert_corrected, z_basal=z_basal, z_covs=z_covs, z_pert=z_pert.sum(dim=1), library=library, qz=qz, mixup_lambda=mixup_lambda, )
def _get_generative_input(self, tensors, inference_outputs, **kwargs): if 'latent' in kwargs.keys(): if kwargs['latent'] in inference_outputs.keys(): # z, z_corrected, z_no_pert, z_no_pert_corrected, z_basal z = inference_outputs[kwargs['latent']] else: raise Exception('Invalid latent space') else: z = inference_outputs["z"] library = inference_outputs['library'] return dict( z=z, library=library, )
[docs] @auto_move_data def generative( self, z, library=None, ): if self.recon_loss == 'nb': px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library) px_r = torch.exp(self.px_r) px = NegativeBinomial(mu=px_rate, theta=px_r) elif self.recon_loss == 'zinb': px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library) px_r = torch.exp(self.px_r) px = ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout) else: px_mean, px_var, x_pred = self.decoder(z) px = Normal(loc=px_mean, scale=px_var.sqrt()) pz = Normal(torch.zeros_like(z), torch.ones_like(z)) return dict(px=px, pz=pz)
[docs] def loss(self, tensors, inference_outputs, generative_outputs): """Computes the reconstruction loss (AE) or the ELBO (VAE)""" x = tensors[CPA_REGISTRY_KEYS.X_KEY] px = generative_outputs['px'] recon_loss = -px.log_prob(x).sum(dim=-1).mean() if self.variational: qz = inference_outputs["qz"] pz = generative_outputs['pz'] kl_divergence_z = kl(qz, pz).sum(dim=1) kl_loss = kl_divergence_z.mean() else: from scvi.model import SCVI kl_loss = torch.zeros_like(recon_loss) return recon_loss, kl_loss
[docs] def r2_metric(self, tensors, inference_outputs, generative_outputs, mode: str = 'lfc'): mode = mode.lower() assert mode in ['direct'] x = tensors[CPA_REGISTRY_KEYS.X_KEY] # batch_size, n_genes indices = tensors[CPA_REGISTRY_KEYS.CATEGORY_KEY].view(-1,) unique_indices = indices.unique() r2_mean = 0.0 r2_var = 0.0 px = generative_outputs['px'] for ind in unique_indices: i_mask = indices == ind x_i = x[i_mask, :] if self.recon_loss == 'gauss': x_pred_mean = px.loc[i_mask, :] x_pred_var = px.scale[i_mask, :] ** 2 if CPA_REGISTRY_KEYS.DEG_MASK_R2 in tensors.keys(): deg_mask = tensors[f'{CPA_REGISTRY_KEYS.DEG_MASK_R2}'][i_mask, :] x_i *= deg_mask x_pred_mean *= deg_mask x_pred_var *= deg_mask x_pred_mean = torch.nan_to_num(x_pred_mean, nan=0, posinf=1e3, neginf=-1e3) x_pred_var = torch.nan_to_num(x_pred_var, nan=0, posinf=1e3, neginf=-1e3) r2_mean += torch.nan_to_num(self.metrics['r2_score'](x_pred_mean.mean(0), x_i.mean(0)), nan=0.0).item() r2_var += torch.nan_to_num(self.metrics['r2_score'](x_pred_var.mean(0), x_i.var(0)), nan=0.0).item() elif self.recon_loss in ['nb', 'zinb']: x_i = torch.log(1 + x_i) x_pred = px.mu[i_mask, :] x_pred = torch.log(1 + x_pred) x_pred = torch.nan_to_num(x_pred, nan=0, posinf=1e3, neginf=-1e3) if CPA_REGISTRY_KEYS.DEG_MASK_R2 in tensors.keys(): deg_mask = tensors[f'{CPA_REGISTRY_KEYS.DEG_MASK_R2}'][i_mask, :] x_i *= deg_mask x_pred *= deg_mask r2_mean += torch.nan_to_num(self.metrics['r2_score'](x_pred.mean(0), x_i.mean(0)), nan=0.0).item() r2_var += torch.nan_to_num(self.metrics['r2_score'](x_pred.var(0), x_i.var(0)), nan=0.0).item() n_unique_indices = len(unique_indices) return r2_mean / n_unique_indices, r2_var / n_unique_indices
[docs] def disentanglement(self, tensors, inference_outputs, generative_outputs, linear=True): z_basal = inference_outputs['z_basal'].detach().cpu().numpy() z = inference_outputs['z'].detach().cpu().numpy() perturbations = tensors[CPA_REGISTRY_KEYS.PERTURBATION_KEY].view(-1, ) perturbations_names = perturbations.detach().cpu().numpy() knn_basal = knn_purity(z_basal, perturbations_names.ravel(), n_neighbors=min(perturbations_names.shape[0] - 1, 30)) knn_after = knn_purity(z, perturbations_names.ravel(), n_neighbors=min(perturbations_names.shape[0] - 1, 30)) for covar, unique_covars in self.covars_encoder.items(): if len(unique_covars) > 1: target_covars = tensors[f'{covar}'].detach().cpu().numpy() knn_basal += knn_purity(z_basal, target_covars.ravel(), n_neighbors=min(target_covars.shape[0] - 1, 30)) knn_after += knn_purity(z, target_covars.ravel(), n_neighbors=min(target_covars.shape[0] - 1, 30)) return knn_basal, knn_after
[docs] def get_expression(self, tensors, n_samples=1, covars_to_add=None, latent='z'): """Computes gene expression means and std. Only implemented for the gaussian likelihood. Parameters ---------- tensors : dict Considered inputs """ tensors, _ = self.mixup_data(tensors, alpha=0.0) inference_outputs, generative_outputs = self.forward( tensors, inference_kwargs={'n_samples': n_samples, 'covars_to_add': covars_to_add}, get_generative_input_kwargs={'latent': latent}, compute_loss=False, ) z = inference_outputs['z'] z_corrected = inference_outputs['z_corrected'] z_no_pert = inference_outputs['z_no_pert'] z_no_pert_corrected = inference_outputs['z_no_pert_corrected'] z_basal = inference_outputs['z_basal'] px = generative_outputs['px'] if self.recon_loss == 'gauss': output_key = 'loc' else: output_key = 'mu' reconstruction = getattr(px, output_key) return dict( px=reconstruction, z=z, z_corrected=z_corrected, z_no_pert=z_no_pert, z_no_pert_corrected=z_no_pert_corrected, z_basal=z_basal, )
[docs] def get_pert_embeddings(self, tensors, **inference_kwargs): inputs = self._get_inference_input(tensors) drugs = inputs['perts'] doses = inputs['perts_doses'] return self.pert_network(drugs, doses)