1 Introduction

Drug development costs billions of dollars over many years with multiple stages of refinement and trial (Mullard 2014). For this reason, repurposing approved drugs is a critical alternative to the full development cycle, offering a huge saving of money, time, and lives. A canonical task in drug repurposing is to predict the drug effect on multiple related diseases. This can naturally be formulated as a multilabel learning problem. Different from standard domains like text or image, drugs are usually represented as variable size graphs of atoms linked by chemical bonds. The irregularity and complexity of rich graph structures make multilabel learning over molecular graphs very challenging. At the same time, graphs bring about new kinds of information not previously seen in unstructured data, as evidenced in the recent surge of research in graph representation learning (Hamilton et al. 2017b; Zhang et al. 2017). Hence, a new proper treatment of the multilabeling over graphs is needed.

We hypothesize that the key for classification performance and explainability lies in uncovering the relations between labels and subgraphs. Towards this goal, we design a new graph neural network (Scarselli et al. 2009) called \(\text {GAML}\) (which stands for Graph Attention model for Multi-Label learning). \(\text {GAML}\) treats all label classes as nodes (termed label nodes) and merges them with other nodes (called input nodes) of an input graph to form a unified label-input graph. In the joint graph, relations between labels and substructures can be effectively captured through the interaction across the label nodes and the input nodes. Specifically, we leverage the message passing algorithm (Pham et al. 2017; Schlichtkrull et al. 2017; Gilmer et al. 2017) to simultaneously update the local substructure at every input node and to propagate the substructure-contained messages from all the input nodes to the label nodes. By using attention (Bahdanau et al. 2014; Xu et al. 2015), each label node can extract the most related substructures to update its own state which will later be used to predict the presence of the corresponding class. Attention also enables insightful visualization which helps explain the prediction. To account for large number of classes and big input graphs, we propose a new type of attention named hierarchical attention. Different from the standard approach that calculates the score matrix between every input and label node directly, our attention mechanism uses some intermediate attentional factors to save computation. In our model, implicit dependencies among the labels are captured via common attended substructures. However, when explicit dependencies among the labels are available (e.g, through expert knowledge), \(\text {GAML}\) can easily integrate them by adding links and exchange messages between related label nodes. Moreover, since the node update procedure runs iteratively, our model can learn the label-subgraph (or label-substructure) relations at various resolution scales.

The flexibility and scalability of \(\text {GAML}\) make it attractive to many real-world problems. In this paper, we focus on two major drug–multitarget prediction problems: predicting drug–protein binding, and drug–cancer response. In the first problem, a drug is tested against multiple target proteins; and in the second problem, a drug is tested against multiple cancer types. We also evaluate our method on classical vector input which can be seen as a special graph with a singleton node. In both cases, \(\text {GAML}\) proves to be superior against rival multilabel learning techniques. Finally, to get a clear picture of the learned label-substructure patterns, we generate visualizations using real drug molecules extracted from our datasets.

In summary, our contributions are:

  • Proposing a novel neural graph neural network named \(\text {GAML}\) that addresses an open problem of multilabel classification over graphs. Our model can effectively capture the (multi-way) relations among the labels and the input subgraphs. It can also incorporate explicit label dependencies and is scalable to many labels and big graphs.

  • Demonstrating the advantages of \(\text {GAML}\) through a comprehensive suite of experiments with quantitative evaluation and visualization.

2 Related work

Multilabel classification with label dependencies Most work in multilabel learning focuses on capturing the implicit or explicit label dependencies. One strategy is applying Canonical Correlation Analysis (CCA) to map input and label into a common latent space. Then from this space, the model will reconstruct the target label. Extensions of this approach including both shallow (Li and Guo 2015; Sun et al. 2011) and deep (Yeh et al. 2017) models. For graphical model-based approach, the work in Ghamrawi and McCallum (2005) uses Conditional Random Fields to model the three way relation between every pair of labels i, j and the input \({\varvec{x}}\) using a feature function \(\phi (y_{i},y_{j},{\varvec{x}})\). Meanwhile, the work in Guo and Gu (2011) constructs a fully connected cyclic Bayesian Network over labels and perform structure learning on this network. The probability of a label \(y_{i}\) conditioned on the input \({\varvec{x}}\) and other labels \(y_{\lnot i}\) is modeled using a logistic regression network. Both methods are computationally expensive and require inexact inference for large number of labels.

To model the joint distribution of labels but still keep computation reasonable, some methods exploit chain rule factorization. The most notable one is Probabilistic Classifier Chain (Dembczynski et al. 2010) which builds a separate binary classifier for each label with input to the model is the combination of the original input and the previously predicted labels. Other methods follow that idea but use recurrent neural networks (Chen et al. 2017; Wang et al. 2016) to learn the relations better. However, the critical issues of these methods are ordering and poor inference (since the output label at one step depends on the value of the previous predicted labels not their distribution, which is very unstable). Although some tricks like beam search (Wang et al. 2016), or automatic order selection (Chen et al. 2017) have been implemented to improve the results, they can only solve part of the problem.

Expert knowledge about label dependencies represented as trees (Deng et al. 2014) or graphs (Bi and Kwok 2011; Chen et al. 2018) has been exploited for multilabel/multiclass classification. In Chen et al. (2018), the authors build a graph neural network over the predefined label graph. The input vector is copied for every label node and is concatenated to the label embedding vector to form an initial state for that label node. Their method, however, is limited to the vector input only whereas our model directly works on graph input with vector input is the special case.

Multilabel classification with graph inputs Although graph classification has attracted a significant interest in recent years (Takigawa and Mamitsuka 2017), there has been a limited body of work on multilabel graph classification (Kong and Philip 2012). The line of work on image tagging considers multilabel learning over a grid of pixels (Gong et al. 2013; Wang et al. 2016; Wei et al. 2016). However, the standard treatment using CNN usually focuses on attention over feature maps instead of exploiting the structural relations of objects in the original image. A recent work in visual question answering that pushes forward the idea of object graph is Teney et al. (2017), but the QA setting is different from ours. A special case of our multilabel learning over graphs is multilabel learning over set (Pham et al. 2017) where input is a collection of nodes with no explicit links.

