Graph Neural Networks for Likelihood-Free Inference in Diversification Models
This article has been Reviewed by the following groups
Listed in
- Evaluated articles (Arcadia Science)
Abstract
A common approach to infer the processes that gave rise to past speciation and extinction rates across taxa, space and time is to formulate hypotheses in the form of probabilistic diversification models and estimate their parameters from extant phylogenies using Maximum Likelihood or Bayesian inference. A drawback of this approach is that likelihoods can easily become computationally intractable, limiting our ability to extend current diversification models with new hypothesized mechanisms. Neural networks have been proposed as a likelihood-free alternative for parameter inference of stochastic models, but so far there is little experience in using this method for diversification models, and the quality of the results is likely to depend on finding the right network architecture and data representation. As phylogenies are essentially graphs, graph neural networks (GNNs) appear to be the most natural architecture but previous results on their performance are conflicting, with some studies reporting poor accuracy of GNNs in practice. Here, we show that this underperformance was likely caused by optimization issues and inappropriate pooling operations that flatten the information along the phylogeny and make it harder to extract relevant information about the diversification parameters. When equipped with PhyloPool, a new time-informed pooling procedure, GNNs show similar or better performance compared to all other architectures and data representations (including Maximum Likelihood Estimation) that we tested for two common diversification models, the Constant Rate Birth-Death and the Binary State Speciation and Extinction. We conclude that GNNs could serve as a generic tool for estimating diversification parameters of complex diversification models with intractable likelihoods.
Article activity feed
-
We anticipate that layers that account for this depth order, e.g. through convolutions or possibly self-attention (as used in spatio-temporal graphs (e.g. Guo et al. 2019, Su et al. 2020)), will often be complementary to other layers acting on the topology (encoded in the phylogenetic graph), e.g. through graph convolutions.
Related to the pooling operator, I think large gains may come from the use of 1) edge weights in your GCN layers so that not all neighbors are treated equally by the message passing mechanism, and 2) alternative MPNN layer types, including use of the graph attention mechanism (i.e. GAT) or graph transformers, which use the attention mechanism to learn which neighbors are more "important." I suspect that even with simple mean-pooling, these alternative layer types will be much more performant and generalizable …
We anticipate that layers that account for this depth order, e.g. through convolutions or possibly self-attention (as used in spatio-temporal graphs (e.g. Guo et al. 2019, Su et al. 2020)), will often be complementary to other layers acting on the topology (encoded in the phylogenetic graph), e.g. through graph convolutions.
Related to the pooling operator, I think large gains may come from the use of 1) edge weights in your GCN layers so that not all neighbors are treated equally by the message passing mechanism, and 2) alternative MPNN layer types, including use of the graph attention mechanism (i.e. GAT) or graph transformers, which use the attention mechanism to learn which neighbors are more "important." I suspect that even with simple mean-pooling, these alternative layer types will be much more performant and generalizable (e.g. from CRBD to BiSSE). In effect the GCN layers (particularly without using edge weights) is more akin to the CRBD in that it assumes uniform, homogeneous contribution by all neighbors to feature updates.
-
the LTT-based statistics are less useful under BiSSE, which explains why the PhyloPool procedure loses part of its edge against global pooling (used in GNN-avg): preserving the phylogenetic order is intuitively less important when estimating under the BiSSE model, where consecutive nodes may be under different states.
Again, I think this is a case for exploring the use of more general pooling operators, such as EdgePooling, etc, that might capture the relevant signal, but without imposing such rigid inductive biases on the architecture that could prove harmful to more general application.
-
In graph convolutional layers, the update for each node starts from an average of the node itself and its direct neighbors in the graph, normalized by the degree of the nodes (see Fig. 2.a).
Related to the above (regarding the number of GNN layers used), have you considered/looked into using either gate residual connections, or using something like jumping knowledge (https://arxiv.org/abs/1806.03536) between layers to mitigate oversmoothing in deeper networks which GNNs tend to be prone to? This could be another simple modification with outsized benefits.
-
In GNNs, each node has an initial embedding vector that is then iteratively updated using the embeddings of its neighbors through graph convolutional layers, a common update scheme.
I presume you're using the "base" GCN layer here? If so I'd be clear about this, since unlike the CNN and MLP, there is a massive diversity of MPNN layer types - its worth being very explicit about what you're using, since their choice could have massive impacts on the performance of these architectures!
-
We consider two GNN architectures, both starting with graph convolutional layers. The first architecture (GNN-avg) aggregates the outcome of the last convolution through a global average pooling layer (as in Lajaaiti et al. 2023), the second (GNN-PhyloPool) through our PhyloPool procedure (see Table S4 and S5 in the Appendix for details on both GNNs).
Have you considered other pooling layers other than mean-pooling? For instance existing pooling operators like EdgePooling might confer similar benefits by retaining the temporal signal through iterative edge contraction. https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.EdgePooling.html#torch_geometric.nn.pool.EdgePooling
-
We provide the GNN with the phylogeny’s topology and one attribute per node as its embedding: the distance to the tips, with the tips assigned 0 and the root at a negative coordinate corresponding to the depth of the phylogeny.
Is there a reason you chose not to include edge weights/attributes corresponding to each respective branch length? I suspect this would be quite useful/informative for providing additional context to the message passing layers and could boost performance further. Otherwise, each neighbor is assumed to contribute equally to feature updates, which is not something we innately suspect to be true.
-
These are followed by pooling layers to capture global features of the graph, and finally fully connected layers.
True, but only necessarily when the prediction task is graph-level. This is not necessary for many other common GNN prediction tasks, including node or link prediction. This is a bit of a nitpick, but I think useful to distinguish since this architecture is so new to the field!
-
We passed the CDV representation through a convolutional neural network (CNN) to predict the parameters of interest (see Table S3 in the Appendix for details on this CNN).
I know these details are present in the supplement, but here and for the other NNs, it could be useful to specify the number of layers used - particularly for the GNN, where this relates to the effective "diameter" of visibility or neighborhood size (i.e. how many hops away the message passing mechanism aggregates neighbor information).
-