A Brief Primer for Applying Deep Graph Learning to Molecular Graphs

10 minute read
Content level: Intermediate
5

A brief introduction to using deep graph learning methods in the context of understanding biological and molecular data. Specifically, this primer will focus on how computational biologists can leverage the Deep Graph Library to apply deep learning to their graph data.

Authors: Joshua Broyde and Shamika Ariyawansa

Introduction

A key problem that Computational Biology and Bioinformatics scientists sometimes face is leveraging state-of-the-art machine learning tools for scientific problems. In the context of graph analysis, many new methods rely on deep learning approaches. This post is meant to be a basic primer for computational biologists, bioinformaticians, and systems biologists who are starting to think about applying state-of-the-art deep graph learning to their projects, but may not yet be familiar with these methods. In this post, we walk through the core concepts presented in this notebook, with the aim of explaining the foundations of the code shown in that resource. We encourage you to read through that notebook, and you can even execute that notebook in the context of your own AWS account to run the machine learning models.

Why is Deep Learning being applied to biological networks?

If you are familiar with classical network analysis, you probably encountered various measures around graphs, such as the betweenness centrality, the degree centrality, or methods for analyzing networks, such as random walk with restart. These methods have been used in the past for doing things like calculating properties of nodes, analyzing networks, or analyzing groups of disease genes. However, many such methods that rely on these types of centralities suffer from the fact that they are transductive, which means that they can only be used for generating features for a particular graph. Thus, doing calculations such as edge prediction or graph classification, where many different graphs are used, may be difficult. See this paper for further discussion of this issue.