Graph neural networks By leveraging the representation power of deep neural networks such as CNN and RNN, a wide range of methods for learning over graphs (Defferrard et al. 2016; Gilmer et al. 2017; Hamilton et al. 2017a; Kipf and Welling 2016a; Li et al. 2016; Niepert et al. 2016; Pham et al. 2017; Scarselli et al. 2009) has been proposed recently. These methods can be grouped into more general categories such as Spectral Graph based (Bruna et al. 2013; Defferrard et al. 2016; Kipf and Welling 2016a), Message Passing based (Gilmer et al. 2017; Pham et al. 2017; Schlichtkrull et al. 2017), Random Walk based (Grover and Leskovec 2016; Perozzi et al. 2014), Neural Net based (Li et al. 2016). Among them, Message Passing Graph Neural Networks (MPGNNs) are very powerful since they can handle various kinds of graphs including attributed graphs whose edges and nodes both have types. MPGNNs have found many applications in bioinformatics such as drug activity classification (Pham et al. 2018), chemical properties prediction (Gilmer et al. 2017), protein interface prediction (Fout et al. 2017) and drug generation (Jin et al. 2018). However, none of these methods properly handle multilabel classification problems in which modeling multi-way relations among labels and molecular subgraphs is the key factor.

Graph representation learning Many supervised learning problems over graphs (including the multilabel classification problem we are working on) assume precomputed graph embedding. Unsupervised learning methods for graphs (Narayanan et al. 2016; Shervashidze et al. 2011; Yanardag and Vishwanathan 2015) often exploit the common substructures among graphs to ensure that graphs with similar structure will be represented as close points in the embedding vector space. Graph embedding can also be achieved through graph reconstruction or generation. This approach includes VAE/GAN based models (Kipf and Welling 2016b; Simonovsky and Komodakis 2018; Wang et al. 2017) and sequence based models (Li et al. 2018; You et al. 2018).

3 Preliminaries

In this section we provide the mathematical formulation of multilabel classification and the background knowledge about graph neural networks on which \(\text {GAML}\) is built. For clarity and consistency, we use the following notations throughout the paper (unless being stated explicitly): bold letters denote vectors (\({\varvec{x}}\) is a vector); capital letters denote matrices (W is a matrix); normal letters denote scalars (s is a scalar); \(\{\cdot \}\) denotes a set; \(f(\cdot )\) denotes a function f with arguments separated by commas. Table 1 lists most common notations used in the paper.

Table 1 Notations used in the paper

3.1 Multilabel classification

Multilabel classification is a supervised learning problem in which each input example may be associated with more than one output class. Denote by X the input vector space and \({{\mathcal {C}}}\equiv \{1,2,...,C\}\) the set of all classes labeled from 1 to C. Multilabel learning estimates a function f that maps X onto the power set of \({{\mathcal {C}}}\), written as \(f:X\mapsto \mathcal {P}_{{{\mathcal {C}}}}\). Because each element of \(\mathcal {P}_{{{\mathcal {C}}}}\) is a subset of \({{\mathcal {C}}}\), it can be represented as a binary vector \({\varvec{y}}\) of length C with \(y_{c}=1\) indicates that class c appears in the subset and \(y_{c}=0\) otherwise (\(c=\overline{1,C}\)).

3.2 Graph notations

Consider an attributed graph \({\mathcal {G}}=\left( {\mathcal {V}},{\mathcal {E}}\right) \) where \({\mathcal {V}}\) is the set of nodes and \({\mathcal {E}}\) is the set of edges. Each node i is associated with a node feature vector \({\varvec{v}}_{i}\) which captures important properties of a node. For example, if the graph \({\mathcal {G}}\) is a drug molecule (as depicted in Fig. 1), each node is an atom and the node’s properties could be its atomic number, its charge, its valance, etc. In our current work, we only use the atomic number information. Thus, \({\varvec{v}}_{i}\) is the embedded vector with respect to that atom type. Similar to nodes, each edge (ij) is also associated with an edge type \(e_{ij}\) (for molecules, it is a bond type).

Fig. 1
figure 1

A drug molecule (PubChem SID \(=\) 502937) represented as an attributed graph. This graph has 17 nodes numbered from 0 to 16 corresponding to 17 atoms. Nodes are characterized by atom types (sulfur, oxygen, nitrogen, carbon) and edges are specified by bond types (single, double, aromatic)

3.3 Message passing graph neural network

Let \({\varvec{x}}_{i}\) be the state of node i and \({\mathcal {N}}(i)=\left\{ j\mid (i,j)\in {\mathcal {E}}\right\} \) denote the neighborhood of node i. In a message passing graph neural network (Gilmer et al. 2017; Pham et al. 2017; Scarselli et al. 2009), a node uses information from its neighbors to update its own state as follows:

$$\begin{aligned} {\varvec{x}}_{i}^{t}= & {} f\left( {\varvec{x}}_{i}^{t-1},\left\{ \left( {\varvec{x}}_{j}^{t-1},e_{ij}\right) \right\} _{j\in {\mathcal {N}}(i)}\right) \end{aligned}$$
(1)

where t denotes the update step; and \(f(\cdot )\) is a non-linear function (e.g., a multi-layer perceptron (MLP)). At \(t=0\), we set \({\varvec{x}}_{i}^{0}={\varvec{v}}_{i}\).

Equation (1) is generic for most graph neural network models. In practice, it can be divided into two steps: message aggregation and state update. In the message aggregation step, we combine multiple messages sent to node i into a single message vector \({\varvec{m}}_{i}\):

$$\begin{aligned} {\varvec{m}}_{i}^{t}= & {} g^{\text {a}}\left( {\varvec{x}}_{i}^{t-1},\left\{ \left( {\varvec{x}}_{j}^{t-1},e_{ij}\right) \right\} _{j\in {\mathcal {N}}(i)}\right) \end{aligned}$$
(2)

where \(g^{\text {a}}(\cdot )\) can be an attention (Bahdanau et al. 2014; Xu et al. 2015) or a pooling architecture. For example, the message aggregated using mean pooling has the following formula:

$$\begin{aligned} {\varvec{m}}_{i}^{t}= & {} \frac{1}{|{\mathcal {N}}(i)|}\sum _{j\in {\mathcal {N}}(i)}W_{e_{ij}}{\varvec{x}}_{j}^{t-1} \end{aligned}$$
(3)

where \(|{\mathcal {N}}(i)|\) is the number of neighbor nodes of node i; \(W_{{\varvec{e}}}\) is a projection matrix which corresponds to the edge type \({\varvec{e}}_{ij}\). Despite of simplicity, Eq. (3) has shown to be able to encode graph structures in several message passing models (Gilmer et al. 2017; Pham et al. 2017; Schlichtkrull et al. 2017).

During the state update step, the node state is updated as follows:

