Graph Neural Networks Part 2. Graph Attention Networks vs. GCNs

Author:Murphy  |  View: 22117  |  Time: 2025-03-22 20:09:46

Welcome to the second post about GNN architectures! In the previous post, we saw a staggering improvement in accuracy on the Cora dataset by incorporating the graph structure in the model using a Graph Convolutional Network (GCN). This post explains Graph Attention Networks (GATs), another fundamental architecture of graph neural networks. Can we improve the accuracy even further with a gat?

First, let's talk about the difference between GATs and GCNs. Then let's train a GAT and compare the accuracy with the GCN and basic neural network.

This blog post is part of a series. Are you new to GNNs? I recommend you to start with the first post, which explains graphs, neural networks, the dataset, and GCNs.


Graph Attention Networks

In my previous post, we saw a GCN in action. Let's take it a step further and look at Graph Attention Networks (GATs). As you might remember, GCNs treat all neighbors equally. For GATs, this is different. GATs allow the model to learn different importance (attention) scores for different neighbors. They aggregate neighbor information by using attention mechanisms (this might ring a bell because these mechanisms are also used in transformers).

How does this work? In the GCN, we only looked at the degree of the nodes. GATs on the other hand, also take the feature values into account to assign attention scores to different neighbors.

So instead of treating all neighbors equally, an attention mechanism is introduced that assigns varying levels of importance to different neighbors. This allows the network to focus on the most relevant parts of the graph structure, essentially learning "where to look" when making predictions.

So, how exactly does the attention mechanism work in GATs? Let's break it down.

Step 1: Computing Attention Scores

For each node, we calculate an attention score for every neighboring node. This score is a measure of how important a specific neighbor's features are when updating the current node's features. The score is learned during training, so the model decides which nodes matter most for each task.

There are multiple ways of computing attention scores in GATs. In this post, I explain the second version instead of the first, because most of the time this method is more effective than the original one.

Mathematically, given a node i and its neighbor j, the attention coefficient​ is computed as follows:

Feature TransformationWe start with two feature vectors of nodes i and j, and the first step is to apply a shared weight matrix W to the features:

Next, the transformed features are summed (in the original GAT version the features were concatenated):

Score CalculationNow we can calculate the raw (unnormalized) attention score, using a LeakyReLU function:

Step 2: Normalizing Attention Scores

The raw attention scores from the previous step​ are normalized across all neighbors of node i using the softmax function. This ensures that the coefficients are easy to interpret (as they sum to 1 for each node):

The normalized attention coefficients​ determine how much weight each neighbor j contributes to the new feature representation of node i.

Step 3: Aggregating Neighbor Information

Finally, the node i‘s new feature representation is computed as a weighted sum of its neighbors' transformed features, where the weights are given by the attention coefficients​:

So now we got our attention scores and updated feature representations! Let's continue with another important aspect of GATs, multi-head attention.

Multi-Head Attention

Just like transformers, GATs often use multi-head attention to improve their performance. But what does multi-head attention mean, and why is it so beneficial?

Multi-head attention refers to running several separate attention mechanisms, or heads, in parallel. Each of these heads independently computes attention scores for the neighbors of a node, learning to focus on different aspects of the graph structure or node features. After these heads process the graph, their outputs are either concatenated or averaged to form the final node representation.

So one of the key reasons of using multiple heads instead of one is to learn diverse patterns, because each attention head has its own learnable parameters and can learn to focus on different parts of the neighborhood. Another reason is that it stabilizes the training process. You can compare it with an ensemble, other heads can compensate for a "noisy head".

A center node with 6 neighbors. Two different attention heads are represented by the blue and green arrows. Thickness of the arrows represent the varying levels of importance (the attention scores). Image by author.

How is multi-head attention implemented in GATs? The first step is that each attention head computes its own set of attention scores and new node features independently. For N heads, and a given node i, we'll end up with N different sets of transformed features. Next up, all outputs are concatenated (stacked) or averaged. Concatenation is more common because it increases the model's expressiveness, but on the other hand the output dimension will be larger. Averaging helps to smooth out the differences between the heads. A general rule is to use concatenation when it's a hidden layer in the network and averaging when it's the last layer. When all attention heads are combined, we hope to get a comprehensive view of the graph, because the different heads have different perspectives on the relationships in the graph.

Multiple heads according to Dall·E.

PyTorch Implementation