(A quick note on nomenclature: we use the term “graphs" to refer to biological networks; we reserve the term "network" for a neural network. Although it is common in the computational biology field to refer to biological graphs as networks, in the deep learning field, "network" refers almost exclusively to a neural network).

One of the key ideas behind applying deep learning to graphs is that convolution neural networks, commonly used in computer vision, can with certain modifications also be used in the context of analyzing graphs. Convolutions, among other properties, allow for inductive learning, whereby features can be learned for different graph topologies. These convolutions thus transform the underlying information in the graph nodes and edges. While a single convolutional layer is generally not sufficient for most tasks, many different convolutions can be combined to form a deep graph convolutional neural network. By using layers of graph convolutions, you can build complex neural networks that can do graph prediction (i.e., predict the class of a network), link prediction (predict missing edges in a network), as well as many other tasks.

With deep learning models, it is also possible to incorporate different edge types as well as external information about edges and nodes. This makes deep learning an attractive approach for analyzing and making predictions about graphs, since biological networks are frequently very heterogeneous, with diverse datasets, such as metabolic, biophysical, proteomic and functional assays, and gene regulatory networks. For example, this blog post shows how a knowledge graph with diverse node and edge types to predict drug repurposing.

While of course researchers can create their own convolutional layers, deep learning researchers have already built many convolutions and architectures that have proven useful in many applications. For example, GraphSage can be used for predicting protein-protein interactions. Another commonly used approach is Graph Attention Networks (GAT).

For a deeper overview of deep graph learning, and how it is being used to analyze biological data, see this review paper. You may also find this tutorial useful.

What is the Deep Graph Library (DGL)? When Should you Use it?

Simply put, Deep Graph Library (DGL) allows researches and developers to easily and quickly apply deep graph learning approaches to their data by abstracting away much of the difficult deep learning work and code. The DGL library comes with a number of prebuilt layers, so that researchers don’t have to reimplement these themselves. For example, this page shows the method call within DGL for using the GraphSage convolution, while this page shows the method call for GAT. You can take a look at the many other convolutions DGL already has here. Of course, you have the flexibility to create your own layers and architectures as well.

Furthermore, the DGL-LifeScience python package provides an even further abstraction of DGL, so that computational biologists, biochemists, and bioinformaticians who wish to leverage deep graph methods can easily do so for certain common use cases and performing common operations in the context of analyzing small and large molecules. If you want to learn more about how to use the DGL library, we recommend getting started with this tutorial.

Applying Deep Graph Learning to Molecular Property Prediction

Human immunodeficiency virus type 1 (HIV-1) is the most common cause of Acquired Immunodeficiency Syndrome (AIDS), and an ongoing area of research is to determine compounds that inhibit HIV-1 viral replication. The Drug Therapeutics Program AIDS Antiviral Screen has tested the ability of 43,850 compounds to inhibit viral replication. You can read more about this dataset and the assays here. The DGL library has a pre-processed version of this dataset where each compound is classified as either Confirmed Inactive (CI; labeled as 0) or Confirmed Moderately Active/Confirmed Active (CM,CA; labeled as 1). If you download and inspect the raw dataset from here (you can download a csv using this link) you will see the data looks something like this:

smilesactivityHIV_active
CC(C)(CCC(=O)O)CCC(=O)OCI0
O=C(O)Cc1ccc(SSc2ccc(CC(=O)O)cc2)cc1CM1
O=C(O)c1ccccc1SSc1ccccc1C(=O)OCI0
CCCCCCCCCCCC(=O)Nc1ccc(SSc2ccc(NC(=O)CCCCCCCCCCC)cc2)cc1CI0

Since Confirmed Inactive compounds are labeled 0, while Confirmed Moderately Active/Confirmed Active are labeled 1, this problem is a graph classification problem. Each molecule (which is described as a SMILES string) will be constructed as a graph, with the goal of classifying each molecule as either active or inactive. In this graph, each atom is a node, and an edge is a bond between two atoms.

We will now dive deeper into this specific notebook. We encourage you to look through that notebook in a step-by-step fashion. In this post, we will only delve into the specific components of the problem that are related to deep learning and DGL.

The key lines for reading in the dataset in the notebook are:

node_featurizer = CanonicalAtomFeaturizer(atom_data_field='feat')
edge_featurizer = None
dataset = HIV(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
                  node_featurizer=node_featurizer,
                  edge_featurizer=edge_featurizer,
                  n_jobs=num_workers)

To break this down, the node_featurizer is creating features about each atom. As described here, this will create features that describe the atom type, charge, hybridization and other chemical characteristics. Note that we are setting edge_featurizer to None as we are not going to create features about the edges. We use the smiles_to_graph functionality to convert the SMILE strings to a graph. We use the add_self_loop feature because the graph convolutions rely on message passing; which in turn relies on node neighbor information. Self-loops can help ensure that a particular node’s information does not become drowned out by its neighbors.

In the notebook, you next use the RDkit chemoinformatics tool to visualize, explore and understand the dataset. Note that while this data exploration is essential for understanding your data, this does not itself create new features for the deep graph network. While we will not dive deeper in this post to describe the operations of RDkit, feel free to explore them in the notebook.

The key section describing model the is:

from dgllife.model import GCNPredictor
import torch.nn.functional as F

model = GCNPredictor(
            in_feats=10,
            hidden_feats=[10, 4],
            activation=[F.relu, F.relu],
            residual=[False] * 2
            )

There is a lot to unpack here (The dgllife library has abstracted a lot for us!) so let’s go through it. As described here, the Graph Convolutional Network Predictor(GCNPredictor) can be used for regression and classification of graphs. in_feats describes the number of input features for the nodes. In this notebook, we are using 10 features. hidden_feats describes the number of features at each hidden layerof the neural network. Note that the length of this vector is the number of hidden layers in the network. We are using [10,4] (which has length 2) which means that this network has two hidden layers.

It is worthwhile to understand more deeply what this is doing. If you type model you can see the entire architecture (we are presenting a simplified version of the code block).

GCNPredictor(
  (gnn): GCN(
    (gnn_layers): ModuleList(
      (0): GCNLayer(
        (graph_conv): GraphConv(in=10, out=10 , activation=relu)
        (dropout): Dropout(p=0.0)
        (bn_layer): BatchNorm1d(10)
      )
      (1): GCNLayer(
        (graph_conv): GraphConv(in=10, out=4, activation=relu)
        (dropout): Dropout(p=0.0)
        (bn_layer): BatchNorm1d(4)
      )
    )
  )
  (readout): WeightedSumAndMax(
    (weight_and_sum): WeightAndSum(
      (atom_weighting): Sequential(
        (0): Linear(in_features=4, out_features=1, bias=True)
        (1): Sigmoid()
      )
    )
  )
  (predict): MLPPredictor(
    (predict): Sequential(
      (0): Dropout(p=0.0)
      (1): Linear(in_features=8, out_features=128, bias=True)
      (2): ReLU()
      (3): BatchNorm1d(128)
      (4): Linear(in_features=128, out_features=1, bias=True)
    )
  )
)

For those not familiar with neural network architectures, this can be a bit daunting, so let’s step through it a bit more:

First, we are creating a GCN layer. The input dimension is of length 10, as is the output dimension; this uses the standard ReLU activation function. Furthermore, we are using batch normalization (bn_layer) and dropout. These, respectively, increase stability/performance and decrease overfitting. This is then fed to a second GCN layer, which is nearly identical, except that the output dimension is of length 4.

This output is then passed to the WeightedSumAndMax layer, which takes in the 4 features, and uses a Linear layer followed by a Sigmoid layer. This is in turn passed to a multi-layered-perceptron (MLP), which does the final prediction. Note that for the MLP, the in_features is 8, since WeightedSumAndMax doubled the size of the feature vector from 4 to 8, as described in the documentation here. The number 128 corresponds to the number of hidden features in the MLP. The output layer is a linear layer with dimension equal to 1. This single output is thus a single number that can be used as a probability for whether the molecule is active or not.

Next Steps

In this post, we have not explained the SageMaker specific components of leveraging DGL, such as for training DGL models at scale, using SageMaker Experiments for tracking performance, and deploying the model. Instead, in this primer we have explained the basics of DGL itself and some of the convolutions used in that approach and notebook. If you are interested in more, we encourage you to explore the rest of the notebook, as well as the DGL documentation here.

If you want just beginning your journey of using Deep Learning, you can learn more broadly about deep learning here. For graph learning specifically, you can explore this resource here.

About the Authors

Joshua is a Senior AI/ML Solutions Architect within the Healthcare and Life Sciences Industry Business Unit at AWS. Shamika is an AI/ML Solutions Architect within the Healthcare and Life Sciences Industry Business Unit at AWS.

AWS
EXPERT
published 2 years ago1489 views