$$\begin{aligned} {\varvec{x}}_{i}^{t}= & {} g^{\text {u}}\left( {\varvec{x}}_{i}^{t-1},{\varvec{m}}_{i}^{t}\right) \end{aligned}$$
(4)

where \(g^{\text {u}}(\cdot )\) can be any type of deep neural networks such as MLP (Kipf and Welling 2016a; Hamilton et al. 2017a), RNN (Scarselli et al. 2009), GRU (Li et al. 2016) or Highway Network (Pham et al. 2017). In our model, we use Highway Network (Srivastava et al. 2015) for \(g^{\text {u}}(\cdot )\) as it has been shown to be effective for long range dependencies thanks to its skip-connection and gating mechanism. As a result, Eq. (4) now becomes:

$$\begin{aligned} {\varvec{x}}_{i}^{t}= & {} \left( 1-{\varvec{\eta }}_{i}^{t}\right) \odot {\varvec{x}}_{i}^{t-1}+{\varvec{\eta }}_{i}^{t}\odot \hat{{\varvec{x}}}_{i}^{t} \end{aligned}$$
(5)

where \({\varvec{\eta }}_{i}^{t}\in (\varvec{0},\varvec{1})\) and \(\hat{{\varvec{x}}_{i}}^{t}\) are the gate vector and the non-linear candidate vector of node i at time t, respectively; \(\odot \) is the element-wise product. The formulas of \({\varvec{\eta }}^{t}\) and \(\hat{{\varvec{x}}}^{t}\) are provided below:

$$\begin{aligned} {\varvec{\eta }}_{i}^{t}= & {} \text {sigmoid}\left( W_{\eta }{\varvec{x}}_{i}^{t-1}+U_{\eta }{\varvec{m}}_{i}^{t}\right) \end{aligned}$$
(6)
$$\begin{aligned} \hat{{\varvec{x}}}_{i}^{t}= & {} \text {relu}\left( W_{x}{\varvec{x}}_{i}^{t-1}+U_{x}{\varvec{m}}_{i}^{t}\right) \end{aligned}$$
(7)

where \(W_{\eta },W_{x}U_{\eta },U_{x}\) are parameters which can be different or shared among layers. During experiments, we observed that models with parameter sharing run faster but still provide comparable results. Hence, we applied this sharing scheme to our model. We abstract Eqs. (4-7) into:

$$\begin{aligned} {\varvec{x}}_{i}^{t}= & {} \text {Highway}\left( {\varvec{x}}_{i}^{t-1},{\varvec{m}}_{i}^{t}\right) \end{aligned}$$
(8)

After T steps of message passing, \({\varvec{x}}_{i}^{T}\) would capture the graph substructure centered at node i with radius T. The graph summary vector (also called graph representation vector) \({\varvec{x}}_{{\mathcal {G}}}\) is the combination of the state vector of all nodes in the graph at step T. In the simplest form, \({\varvec{x}}_{{\mathcal {G}}}\) is the average of \(\left\{ {\varvec{x}}_{i}^{T}\right\} _{i\in {\mathcal {V}}}\), as follows:

$$\begin{aligned} {\varvec{x}}_{{\mathcal {G}}}= & {} \frac{1}{|{\mathcal {V}}|}\sum _{i\in {\mathcal {V}}}{\varvec{x}}_{i}^{T} \end{aligned}$$

4 Method

In this section, we present our main contribution—Graph Attention model for Multi-Label learning (\(\text {GAML}\)).

4.1 Multilabel classification over graphs

We generalize the definition of multilabel classification in Sect. 3.1 to the situation in which inputs are graphs by considering a problem of learning a function \(f:X_{{\mathcal {G}}}\mapsto \mathcal {P}_{{{\mathcal {C}}}}\) where \(X_{{\mathcal {G}}}\) is the space of input graph representation vectors. We argue that in order to perform well on this task, two types of relation must be captured: those within the label set and those between the label set and input subgraphs.

For an input graph \({\mathcal {G}}\), we consider all the C classes as auxiliary nodes (called label nodes) alongside \(|{\mathcal {V}}|\) existing nodes of the input graph \({\mathcal {G}}\). Each label node c connects to all input nodes and has the initial state \({\varvec{l}}_{c}^{0}\in {\mathbb {R}}^{d_{l}}\) which is the embedding of class c to a vector space. On the other hand, each input node i also connects to all label nodes. It results in a joint graph of \(C+|{\mathcal {V}}|\) nodes, which naturally lends itself to the message passing scheme in the graph neural network presented in Sect. 3.3. The idea is that by iteratively updating the states of input and label nodes using message passing, complex label-label and label-substructure dependencies emerge. See Fig. 2 for an illustration.

Fig. 2
figure 2

Message passing in the joint graph of input nodes and label nodes. In (b, c), dash red link indicates message passing with attention while blue solid link indicates message passing with mean pooling (Color figure online)

4.1.1 Input node update

Since an input node i connects to its neighbor nodes \(j\in {\mathcal {N}}(i)\) and all the label nodes \(c\in \overline{1,C}\), the message passing update of the input node i at step t is formulated as follows:

$$\begin{aligned} {\varvec{x}}_{i}^{t}= & {} f\left( {\varvec{x}}_{i}^{t-1},\left\{ \left( {\varvec{x}}_{j}^{t-1},{\varvec{e}}_{ij}\right) \right\} _{j\in {\mathcal {N}}(i)},\left\{ {\varvec{l}}_{c}^{t-1}\right\} _{c\in \overline{1,C}}\right) \end{aligned}$$
(9)

Note that Eq. (9) is derived from Eq. (1) with the introduction of new arguments \(\left\{ {\varvec{l}}_{c}^{t-1}\right\} _{c\in \overline{1,C}}\).

There are two types of message sent to the input node i. One contains structure information from neighbor input nodes and the other contains label-related information from label nodes. Because these messages have different meanings, they should be aggregated into separate message vectors. In case of neighbor input nodes, we use mean pooling to combine them as similar to Eq. (3):

$$\begin{aligned} {\varvec{\mu }}_{i}^{t}= & {} \frac{1}{|{\mathcal {N}}(i)|}\sum _{j\in {\mathcal {N}}(i)}W_{{\varvec{e}}_{ij}}{\varvec{x}}_{j}^{t-1} \end{aligned}$$

However, mean pooling may not be ideal to aggregate labels since it equalizes the importance of each class towards the input node i. To overcome this issue, we use the attention mechanism (Bahdanau et al. 2014; Xu et al. 2015) to compute a weighted sum of all the label nodes as follows:

