Close this search box.

GRouNdGAN: GRN-guided simulation of single-cell RNA-seq data using causal generative adversarial networks – Nature Communications

Datasets and preprocessing

We downloaded the human peripheral blood mononuclear cell (PBMC Donor A) dataset containing the single-cell gene expression profiles of 68579 PBMCs represented by UMI counts from the 10x Genomics ( This dataset contains a large number of well-annotated cells and has been used by other models to generate synthetic scRNA-seq samples24. We also downloaded the scRNA-seq profile of 2730 cells corresponding to differentiation of hematopoietic stem cells to different lineages from mouse bone marrow34 (“BoneMarrow” dataset) from Gene Expression Omnibus (GEO) (accession number: GSE72857). We also obtained another haematopoietic dataset corresponding to the scRNA-seq (10x Genomics) profiles of 44,802 mouse bone marrow hematopoietic stem and progenitor cells (HSPCs) differentiating towards different lineages from GEO (accession number: GSE107727)46 (called Dahlin dataset here). Finally, we obtained the batch-corrected scRNA-seq (10x Genomics) profile of 136,147 cells corresponding to malignant cells as well as cells in the tumor microenvironment (called Tumor-All dataset here) from 20 fresh core needle biopsies of follicular lymphoma patients from (

We followed pre-processing steps similar to those of scGAN and cscGAN24 using scanpy version 1.8.245. In each dataset, cells with nonzero counts in less than ten genes were removed. Similarly, genes that only had nonzero counts in less than three cells were discarded. Top 1000 highly variable genes were selected using the dispersion-based method described by Satija et al.54. Finally, library-size normalization was performed on the counts per cell with a library size equal to 20,000 in order to be consistent with previous studies24. See Supplementary Data 1 – Sheet 3 for the number of cells and highly variable genes and TFs present in each final dataset.

GRouNdGAN’s model architecture

GRouNdGAN’s architecture consists of 5 components, each implemented using separately parameterized neural networks: a causal controller, target generators, a critic, a labeler, and an anti-labeler (Fig. 1C).

Causal controller

The role of causal controller is to generate the expression of TFs that causally control the expression of their target genes based on a user-defined gene regulatory network (Fig. 1C). To achieve this, it is first pre-trained (Fig. 1B) as the generator of a Wasserstein GAN with gradient penalty (WGAN-GP) (see Supplementary Notes for the formulation of the Wasserstein distance). GANs are a class of deep learning models that can learn to simulate non-parametric distributions55. They typically involve simultaneously training a generative model (called the “generator”) that produces new samples from noise, and its adversary, a discriminative model that tries to distinguish between real and generated samples (called the “discriminator”). The generator’s goal is to generate samples so realistic that the discriminator cannot determine whether it is real or simulated (an accuracy value close to 0.5). Through adversarial training, the generator and discriminator receive feedback, allowing them to co-evolve in a symbiotic manner.

The main difference between a WGAN and a traditional GAN is that in the former, a Wasserstein distance is used to quantify the similarity between the probability distribution of the real data and the generator’s produced data (instead of Kullback–Leibler or Jensen–Shannon divergences). Wasserstein distance has been shown to stabilize the training of WGAN without causing mode collapse32. The detailed formulation of the Wasserstein distance used as the loss function in this study is provided in Supplementary Notes. In addition, instead of a discriminator, WGAN uses a “critic” that estimates the Wasserstein distance between real and generated data distributions. In our model, we added a gradient penalty term for the critic (proposed by Gulrajani et al.56 as an alternative to weight clipping used in the original WGAN) in order to overcome vanishing/exploding gradients and capacity underuse issues.

In the pretraining step, we trained a WGAN-GP with a generator (containing an input layer, three hidden layers, and an output layer), a library-size normalization (LSN) layer24, and a critic (containing an input layer, three hidden layers, and an output node). A noise vector of length 128, with independent and identically distributed elements following a standard Gaussian distribution, was used as the input to the generator. The output of the generator was then fed into the LSN layer to generate the gene and TF expression values. The details of hyperparameters and architectural choices of this WGAN-GP are provided in Supplementary Table 1. Although we were only interested in generating expression of TFs using the generator of this WGAN-GP (in the second step of pipeline), the model was trained using all genes and TFs to properly enforce the library-size normalization. Once trained, we discarded the critic and the LSN layer, froze the weights of the generator and used it as the “causal controller”31 to generate expression of TFs (Fig. 1C).

Target generators

The role of target generators is to generate the expression of genes causally regulated by TFs based on the topology of a GRN. Consider a target gene (Gj) regulated by a set of ({TFs},:) ({T{F}_{1},T{F}_{2},ldots,T{F}_{n}}). Under the causal sufficiency assumption and as a result of the manner by which TFs’ expressions are generated from independent noise variables, we can write ({E}_{{Gj}}={f}_{{Gj}}left({E}_{{TF}1}, , {E}_{{TF}2},ldots,,{E}_{{TFn}},,{N}_{{Gj}}right)) and ({E}_{{TFi}}={f}_{{TFi}}left({N}_{{TFi}}right)) for (i=1,,2,,ldots,, {n}), where (E) represents expression, (N) represents a noise variable, and (f) represents a function (to be approximated using neural networks). All noise variables are jointly independent. Following the theoretical and empirical results of CausalGAN31, we can use feedforward neural networks to represent functions (f) by making the generator inherit its neural network connections from the causal GRN. To achieve this, we generate each gene in the GRN by a separate generator such that target gene generators do not share any neural connections (Fig. 1C).

As input, the generator of each gene accepts a vector formed by concatenating a noise variable and a vector of non-library size normalized TF expressions from the causal controller, corresponding to the TFs that regulate the gene in the imposed GRN (i.e., its parents in the graph). The expression values of TFs and the generated values of target genes are arranged into a vector, which is passed to an LSN layer for normalization. We used target generators with three hidden layers of equal width. The width of the hidden layers of a generator is dynamically set as twice the number of its regulators (the noise variable and the set of regulating TFs). If the imposed GRN is relatively dense and contains more than 5000 edges, we set the depth to 2 and the width multiplier to 1 to be able to train on a single GPU. The details of hyperparameters and architectural choices are provided in Supplementary Table 2.

In practice, generating each target gene’s expression using separate neural networks introduces excess overhead. This is because in every forward pass, all target genes’ expressions must be first generated before collectively being sent into the LSN layer. As a result, instead of parallelizing the generation of each target gene’s expression (which due to the bottleneck above does not provide a significant computational benefit), we implemented target generators using a single large sparse network. This allows us to reduce the overhead and the training time to train the model on a single GPU, and to benefit from GPU’s large matrix multiplication. We mask weights and gradients to follow the causal graph, while keeping the generation of genes independent from each other. From a logical standpoint, our implementation has the same architecture described earlier, but is significantly more computationally efficient. See Supplementary Notes for details.


Similar to a traditional WGAN, the objective of GRouNdGAN’s critic (Fig. 1C) is to estimate the Wasserstein distance between real and generated data distributions. We used the same critic architecture as the WGAN-GP trained in the first stage.

Labeler and anti-labeler

Although the main role of the target generators is to produce realistic cells to confuse the critic, it is crucial that they rely on the TFs’ expression (in addition to noise) in doing so. One potential risk is that the target generators disregard the expression of TFs and solely rely on the noise variables. This is particularly probable when the imposed GRN does not conform to the underlying gene expression programs of real cells in the reference dataset; in such a scenario, and to make realistically looking simulated cells, it is more convenient for a WGAN to simply ignore the strict constraints of the GRN and solely rely on noise.

To overcome this issue, we used auxiliary tasks and neural networks known as “labeler” and “anti-labeler”31. The task of these two networks is to estimate the causal controller’s TF expressions (here called labels) from the target genes’ expressions alone, by minimizing the squared L2 norm between each element’s TF estimates and their true value. More specifically, the corresponding loss for a batch of size ({N}_{{batch}}) is of the form (frac{1}{{N}_{{batch}}}{sum }_{i=1}^{{N}_{{batch}}}{||}{hat{{{{{{boldsymbol{y }}}}}}}}_{i}{{{{{boldsymbol{-}}}}}}{{{{{{boldsymbol{y }}}}}}}_{i}{{||}}_{2}^{2}), where ({hat{{{{{{boldsymbol{y}}}}}}}}_{i}) is the estimated vector of TF expression values generated by the labeler or anti-labeler. For anti-labeler, ({{{{{{boldsymbol{y}}}}}}}_{i}) corresponds to the TF expression values outputted by the causal controller; for labeler, this vector can also correspond to TF expression values from the real training data. This resembles the idea behind an autoencoder and ensures that the model will not disregard the expression of TFs in generating the expression of their target genes. The anti-labeler is trained solely based on the outputs of the target generators, while the labeler utilizes both the outputs of target generators and the expression of real cells (Fig. 1C). They are both implemented as fully connected networks with a width of 2000 and a depth of 3 and optimized using the AMSGrad57 algorithm. Each layer, except the last one, utilizes a ReLU activation function and batch normalization. In addition to WGAN-GP losses, we add labeler and anti-labeler losses to the generator to minimize both. This is different from the approach used in CausalGAN, where the anti-labeler’s loss is maximized in the early training stages to ensure that the generator doesn’t fall into label-conditioned mode collapse. In GRouNdGAN, the causal controller is pretrained and generates continuous labels (TFs expression) and does not face a similar issue. As a result, we instead minimized the loss of the anti-labeler from the beginning and as such the labeler and anti-labeler both act as auxiliary tasks to ensure the generated gene expression values take advantage of TFs expression.

Training procedure and hyperparameter tuning

We follow a two-step training procedure comprising of training two separate WGAN-GPs (Fig. 1) to train GRouNdGAN. The generators and critics in both GANs are implemented as fully connected neural networks with rectified linear unit (ReLU) activation functions in each layer, except for the last layer of the critic. The weights were initialized using He initialization58 for layers containing ReLU activation and Xavier initialization59 for other layers (containing linear activations). We used Batch Normalization60 to normalize layer inputs for each training minibatch, except for the critic. This is since using it in the critic invalidates the gradient penalty’s objective, as it penalizes the norm of the critic’s gradient with respect to the entire batch rather than to inputs independently. An LSN layer24 was used in both WGAN-GPs (Fig. 1B, C) to scale counts in each simulated cell to make it consistent with the library size of input reference dataset. This normalization results in a dramatic decrease in convergence time and smooths training by mitigating the inherent heterogeneity of scRNA-seq data.

We would like to point out that both steps of training (i.e., the pretraining of Fig. 1B and training of Fig. 1C) are performed on the exact same training set, and when data corresponding to a new dataset is to be generated, these steps need to be repeated. The goal of the first step is to train the causal controller to learn the distribution of training set and generate realistic TF expression values (without imposing TF-gene relationships). The goal of the second step is to use the TF expression values generated by the trained causal controller to generate expression of target genes while imposing the TF-gene causal relationships.

GRouNdGAN solves a min-max game between the generator (({f}_{g})) and the critic (({f} ! _{c})), with the following objective function:

$${min }_{{f}_{g}}{{{{{rm{max}}}}}}_{{||} , {{f} _{c}}{{||}_{L}}le {1}}{{mathbb{E}}}_{x sim {{mathbb{P}}}_{r}} , {f} ! _{c}left(xright)-{{mathbb{E}}}_{x sim {f} ! _{g}left({{mathbb{P}}}_{{noise}}right)} , {f} ! _{c}left(xright)$$


The training objective of the critic involves maximizing the difference between the average score assigned to real (({{mathbb{E}}}_{x sim {{mathbb{P}}}_{r}} , {f} ! _{c}left(xright))) and generated samples (({{mathbb{E}}}_{x sim {f} ! _{g}left({{mathbb{P}}}_{{noise}}right)} , {f} ! _{c}left(xright))) with respect to its parameters following Eq. (1). On the contrary, the generator attempts to minimize the average score that the critic assigns to real and generated samples. Through this adversarial game, both the critic and generator co-evolve and the generator learns a mapping ({f} ! _{g}) from a simple standard Gaussian noise distribution (({{mathbb{P}}}_{{noise}})) to a distribution ({f}_{g}left({{mathbb{P}}}_{{noise}}right)) that approximates the real data distribution ({{mathbb{P}}}_{r}). An important point to consider is that the critic does not directly compute the Wasserstein distance (mathematically defined in Supplementary Notes – Details regarding the Wasserstein distance). Instead, it is encouraged to provide a meaningful estimate of the Wasserstein distance between the distribution of real data (({{mathbb{P}}}_{r})) and the distribution of generated data (({f}_{g}left({{mathbb{P}}}_{{noise}}right))) through its training process.

We alternated between minimizing the generator loss for one iteration and maximizing the critic loss for five iterations. We employed the AMSGrad57 optimizer with the weight decay parameters ({beta }_{1}=0.5), ({beta }_{2}=0.9) and employed an exponentially decaying learning rate for the optimizer of both the critic and generator.

The hyperparameters were tuned using a validation set consisting of 1000 cells from the PBMC-CTL dataset (Supplementary Tables 1 and 2) based on the Euclidean distance and the RF AUROC score, which were consistently in accord. The same hyperparameters were used for all other analyses and datasets.

Causal GRN preparation

This section describes the creation of the causal graph inputted to GRouNdGAN used to impose a causal structure on the model. GRouNdGAN accepts a GRN in the form of a bipartite directed acyclic graph (DAG) as input, representing the relationship between TFs and their target genes. In this study, we created the causal graph using the 1000 most highly variable genes and TFs identified in the preprocessing step (Supplementary Data 1 – Sheet3). First, the set of TFs among the highly variable genes were identified based on the AnimalTFDB3.0 database61 and a GRN was inferred using GRNBoost28 (with the list of TFs provided) from the training reference dataset. It is important to note that the regulatory edges identified from the reference dataset using GRNBoost2 are not necessarily causal edges (and they do not need to be for the purpose of forming the input GRN), but they are consistent with the patterns of the data. However, when this (potentially non-causal) GRN is imposed by GRouNdGAN, it is imposed in a causal manner and represents the causal data generating graph of the simulated data (and not the reference data).

Evaluation of the resemblance of real and simulated cells

We evaluated all models using held-out test sets containing randomly selected cells from each reference dataset (500 cells from BoneMarrow and 1000 cells for all other datasets) (see Supplementary Data 1 – Sheet 3 for other statistics about the datasets). To quantify the similarity between real and generated cells, we employed various metrics. For each cell represented as a datapoint in a low dimensional embedding (e.g., t-SNE or UMAP), the integration local inverse Simpson’s Index (iLISI)36 captures the effective number of datatypes (real or simulated) to which datapoints of its local neighborhood belong based on weights from a Gaussian kernel-based distributions of neighborhoods. The miLISI is the mean of all these scores and in our study ranges between 1 (poor mixing of real and simulated cells) and 2 (perfect mixing of real and simulated cells). Additionally, we calculated the cosine and Euclidean distances of the centroids of real cells and simulated cells, where the centroid was obtained by calculating the mean along the gene axis (across all simulated or real cells).

To estimate the proximity of high-dimensional distributions of real and simulated cells without creating centroids, we used the maximum mean discrepancy (MMD)35. Given two probability distributions (p) and (q) and a set of independently and identically distributed (i.i.d.) samples from them, denoted by (X) and (Y), MMD with respect to a function class ({{{{{mathcal{F}}}}}}) is defined as

$${{{{{rm{MMD}}}}}}left[{{{{{mathcal{F}}}}}}{{{{{mathscr{,}}}}}} , X,, Yright] , {{{{{rm{: !=}}}}}}{sup }_{f{{{{{mathscr{in }}}}}}{{{{{mathcal{F}}}}}}}left({{mathbb{E}}}_{x}left[fleft(xright)right]-{{mathbb{E}}}_{y}left[fleft(yright)right]right),$$


where (sup) refers to supremum and ({mathbb{E}}) denotes expectation. When the MMD function class ({{{{{mathcal{F}}}}}}) is a unit ball in a reproducing kernel Hilbert space (RKHS) ({{{{{mathcal{H}}}}}}) with kernel (k), the population MMD takes a zero value if and only if (p=q) and a positive unique value if (p , ne , q). The squared MMD can be written as the distance of mean embeddings ({mu }_{p}), ({mu }_{q}) of distributions (p) and (q), which can be expressed in terms of kernel functions:

$${{MMD}}^{2}[F,, p,, q]={||}{mu }_{p}-{mu }_{q}{{||}}_{H}^{2}$$


$$={{mathbb{E}}}_{x,{x}^{{prime} }}left[kleft(x,, {x}^{{prime} }right)right]-2{{mathbb{E}}}_{x,y}left[kleft(x,, yright)right]{{+}}{{mathbb{E}}}_{y{{,}}{y}^{{{{prime} }}}}left[kleft(y,, {y}^{{prime} }right)right]$$

Following existing implementations of MMD in the single-cell domain24,62, we chose a kernel that is the sum of three Gaussian kernels to increase sensitivity of the kernel to a wider range:

$$kleft(x,, yright)={sum}_{i}left(frac{{||x}-y{{||}}^{2}}{{sigma }_{i}^{2}}right),, i={{{{mathrm{1,2,3}}}}}$$


where ({{{{{{rm{sigma }}}}}}}_{i}) denote standard deviations and were chosen to be the median of the average distance between a point to its 25 nearest neighbors divided by factors of 0.5, 1, and 2 in the three kernels, respectively.

We also used a random forests (RF) classifier and used its area under the receiver operating characteristic (AUROC) curve to determine whether the real and simulated cells can be distinguished from each other. Consistent with previous studies24,63, we first performed a dimensionality reduction using principal component analysis (PCA) and used the top 50 PCs of each cell as the input features to the RF model, which improves the computational efficiency of this analysis. The RF model was composed of 1000 trees and the Gini impurity was used to measure the quality of a split.

Baseline simulator models

We compared the performance of GRouNdGAN to scDESIGN225, SPARsim26, and three GAN-based methods: scGAN24, cscGAN with projection-based conditioning24, and a conditional WGAN (cWGAN). The cWGAN method conditions by concatenation following the cGAN framework64. More specifically, it concatenates a one-hot encoded vector (representing the cluster number or cell type) to the noise vector input to the generator and cells forwarded to the discriminator. We did not train the cWGAN or cscGAN on the PBMC-CTL dataset, since it contains only one cell type. For the PBMC-All and the BoneMarrow dataset, we trained all models above. Additionally, we simulated data using scDESIGN2 and SPARsim with and without cell cluster information, as they allow providing such side information in their training.

To train models that utilized cell cluster information, we performed Louvain clustering and provided the cluster information and ratio of cells per cluster during training. Clustering was done by following the cell ranger pipeline33, based on the raw unprocessed dataset (and independent of the pre-processing steps described earlier for training simulators). First, genes with no UMI count in any of the cells were removed. Then the gene expression profile of each cell was normalized by the total UMI of all (remaining) genes, and highly variable genes were identified. The gene expression profile of each cell was then re-normalized by the total UMI of retained highly variable genes, and each gene vector (representing its expression across different cells) was z-score normalized. Given the normalized gene expression matrix above, we found top 50 principal components (PCs) using PCA analysis. These PCs were then used to compute a neighborhood graph with a local neighborhood size of 15, which was used in Louvain clustering. We ran the Louvain algorithm with a resolution of 0.15.

For SPARSim, we set all sample library sizes to 20000 and estimated gene expression level intensities and simulation parameters by providing it with both raw and normalized count matrices. When cell cluster information was provided, distinct SPARSim simulation parameters were estimated per cell for each cluster. scDESIGN2 accepts input matrices where entries are integer count values; we thus performed rounding on the expression matrix before fitting scDESIGN2. With cluster information provided, a scDESIGN2 model was fit separately for cells of each cluster, and similar to conditional GANs, the ratio of cells per cluster was provided to the method.

In silico perturbation experiments using GRoundGAN

To perform perturbation experiments using GRoundGAN, we put the trained model in a deterministic mode of operation. This is necessary to ensure that the perturbation experiments are performed on the same batch (i.e., replicate) of generated cells to form matched case/control experiments. To do this, we performed a forward pass through the generator and then saved the input noise to the causal controller, the input noise to the target generators, the TF expression values generated by the causal controller, and the per-cell scaling factor of the LSN layer. Subsequent passes through the generators used the saved parameters so that ensuing runs always output the same batch of cells (instead of generating new unmatched cells).

Trajectory inference and pseudo-time analysis

Following official PAGA tutorial for the BoneMarrow dataset (, we used (Partition-based graph abstraction) PAGA38 for trajectory inference and analysis. We built force-directed graphs65 (with ForceAtlas266) using the top 20 principal components of the data (using principal component analysis or PCA) and a neighborhood graph of observations computed using UMAP (to estimate connectivities). We next denoised the graph by representing it in the diffusion map space and computed distances and neighbors as before using this new representation. After denoising, we then ran the Louvain clustering algorithm with a resolution of 0.6. Finally, we ran the PAGA algorithm on the identified clusters and used the obtained graph to initialize and rebuild the force-directed graph.

GRN inference methods

In our GRN benchmarking analysis, we focused on eight GRN inference algorithms: GENIE39, GRNBoost28, PPCOR37, PIDC10, LEAP12, SCODE13 and SINCERITIES14, which were used in the BEELINE study19, as well as CeSpGRN48. Of these methods, LEAP requires pseudo-time ordering of cells, while SCODE and SINCERITIES require both pseudo-time ordering and pseudo-time values. Since not all algorithms inferred the edge directionality or its sign (activatory or inhibitory nature), we did not consider these factors in our analysis to be consistent among different models.

For the methods available in the BEELINE study, we ran them as docker containers using the docker images provided by BEELINE’s GitHub ( with the default parameters used in BEELINE. These methods were applied to nine datasets simulated by GRouNdGAN and scGAN, and the original training real dataset corresponding to PBMC-CTL, PBMC-All, and BoneMarrow. To ensure consistency, the same number of cells as the real training set were simulated using GRouNdGAN and scGAN for each dataset: n = 19773 for PBMC-CTL, n = 67579 for PBMC-All, and n = 2230 for BoneMarrow. Number of genes and TFs present in the GRN for each dataset is provided in Supplementary Data 1 – Sheet 3. To benchmark algorithms requiring pseudo-time ordering of cells, we computed the pseudo-times of GRouNdGAN-simulated data (based on the BoneMarrow dataset) using PAGA38 and diffusion pseudotime39, following the methodology described earlier. In the GRN inference benchmark analysis, we did not provide the list of TFs to GRNBoost2 to make it consistent with other GRN inference methods.

We also included CeSpGRN, which is a cell-specific GRN inference method. Since this method first generates one GRN for each cell, it requires a high amount of memory to run. As a result, we were only able to benchmark it using a subset of data consisting of only n = 1000 cells (for any of the three datasets) and 100 genes (we did not change the number of TFs, or the GRN edges connecting them to the considered genes). Following the method described in the original study, we then averaged the total absolute edge weights across all cells to form a consensus GRN using CeSpGRN.

Statistics and reproducibility

The sample sizes were selected by original studies producing datasets used for training GRouNdGAN. Preprocessing steps are described in the datasets and preprocessing section. Other than standard cell-level and gene-level filtering, there were no data exclusions. Statistical tests used for each analysis are described in the corresponding sections and include Wilcoxon signed rank and Mann Whitney U tests.

Reporting summary

Further information on research design is available in the Nature Portfolio Reporting Summary linked to this article.