Let's implement a GAT in python, and train it on the Cora dataset. You can use the same setup as in the previous post.

from torch_geometric.nn import GATv2Conv  # use GATConv for the first GAT version
import torch
import torch.nn.functional as F

class GAT(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, heads=8):
        super().__init__()
        self.gat1 = GATv2Conv(input_dim, hidden_dim, heads=heads)
        # for the last GAT layer we use concat=False to average the outputs of the heads
        self.gat2 = GATv2Conv(hidden_dim * heads, output_dim, heads=heads, concat=False)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.3, training=self.training)
        x = F.elu(self.gat1(x, edge_index))
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.gat2(x, edge_index)
        return F.log_softmax(x, dim=1)

Looking back, the results for the MLP and GCN were as follows:

MLP Test Accuracy: 54.35 ± 1.06
GCN Test Accuracy: 78.76 ± 0.38

Will we improve this with the GAT? Let's run the code (same as previous post):

for model_class in [MLP, GCN, GAT]:
    results[model_class.__name__] = []
    # same training loop as before...

The model is training… The GAT model takes a bit longer than the GCN and MLP…

And here is the result:

GAT Test Accuracy: 78.45 ± 1.11

The GAT performance is comparable with the GCN! This can happen, and it looks like for the Cora dataset it doesn't matter which model we use. But we didn't do any finetuning on both models, so maybe the GAT will be better in the end.

According to the original GAT paper (version 1), GATs outperform GCNs on the benchmark datasets.

Considerations for GATs

While GATs have shown great promise in improving accuracy (not in this post, but it's better to trust the paper here), there are a few things to keep in mind:

  • The attention mechanism in GATs adds additional complexity to the model, both in terms of computation and the number of parameters. This makes GATs more resource-intensive and slower to train than GCNs.
  • Multi-head attention helps stabilize training, but there is still a risk of overfitting, especially when using many attention heads or deep GAT architectures. Using techniques like dropout and early stopping can help to mitigate this.
  • One advantage of GATs is that they provide interpretability through attention scores. These scores can be analyzed to understand which nodes are most influential in making predictions, offering insights into the graph structure.
  • Another point I didn't address in the previous post is how to finetune GNNs. Many steps in finetuning GNNs are similar to traditional neural networks: testing different values for the hyperparameters and preventing overfitting with early stopping. For example with GATs you need to tune the number of attention heads. Small changes to node and edge features can have an impact on GNN performance, so it might help to experiment with different feature combinations or to create new features. Augmenting data can improve generalization. You can do this by adding noise to edges, randomly dropping nodes, or by performing subgraph sampling.

Conclusion

GATs take GCNs a step further by introducing attention mechanisms that assign different levels of importance to each node's neighbors. This added flexibility allows GATs to achieve better performance in many cases compared to GCNs. However, this comes at the cost of increased computational complexity and the need to tune extra hyperparameters like the number of attention heads.

GATs and GCNs represent just two foundational architectures of GNNs. Each has its strengths and trade-offs, and the choice of which to use depends on the dataset and prediction task. For many tasks, GATs can offer a performance boost, especially when the relationships between nodes are not uniformly important.

Are you curious about other architectures? I'm not sure yet how I will follow up these two blog posts on GNNs. For those who can't wait, here are some other interesting architectures and papers to investigate:

  • If you are looking for an overview of GNN models and a general design pipeline, this paper is a good place to start.
  • Relational Graph Convolutional Networks (R-GCNs) are an extension of GCNs. R-GCNs are specifically designed for situations where edges can have different types or relations. R-GCNs use relation-specific weights to handle different types of edges and their unique relationships.
  • GraphSAGE samples a fixed number of neighbors for each node, instead of aggregating features from all neighbors (like GCNs and GATs do). It's an interesting architecture for efficient, large-scale graph representation learning.
  • SEAL is specifically designed for link prediction. It extracts a local subgraph around each target link and shows great performance. Here you can find the GitHub repo.
  • Is it possible to generate realistic graphs? GraphRNN uses auto-regressive models to model graph generation, where nodes and edges are treated as time-based events.

Related

Graph Neural Networks Part 1. Graph Convolutional Networks Explained

Optimizing Connections: Mathematical Optimization within Graphs

Simplify Your Machine Learning Projects

Tags: gat Getting Started Graph Attention Networks Graph Neural Networks node-classification

Comment