$$\begin{aligned} {\varvec{m}}_{i}^{t}= & {} \sum _{c=1}^{C}a_{ic}^{t}{\varvec{l}}_{c}^{t-1} \end{aligned}$$
(10)

where \(a_{ic}^{t}>0,\ \sum _{c=1}^{C}a_{ic}^{t}=1\) is the attention coefficient from the input node i to a label node c at time t, computed as:

$$\begin{aligned} s_{ic}^{t}= & {} {\varvec{u}}_{s}^{\intercal }\tanh \left( W_{s}{\varvec{x}}_{i}^{t-1}+U_{s}{\varvec{l}}_{c}^{t-1}+{\varvec{b}}_{s}\right) \end{aligned}$$
(11)
$$\begin{aligned} a_{ic}^{t}= & {} \frac{\exp \left( s_{ic}^{t}\right) }{\sum _{c'=1}^{C}\exp \left( s_{ic'}^{t}\right) } \end{aligned}$$
(12)

The set of all unnormalized attention scores \(s_{ic}^{t}\) in Eq. (11) forms a matrix \(S^{t}\in {\mathbb {R}}^{|{\mathcal {V}}|\times C}\), which we will reuse later.

For generality, Eqs. (1012) are written in a more compact form:

$$\begin{aligned} {\varvec{m}}_{i}^{t}= & {} \text {Attention}\left( {\varvec{x}}_{i}^{t-1},\left\{ {\varvec{l}}_{c}^{t-1}\right\} _{c\in \overline{1,C}}\right) \end{aligned}$$
(13)

We call the attention in Eq. (13) input-to-label attention.

In the state update phase, the new state \({\varvec{x}}_{i}^{t}\) of the input node i is computed as:

$$\begin{aligned} {\varvec{x}}_{i}^{t}=\text {Highway}\left( {\varvec{x}}_{i}^{t-1},\left[ {\varvec{\mu }}_{i}^{t},{\varvec{m}}_{i}^{t}\right] \right) \end{aligned}$$

where \(\left[ \cdot \right] \) denotes vector concatenation and \(\text {Highway}(.)\) is defined in Eq. (8).

4.1.2 Label node update

By connecting to every input node, a label node c can receive information about various substructures in the graph \({\mathcal {G}}\) through multiple steps of message passing. Among these substructures, only a few are related to the class c. Therefore, we use the attention mechanism to extract the most useful substructures for predicting class c and store them in the message vector as follows:

$$\begin{aligned} {\varvec{m}}_{c}^{t}= & {} \text {Attention}\left( {\varvec{l}}_{c}^{t-1},\left\{ {\varvec{x}}_{i}^{t-1}\right\} {}_{i\in \overline{1,|{\mathcal {V}}|}}\right) \end{aligned}$$
(14)

where \(\text {Attention}(.)\) is similar to the function defined in Eq. (13) with the role of input nodes and label nodes swapped. We denote this function label-to-input attention. The unnormalized score matrix \(S^{t}\) from Eq. (11) is reused here to save computation and improve consistency. However, the attention coefficients are be normalized over rows instead of columns of \(S^{t}\), i.e.,

$$\begin{aligned} a_{ci}^{t}= & {} \frac{\exp \left( s_{ic}^{t}\right) }{\sum _{i=1}^{|{\mathcal {V}}|}\exp \left( s_{ic}^{t}\right) } \end{aligned}$$

Finally, we compute the new state of the label node c using a different Highway Network as:

$$\begin{aligned} {\varvec{l}}_{c}^{t}= & {} \text {Highway}\left( {\varvec{l}}_{c}^{t-1},{\varvec{m}}_{c}^{t}\right) \end{aligned}$$

4.1.3 A priori label dependencies

When explicit label dependencies are available, a label graph can be formed in the same way as the input graph. Messages between label nodes is aggregated using mean-pooling as in Eq. (3):

$$\begin{aligned} {\varvec{\mu }}_{c}^{t}= & {} \frac{1}{|{\mathcal {N}}(c)|}\sum _{f\in {\mathcal {N}}(c)}W_{{\varvec{e}}_{cf}}{\varvec{l}}_{f}^{t-1} \end{aligned}$$

The state of the label node c is updated as:

$$\begin{aligned} {\varvec{l}}_{c}^{t}= & {} \text {Highway}\left( {\varvec{l}}_{c}^{t-1},\left[ {\varvec{m}}_{c}^{t},{\varvec{\mu }}_{c}^{t}\right] \right) \end{aligned}$$

4.1.4 Vector input as a special case

In many traditional multilabel classification problems, the input is represented as vector instead of graph. This can be seen as a special case of our model where the input graph \({\mathcal {G}}\) collapses into a single node \({\varvec{x}}\). With this observation, the state update of the input node at step t is:

$$\begin{aligned} {\varvec{m}}^{t}= & {} \text {Attention}\left( {\varvec{x}}^{t-1},\left\{ {\varvec{l}}_{c}^{t-1}\right\} _{c\in \overline{1,C}}\right) \\ {\varvec{x}}^{t}= & {} \text {Highway}\left( {\varvec{x}}^{t-1},{\varvec{m}}^{t}\right) \end{aligned}$$

The state of a label node c is updated as:

$$\begin{aligned} {\varvec{l}}_{c}^{t}= & {} \text {Highway}\left( {\varvec{l}}^{t-1},{\varvec{x}}^{t-1}\right) \end{aligned}$$

4.2 Learning

After T steps of message passing, we pass each class-specific final state vector \({\varvec{l}}_{c}^{T}\) to a multi-layer perceptron (MLP) with sigmoid activation on top to predict the present of class c:

$$\begin{aligned} o_{c}=\text {MLP}\left( {\varvec{l}}_{c}^{T}\right) \end{aligned}$$

Here the value of \(o_{c}\) is in (0, 1). The MLPs for all classes share the same parameters. For learning, we use a binary cross-entropy loss function which is defined as:

$$\begin{aligned} \mathcal {L}=\mathbb {E}_{\text {train}}\left( \sum _{c=1}^{C}y_{c}\log o_{c}+(1-y_{c})\log (1-o_{c})\right) \end{aligned}$$

where \(\mathbb {E}_{\text {train}}\) denotes the mean over all training data.

4.3 Scale to big graphs and many labels

When the number of nodes in the input graph (\(|{\mathcal {V}}|\)) and the number of classes (C) are large, it becomes expensive to calculated the unnormalized score matrix \(S^{t}\in {\mathbb {R}}^{|{\mathcal {V}}|\times C}\) in Eq. (11) for all steps \(t=1,2,...,T\). To handle this problem, we propose a new attention technique called hierarchical attention. At each layer, we define K (\(K\ll \min \left\{ |{\mathcal {V}}|,\ C\right\} \)) intermediate attentional factors between input nodes and label nodes. The input-label attentions are broken down into two steps as follows:

  • For label-to-input attention, we do factor-to-input attention then label-to-factor attention.

  • For input-to-label attention, we do factor-to-label attention then input-to-factor attention.

Label-to-input message aggregation. More concretely, the label-to-input message aggregation in Eq. (10) is replaced by:

$$\begin{aligned} {\varvec{m}}_{i}^{t}= & {} \sum _{k=1}^{K}a_{ik}^{t}\varvec{\lambda }_{k}^{t-1};\quad \text {for}\quad \varvec{\lambda }_{k}^{t-1}=\sum _{c=1}^{C}b_{ck}^{t}{\varvec{l}}_{c}^{t-1} \end{aligned}$$

where \(\varvec{\lambda }_{k}^{t-1}\) is the kth intermediate factor (\(k\in \overline{1,K}\)) that aggregates all label nodes; \({\varvec{m}}_{i}^{t}\) is the message to the input node i; \(a_{ik}^{t}\) is factor-to-input attention probability (i.e., \(\sum _{k=1}^{K}a_{ik}^{t}=1\)); and \(b_{ck}^{t}\) is label-to-factor attention probability (i.e., \(\sum _{c=1}^{C}b_{ck}^{t}=1\)).

To compute \(a_{ik}^{t}\) and \(b_{ck}^{t}\) we define two score matrices \(S_{1}^{t}=\left[ s_{1;ik}^{t}\right] \in {\mathbb {R}}^{|{\mathcal {V}}|\times K}\) and \(S_{2}^{t}=\left[ s_{2;ik}^{t}\right] \in {\mathbb {R}}^{C\times K}\) as follows:

$$\begin{aligned} s_{1;ik}^{t}= & {} {\varvec{u}}_{1}^{\intercal }\tanh (W_{1}{\varvec{x}}_{i}^{t-1}+{\varvec{z}}_{k}^{t-1}) \end{aligned}$$
(15)
$$\begin{aligned} \text {and}\ \ s_{2;ck}^{t}= & {} {\varvec{u}}_{2}^{\intercal }\text {tanh}(W_{2}{\varvec{l}}_{c}^{t-1}+{\varvec{z}}_{k}^{t-1}) \end{aligned}$$
(16)

where \({\varvec{u}}_{1},{\varvec{u}}_{2}\in {\mathbb {R}}^{d_{z}}\), \(W_{1}\in {\mathbb {R}}^{d_{x}\times d_{z}}\), \(W_{2}\in {\mathbb {R}}^{d_{l}\times d_{z}}\) and \({\varvec{z}}_{k}^{t}\in {\mathbb {R}}^{d_{z}}\), (\(k=\overline{1,K}\)) are parameters. Then factor-to-input attention probability and label-to-factor attention probability are computed as:

$$\begin{aligned} a_{ik}^{t}= & {} \frac{\exp (s_{1;ik'}^{t})}{\sum _{k'=1}^{K}\exp (s_{1;ik'}^{t})};\quad b_{ck}^{t}=\frac{\exp (s_{2;c'k}^{t})}{\sum _{c'=1}^{C}\exp (s_{2;c'k}^{t})} \end{aligned}$$

Input-to-label message aggregation. Likewise the two-step input-to-label message aggregation is computed as:

$$\begin{aligned} {\varvec{m}}_{c}^{t}=\sum _{k=1}^{K}\alpha _{ck}^{t}\varvec{\chi }_{k}^{t-1};\quad \text {for}\quad \varvec{\chi }_{k}^{t-1}=\sum _{i=1}^{|{\mathcal {V}}|}\beta _{ik}^{t}{\varvec{x}}_{i}^{t-1} \end{aligned}$$

where \(\varvec{\chi }_{k}^{t-1}\) is the kth intermediate factor (\(k\in \overline{1,K}\)) that aggregates all input nodes; \({\varvec{m}}_{c}^{t}\) is the message to the label node c; \(\alpha _{ck}^{t}\) is factor-to-label attention probability (i.e., \(\sum _{k}\alpha _{ck}^{t}=1\)); and \(\beta _{ik}^{t}\) is input-to-factor attention probability (i.e., \(\sum _{i}\beta _{ik}^{t}=1\)). The attention probabilities are respectively computed as:

$$\begin{aligned} \alpha _{ik}^{t}= & {} \frac{\exp (s_{1;ik}^{t})}{\sum _{i'=1}^{|{\mathcal {V}}|}\exp (s_{1;i'k}^{t})};\quad \beta _{ck}^{t}=\frac{\exp (s_{2;ck}^{t})}{\sum _{k'=1}^{K}\exp (s_{2;ck'}^{t})} \end{aligned}$$

where the scores \(s_{1;ik}^{t}\) and \(s_{2;ck}^{t}\) are computed using Eqs. (15,16).

It is clear that with this decomposition strategy, the number of computation steps reduces from \(\mathcal {O}\left( |{\mathcal {V}}|C\right) \) to \(\mathcal {O}\left( \left( |{\mathcal {V}}|+C\right) K\right) \) for \(K\ll \min \left\{ |{\mathcal {V}}|,\ C\right\} \).

4.4 Detecting higher-order relation

Higher-order label-label relation The iterative message passing scheme spreads information to distant nodes. Two labels can indirectly interact with each other after two step of updates: a label sends messages to input nodes which then redistribute the information back to other labels. This brings about higher-order label correlation.

Multi-resolution substructure-label relation Likewise, after t steps, an input node accumulates information from other nodes within t hops. Because t varies from 1 to T, our model can detect label-specific substructures with multiple resolutions via the label-to-input attention. This capability is discussed in detail in Sect. 5.1.8.

5 Experiments

We present empirical results on two comprehensive sets of experiments: one on graph-structured input (Sect. 5.1) and the other on traditional unstructured input (Sect. 5.2).

5.1 Multilabel classification with graph-structured input

Our experiments focus on biochemical databases of potential drugs. A drug is a moderate-sized molecule with desirable bioactivities treated as labels. In the molecular graph of a drug, nodes represent atoms and edges represent bond types.

5.1.1 Datasets

We use two real-world biochemical datasets:

  • 9cancers For this dataset, the goal is to predict drug activity against nine types of cancer (see Table 2). The activity is binary indicating whether there is a response, i.e., the drug reduces or prevents tumor growth. We first download nine separate datasets for each cancer type from PubChem.Footnote 1 Then, we search for drug molecules that appear in all datasets, which results in about 22 thousand molecules in total. Among them, there are 3,356 molecules active for at least one type of cancer. We select all the active molecules and 10,000 fully inactive molecules to create the final dataset for experiment.

  • 50proteins This dataset is about drug-protein binding prediction. Again, drugs are treated as input graphs while proteins are labels. We obtain the raw version from BindingDB.Footnote 2 In this dataset, the number of unique proteins (also called targets) is 595 and the number of unique drugs (or ligands) is 55,781. We select top 50 proteins that are bound by most ligands to construct our experimental dataset. There are 36,349 ligands in total with the average number of proteins to which one ligand binds is 1.35.

Table 2 Assay ID and name of nine cancers in 9cancers dataset extracted from PubChem

We divide each dataset into train/valid/test sets with the proportions of 0.6/0.2/0.2, respectively. The detailed statistics are shown in Table 3 and the number of label occurrences is shown in Fig. 3. The labels in 50proteins are sparse as each ligand links to at most 10 proteins (but the majority of ligands bind to only 1 or 2 proteins). Meanwhile, the labels in 9cancers are denser with nearly a thousand of drugs positive to all cancers.

Table 3 Statistics of all multilabel datasets with graph-structured inputs
Fig. 3
figure 3

Histogram of the number of common labels that each instance associates to in 9cancers and 50proteins

5.1.2 Baselines

For comparison, we employ the following data representations and associated multilabel classifiers:

Molecular fingerprint The first set of baselines works on molecular fingerprints. A molecular fingerprint is a binary vector whose each element is associated with a particular type of substructures in the molecular graph. We use the well-known Morgan algorithm from RDKitFootnote 3 to generate multiple fingerprints with an increasing radius from 1 to 5 to account for fine-grained levels of substructures. Then, these fingerprints are concatenated to form a final feature vector. For each radius, we set the length of the fingerprint hash vector to 100. This results in the final feature vector of size 500. We evaluate two models running on top of this vector representation:

  • The first model is a SVM with RBF kernels set as a base classifier for Binary Relevance algorithm (Tsoumakas and Katakis 2007). We denote this model as fp \(+\) SVM \(+\) BR.

  • The second model is a Highway Network (HWN) (Srivastava et al. 2015) followed by a fully connected neural network with sigmoid activation function. All highway layers share parameters. We denote the combination of fingerprint and HWN as fp \(+\) HWN. In this model, the dependencies among classes are implicitly captured through the intermediate hidden layers.

String representation SMILES is one of the most popular string representation of molecules which encapsulates the graph structure into its grammar. We consider SMILES as a sequence of characters and model it using a GRU (Cho et al. 2014). When reaching the end of the sequence, the last state of the GRU is fed to a 2-layer MLP that outputs prediction for all labels. This SMILES+GRU combination has been recently proven to be highly effective in drug evaluation and design (Segler et al. 2017).

Graph representation The last set of baselines handle graph-structured input directly. We select two representative models: Weisfeiler-Lehman Graph Kernels (WLs) (Shervashidze et al. 2011) for graph kernel based methods and Column Networks (CLNs) (Pham et al. 2017) for graph neural network based methods.

  • WL is an unsupervised graph2vec model that maps a graph into a characteristic representation vector. Each element of this vector is the count of a specific rooted subgraph (or tree) in the graph. Because the length of the graph representation vector is equal to the vocabulary size of the trees which is very big, in practice, the similarity (kernel) matrix for every pair of graphs is used instead. We precompute the kernel matrix for both training and testing data using the Weisfeiler-Lehman algorithm. The maximum height of the trees is chosen to be 3. For 9cancers, it results in about 49 thousand different tree structures for the entire graph dataset. Meanwhile, the total number of graphs is only about 13 thousands. Therefore, increasing height more than 3 will add very little information about graph similarity as the proportion of matching substructures approach zero. The kernel matrix for training graphs is used as input to a SVM wrapped by Binary Relevance (WL+SVM+BR) for multilabel classification.

  • CLN, on the other hand, is a supervised graph message passing neural network. We use the same model as in Pham et al. (2017) with a mean pooling layer on top message passing layers to compute the graph representation vector. This vector is then fed to a 2-layer MLP to predict all labels. Different from our \(\text {GAML}\), a CLN only captures the relations between the labels and the subgraphs at the topmost layer rather than at every layer.

The hyper-parameters of fp \(+\) HWN and SMILES \(+\) GRU are obtained through validation. Meanwhile, the hyper-parameters of CLN are set similar to the optimal hyper-parameters of our model (see below).

5.1.3 Model setting

In our model, the sizes of the node and edge embedding are both set to 50. We perform grid search for other hyper-parameters with the label embedding size in \(\{10,30,50,70,100\}\), the number of factors in \(\{1,5,10,15,20\}\), and the number of message passing layers in \(\{2,4,6,8,10\}\). Dropout is set for every graph input node with the rate of 0.3. We do not use dropout for label nodes as it results in low F1 score although it makes the model less overfitting. In addition, we set the batch size to 60 and 100 for 9cancers and 50proteins, respectively. We use Adam optimizer (Kingma and Jimmy 2014) with an initial learning rate of 0.001. During training, the learning rate will be reduced by half if the validation loss does not improve after 20 consecutive epochs. We train our model for a maximum of 300 epochs and may stop early after decaying the learning rate 4 times.

5.1.4 Evaluation metrics

We use popular metrics for multilabel classification which are micro, macro (sometimes called per label) F1 and micro, macro AUC. While micro F1 favors labels with many examples due to its global averaging, macro F1 treats all labels equally regardless of their sample size, hence, is a good indication of the model performance on small labels.

5.1.5 Parameter sensitivity

To have a deep understanding of how \(\text {GAML}\) works for graph structured input, we investigate the contribution of different hyper-parameters including: the number of message passing layers (Figs. 4, 5), the number of attention factors (Fig. 6), and the type of attention (Fig. 7). We report results for 50proteins, but similar results are also observed for 9cancers.

Fig. 4
figure 4

Learning curves on 50proteins with different number of message passing layers \(n\in \left\{ 2,4,6,8,10\right\} \). Best viewed in color

Fig. 5
figure 5

Micro AUC (a) and micro F1 (b) on 50proteins with different number of message passing layers \(n\in \left\{ 2,4,6,8,10\right\} \). Best viewed in color

Fig. 6
figure 6

Micro AUC (a) and micro F1 (b) on 50proteins with different number of factors \(k\in \left\{ 1,5,10,15,20\right\} \). Best viewed in color

Fig. 7
figure 7

Results on 50proteins with different type of attentions. Label-to-Input refers to unidirectional attention from label to input nodes; Input-to-Label refers to attention in the reverse direction; Both refers to bidirectional attention. Best viewed in color

From Fig. 5, it is seen that when the number of layers n is small, e.g. \(n=2\), the model performs sub-optimally. Increasing the number of layers usually improves the results. We hypothesize that at higher level, input nodes receive a wider range of structural information through message passing. However, when \(n\ge 6\), the improvement rate becomes steady and the model is more likely to overfit (see Fig. 5c). We believe there are two reasons for this situation: (i) the structure information from distant nodes is much less important than that from close neighbors; and (ii) the structure information at every node becomes more global and indistinguishable, causing difficulty for the model to detect meaningful substructures during prediction.

Another factor that affects the model performance is the type of attention. Generally, using attention provides better micro F1 score than not using it. However, the input-to-label attention seems to be redundant and causes misleading to the model. We observed that when the input-to-label attention is available, the model often has higher loss and lower micro AUC (see Fig. 7a, c). Meanwhile, the label-to-input attention is important as it helps the label nodes focus on particular substructures of the input graph to give accurate prediction. One interesting thing to note here is that the improvement of micro F1 by using attention mainly comes from micro Recall (as can be seen from Fig. 7b, d, e) and since the denominator in the micro Recall formula is constant (which is equal to the number of positive examples in the dataset), the number of true positives actually increases.

\(\text {GAML}\) performs worst in term of both micro AUC and micro F1 when the number of attention factors k is 1 which is equivalent to collapsing all the neighbor nodes into one aggregating vector. For other values from 10 to 20, the results are quite comparable, which suggests that a small value of k is usually sufficient.

5.1.6 Performance results

Table 4 shows the classification results for graph structured input. \(\text {GAML}\) consistently beats all baselines on all evaluation metrics. In particular, our model achieves about 2–3% higher F1 and about 0.25–1% higher AUC than the second best method (CLN) on both datasets. We believe this improvement comes from the fact that our model can associate labels with useful multi-resolution substructures of the input graph through attention mechanism while CLN does not have this capability. Furthermore, it is also clear that the models learning directly on graphs such as WL \(+\) BR or CLN usually provide better results than those learning on strings or vectors. For example, CLN achieves roughly 2% improvement in term of micro and macro F1 compared to its vector counterpart fp \(+\) HWN. Whereas, WL+BR produces about 2–4% higher macro and micro AUC than fp+SVM+BR.

Table 4 The performance in the multi-label classification with graph-structured input (m-X: micro average of X; M-X: macro average)

5.1.7 External knowledge of label dependencies

A priori label dependencies are known to improve model performance as they bring structural constraints to the output space (McCallum and Pereira 2001; Tsochantaridis et al. 2004). We consider the setting where label dependencies form a graph. The multilabeling becomes node classification in the label graph conditioned on the input graph. We investigate the case of 50proteins where the labels are sparse. We compute the protein-protein interaction (PPI) scores by using Human Integrated Protein-Protein Interaction rEference (HIPPIE) (Alanis-Lobato et al. 2016). HIPPIE provides a normalized scoring scheme that integrates multiple PPI sources (Schaefer et al. 2012), hence, is reliable. The PPI scores have already been normalized in the range of [0, 1]. We add an edge between two proteins if their interaction score is larger than a predefined threshold (which set to 0.5 in our experiment). Since the interaction scores are asymmetric, the edges are directed. Table 5 reports results of our model when external label dependencies are introduced. The results are improved on F1 measures but not on the AUC scores suggesting that the external label constraints may help balance recall and precision when labels are sparse.

Table 5 Results on incorporating external knowledge of label dependencies

5.1.8 Attention visualization

In Fig. 8, we show the label-to-input attention scores at different message passing layers when our model runs on 9cancers to see how our model matches labels to substructures of the input graph. At the first layer, the label nodes often attend to many input nodes. The reason is that input nodes at this level only contain information about their types. In addition, the attended input nodes are usually special atoms like Oxygen (8) or Nitrogen (7) instead of the common Carbon (6). However, the attention becomes more focused when going up to higher layer since the structure information at each input node has been updated via message passing. Sometimes, new substructures emerge and the model may switch its attention to these substructures if it finds them to be more appropriate.

Fig. 8
figure 8

Normalized label-to-input attention probability at 6 layers of \(\text {GAML}\) over 4 different molecular graphs sampled from 9cancers. Darker color refers to higher probability. Columns correspond to input graphs and rows correspond to layers with the first layer drawn on top then the second layer and so on. Each tick in the x-axis is labeled with the atomic number of the corresponding node in the input graph (6: Carbon, 7: Nitrogen, 8: Oxygen, 17: Chlorine). Best view in color (Color figure online)

From the label-to-input attention matrices in Fig. 8, we can map back to the molecule graph to detect meaningful substructures toward labels. In Fig. 9, we can observe the shift in the model attention with respect to the evolution of structures across layers. In particular, at layer 2, the model focuses most on the O-N substructure. However, at layer 3, the model changes its attention to N[6], N[5] and C[11] instead of Os. The reason is that the model becomes more interested in the appearance of two adjacent Ns in an aromatic group, which cannot be captured within two hops by starting at O. Therefore, an attention shift is performed by the model.

Note that although the attention shift looks disruptive in Fig. 8 as the model is highly attentive (due to well training), it is actually smooth under graphical view in Fig. 9 since the substructures rooted at N[6], N[5] and C[11] all contains the substructure O-N from the previous layer. This suggests a human-like concept transferring mechanism through attention where the old concept is not totally discarded but still exists as part of the new concept with less focusing from the brain. From layer 3 to layer 6, the model performs one more small attention shift (from N[6] to N[8]). We hypothesize the model does that to keep itself attended to the left ring only (instead of both the left and the right rings). This is reasonable because when N[8] receives more redundant information about the right ring, its attention score reduces from 0.48 (the 5th row) to 0.27 (the 6th row). The strong focus of the model on a particular substructure is also well demonstrated in Fig. 9. As we can see in the last row, although C[10] (at the last column) contains information about the whole molecular graph, its attention score is still significantly smaller than of the substructure rooted at N[8].

Fig. 9
figure 9

Attention visualization on substructures of a molecule with PubChem SID of 491286. This molecule is the second example in Fig. 8. Each row specifies the top 3 substructures with the highest attention score (sorted in descending order from left to right) at the corresponding layer. For each substructure at layer k, the root atom as well as its neighbor atoms and bonds up to k hops are highlighted in green. Each atom is displayed with its atomic number and its index number (in square brackets) in the molecule. Best view in color (Color figure online)

To discovery typical rooted substructures at a particular depth for a group of classes, we select a node with the best attention score averaged over the present classes for every molecular graph in the training data. Then, we perform clustering on the representation vector of these nodes to find similar substructures. Figure 10 shows an example of such common substructures shared by different molecules that is typical to all classes in 9cancers.

Fig. 10
figure 10

Common substructures shared by some molecules that are typical to all classes in 9cancers. Pictures from left to right show the evolution of the rooted substructures with depth from 2 to 5. Along the rows, the molecules are sorted by the average attention scores computed at the topmost layer. Best view in color (Color figure online)

5.2 Multilabel classification with unstructured input

We now test whether our proposed method can work on the traditional setting where the input is a vector.

5.2.1 Datasets

Four datasets are used in this experiments: media_mill, bookmarks, Corel5k and NUS-WIDE (see Table 6 for statistics). The former two belong to the text categorization domain where each instance is a document represented as binary bag-of-words. Meanwhile, the latter two belong to the image classification domain where each image is represented as a real-value feature vector. For all datasets, we follow the predefined train/test split so that our results can be comparable to others.

Table 6 Statistics of all multilabel datasets with unstructured input

5.2.2 Baselines

For comparison, we consider the following methods:

  • State-of-the-art classical methods for multilabel classification (MLC) evaluated in Madjarov et al. (2012), which are representative for broader classes of algorithms. They are RAkEL (Tsoumakas and Vlahavas 2007) for ensemble methods, ML-kNN (Zhang and Zhou 2007) for algorithm adaptation methods, HOMER (Tsoumakas et al. 2008) for label power set methods and Calibrated Label Ranking (Fürnkranz et al. 2008) for pairwise ranking. Most of these methods are implemented in well-known multilabel machine learning systems, such as Tsoumakas et al. (2011) and Read et al. (2016) with careful hyper-parameters tuning by the authors. Thus, their result are strong and reliable. For presentation compactness, we only report the best results in Madjarov et al. (2012).

  • Collective multilabel classification with CRF (CML) (Ghamrawi and McCallum 2005). This model can learn pairwise relations among labels via CRF, hence, should be selected as baseline for comparison. We use the Java implementation of CML released on GithubFootnote 4 and search for the optimal values of “train.gaussianVariance” in \(\{0.01,0.03,0.1,0.3,1,3,10\}\). However, we can only test this model on media_mill and NUS-WIDE since the other two datasets are not accepted by the implementation.

  • A Highway Network (HWN) (Srivastava et al. 2015), similar to what described in Sect. 5.1.

5.2.3 Model setting

The label embedding size is set to 50 for NUS-WIDE and media_mill, 75 for bookmarks and 30 for Corel5k. We project input vector to a low dimensional space by using a single layer neural network with ReLU activation before feeding it to GAML. The size of the projected vector is 55 for NUS-WIDE, 75 for media_mill, 110 for bookmarks and 50 for Corel5k. For all datasets, the number of message passing layers is set to 6. In training, the batch size for NUS-WIDE is 500 while for the other datasets, it is 100. We use k-fold cross validation where k is 9 for Corel5k and 5 for other dataset. The optimizer is Adam with an initial learning rate of 0.001. We reduce the learning rate by half if the valid loss does not decrease after 5 consecutive epochs for Corel5k and 20 for other datasets. The maximum number of epochs is 300 and the early stopping condition is 4 times of the learning rate decay.

5.2.4 Results

The classification performance of all the methods is presented in Table 7. The deep networks (HWN and \(\text {GAML}\)) outperform traditional methods and CML on most datasets except for Corel5k – the smallest dataset. Especially, on bookmarks, our model improves the micro F1 and macro F1 over the best traditional methods by about 7% and 10%, respectively. In the case of Corel5k, the best traditional method is CLR (see Madjarov et al. 2012), a ranking-based method that uses SVM as a base classifier. SVM appears to be more robust than deep networks on small datasets like Corel5k. Compared to HWN, \(\text {GAML}\) achieves better results on all datasets. This supports our model’s strength in learning relations between labels and the input at multiple levels of abstraction.

Table 7 The performance in the multi-label classification with unstructured input (m-X: micro average of X, M-X: macro average of X)

5.3 Remark on running time

Kernel methods such as SVM and WL \(+\) SVM are not very scalable against data size due to the quadratic storage and cubic running time for inversion of the kernel matrix. Training these models may take tens of hours on a single CPU with a moderate data size of 30K. Deep learning methods such as HWN, GRU, CLN and \(\text {GAML}\) do not suffer from the same limitation and they are trained on GPUs, hence run much faster. HWN runs significantly faster than \(\text {GAML}\) since it only deals with flat vectors instead of structured graphs. CLN is lightly faster than \(\text {GAML}\) (about 10%) since it shares similar message passing complexity.

6 Discussion

We introduced \(\text {GAML}\), a new graph neural network to tackle an open problem of multi-label learning over graph structured data. The key insight is to realize that label nodes and input nodes can be put into a joint graph to model the multi-way relations among labels and subgraphs. This is achieved through a message passing scheme that exchanges information between connected nodes across multiple steps and an attention mechanism that enables selective flowing of information between label nodes and input nodes. Our model is highly flexible and scalable. We evaluated \(\text {GAML}\) using an extensive set of experiments on both graph structured and unstructured inputs. Our results clearly demonstrate the efficacy of the proposed model.

This work opens up a wide room for the future at both applied and theoretical fronts. \(\text {GAML}\) is directly applicable to many other domains. One example is shopping basket recommendation, where users play the role of labels (with or without profile), and item basket modeled as input graph of items. Alternatively, items recommendation to user group works in a similar way, where the user group forms a social graph, and items play the role of labels. At the modeling front, a next step is to extend \(\text {GAML}\) from label node classification to full graph prediction, where edges are also predicted. Additionally, the current setting is open for auxiliary tasks, e.g., the input graph is node-labeled.