Graph Neural Network
Graph Neural Network
[1] A Gentle Introduction to Graph Neural Networks
[2] Graph neural networks: A review of methods and applications
Graph Neural Networks (GNNs) are a class of deep learning models designed to process graph-structured data. Unlike traditional neural networks (e.g., CNNs for grids, RNNs for sequences), GNNs explicitly model relationships between entities by propagating information through nodes and edges. They have become essential tools for tasks involving relational or topological data.
Graph
Graph Data: A graph consists of:
- Nodes/Vertices (V): Represent entities (e.g., users, molecules, or cities).
- Edges (E): Represent relationships or interactions between nodes.
- Graph (G): The entire graph structure.
Node/Edge/Graph Features: Optional attributes associated with nodes, edges, and the graph (global feature or master node).
We can additionally specialize graphs by associating directionality to edges (directed, undirected).
data:image/s3,"s3://crabby-images/796b6/796b69b532575390d3bf0821115fc83009d85cef" alt="Graph representation"
We use an adjacency matrix to describe a graph. However, it can be inefficient when the scale of the graph is huge (e.g., for nodes, the scale of the matrix is ). So, an adjacency list can be a great option. We avoid computation and storage on the disconnected parts of the graph.
data:image/s3,"s3://crabby-images/b9cc2/b9cc27cf5ced37ace2e04abe4c8579264463f6fb" alt="Graph attributes representation"
Each node and edge has its feature vector.
The key feature of a graph is its permutation invariance. Permutation invariance refers to a property where the output of a function remains unchanged when the input's order is rearranged. For example, summation is permutation invariant because the order of elements doesn't affect the result (e.g., ). This requires GNNs to perform optimizable transformations on all attributes of the graph (nodes, edges, global-context) that preserve graph symmetries (permutation invariances).
Task
GNNs have three levels of tasks: node-level, edge-level, and graph-level.
- Node-level prediction problems are analogous to image segmentation, where we are trying to label the role of each pixel in an image. With text, a similar task would be predicting the parts-of-speech of each word in a sentence (e.g., noun, verb, adverb, etc.).
- Edge-level problems predict which of these nodes share an edge or what the value of that edge is.
- Graph-level tasks are analogous to image classification problems with MNIST and CIFAR, where we want to associate a label to an entire image. With text, a similar problem is sentiment analysis, where we want to identify the mood or emotion of an entire sentence at once.
Graphs have up to four types of information that we will potentially want to use to make predictions: nodes, edges, global-context, and connectivity. After several iterations, we can apply a classification layer to each of its information (feature vector) to predict.
data:image/s3,"s3://crabby-images/01601/01601736e53b88fdcbfcfaeba040d93b8633bc20" alt="An end-to-end prediction task with a GNN model"
Message Passing
Neighboring nodes or edges exchange information and influence each other’s updated embeddings.
Message passing works in three steps:
- For each node in the graph, gather (concatenate) all the neighboring node embeddings (or messages).
- Aggregate all messages via an aggregate function (like sum).
- All pooled messages are passed through an update function, usually a learned neural network.
This is the node-node message passing process.
data:image/s3,"s3://crabby-images/a61f5/a61f5efdfda12f4f92248dbb8ac22138d179ae94" alt="Node-node message passing"
data:image/s3,"s3://crabby-images/0dabd/0dabd32ddbb619d9c99ef313b77096863a88014d" alt="Node-node message passing"
In the picture, is the pooling process (step 1 and step 2 in message passing).
We also have node-level, node-graph, edge-graph message passing patterns and vice versa.
data:image/s3,"s3://crabby-images/facda/facda3f337fb7ebf3f90ff70d9285732c748cb1b" alt="Node-edge-node"
Which graph attributes we update and in which order we update them is one design decision when constructing GNNs. We could choose whether to update node embeddings before edge embeddings or the other way around. This is an open area of research with a variety of solutions–for example, we could update in a ‘weave’ fashion, where we have four updated representations that get combined into new node and edge representations: node to node (linear), edge to edge (linear), node to edge (edge layer), edge to node (node layer).
data:image/s3,"s3://crabby-images/b3a79/b3a79ed3debaf9732ec0215702b78a4f45829d04" alt="Weave node-edge-node, edge-node-edge"
When we want to add a global representation (graph feature), one solution to this problem is to add a master node or context vector, which is virtually proposed. This global context vector is connected to all other nodes and edges in the network and can act as a bridge between them to pass information, building up a representation for the graph as a whole. This creates a richer and more complex representation of the graph than could have otherwise been learned.
So, we have node, edge, and global representations. We can choose which of them to aggregate and update in the iterations.
data:image/s3,"s3://crabby-images/63eb4/63eb4eb9a8d7c2dd90b6471916417091acefa444" alt="Conditioning information"
Key GNN Architectures
GCN
A Graph Convolutional Network (GCN) is a neural network architecture designed to process graph-structured data. It extends convolutional operations to irregular graphs by aggregating features from a node's local neighborhood. GCNs are widely used for tasks like node classification, link prediction, and graph classification.
Key Components:
Graph Structure:
- Adjacency matrix , where is the number of nodes.
- Degree matrix , where .
- Node feature matrix , where is the feature dimension.
Self-Loops: Add self-connections to to include a node’s own features:
Normalized Adjacency Matrix:
where is the degree matrix of .
Layer-wise Propagation Rule
The output of the -th GCN layer is computed as:
- : Input features of the -th layer ().
- : Trainable weight matrix.
- : Activation function (e.g., ReLU).
Example
Step 1: Define Inputs
Adjacency Matrix (3 nodes):
Node Features ():
Step 2: Compute Normalized Adjacency Matrix
Degree Matrix :
Normalized (values rounded for simplicity):
Step 3: Apply GCN Layer
Assume and ReLU activation:
After matrix multiplication and ReLU, the output features are transformed.
GAT
A Graph Attention Network (GAT) is a neural network architecture designed to process graph-structured data, similar to GCNs. However, GAT introduces an attention mechanism to weigh the importance of neighboring nodes dynamically. This allows the model to focus on more relevant neighbors during feature aggregation, making it more flexible and expressive than GCNs.
Key Components:
Graph Structure:
- Adjacency matrix , where is the number of nodes.
- Node feature matrix , where is the feature dimension.
Attention Mechanism: Computes attention coefficients between nodes to weigh their contributions during aggregation.
Attention Mechanism
For a node and its neighbor , the attention coefficient is computed as:
: Feature vectors of nodes and .
: Shared weight matrix for feature transformation.
: Weight vector for the attention mechanism.
: Concatenation operation.
LeakyReLU: Nonlinear activation function.
Normalized Attention Coefficients
The attention coefficients are normalized using the softmax function:
- : Set of neighbors of node .
Feature Aggregation
The output feature for node is computed as a weighted sum of its neighbors' features:
- : Nonlinear activation function (e.g., ReLU).
Example
Step 1: Define Inputs
Adjacency Matrix (3 nodes):
Node Features ():
Step 2: Compute Attention Coefficients
Assume and :
For node 1 and its neighbors (nodes 2 and 3):
Normalize using softmax:
Step 3: Aggregate Features
Compute the output feature for node 1:
GraphSAGE
GraphSAGE is a framework for inductive representation learning on large graphs. Unlike GCNs, which require the entire graph during training, GraphSAGE generates node embeddings by sampling and aggregating features from a node's local neighborhood. This makes it scalable to large graphs and capable of generalizing to unseen nodes.
Key Components:
Graph Structure:
- Adjacency matrix , where is the number of nodes.
- Node feature matrix , where is the feature dimension.
Neighborhood Sampling: For each node, a fixed-size subset of neighbors is sampled to reduce computational complexity.
Aggregation Functions: Combines features from a node's neighbors. Common choices include mean, LSTM, or pooling aggregators.
Layer-wise Propagation Rule
For each node , the embedding at the -th layer is computed as:
Aggregate Neighbor Features
- : Sampled neighbors of node .
- : Aggregation function (e.g., mean, LSTM, max-pooling).
Combine Features:
- : Trainable weight matrix.
- : Concatenation operation.
- : Nonlinear activation function (e.g., ReLU).
Output Embedding
After layers, the final embedding for node is:
data:image/s3,"s3://crabby-images/0e64a/0e64a8dc915773fae943337a98882ecb211d06ce" alt="GraphSAGE"
Summary
In summary, GNNs differ in which components exchange information with each other, the aggregation function, and the update function.
data:image/s3,"s3://crabby-images/b48b9/b48b9e47832a59a599389d801fded72c48704d37" alt="Types of GNN"
Code
Here we write codes that can randomly produce several graph alongside their features. We use networkx to visualize. For better understand the calculation process, we preset the parameters to be identity matrix.
data:image/s3,"s3://crabby-images/c2e42/c2e42afa92a91cda3b4136ae6bd1b95c5f600ff2" alt="input graph"
data:image/s3,"s3://crabby-images/20722/207221dd9425edd18a138d41b2de724a3acf20fb" alt="result"
import networkx as nx
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
class GCNs(nn.Module):
def __init__(self, num_feature_in, num_feature_out):
super(GCNs, self).__init__()
# self.W = torch.nn.Parameter(torch.randn([num_feature_in, num_feature_out]))
self.W = torch.nn.Parameter(torch.eye(num_feature_in, num_feature_out))
def forward(self, adj_matrix, node_feature):
'''
if the input is not in batch, we need to
add a dimension(batch dimension) to the input.
'''
if len(adj_matrix.shape) == 2:
adj_matrix = torch.unsqueeze(adj_matrix, dim=0)
if len(node_feature.shape) == 2:
node_feature = torch.unsqueeze(node_feature, dim=0)
# make sure the input is float type.
adj_matrix = adj_matrix.float()
node_feature = node_feature.float()
# number of nodes
num_nodes = adj_matrix.shape[1]
A = adj_matrix + torch.eye(num_nodes)
A = A.float()
# calculate degree matrix
D = torch.diag_embed(A.sum(dim=-1))
D = D.float()
D_sqrt = D.pow(-0.5)
D_sqrt[D_sqrt.isinf()] = 0
H_t1 = torch.bmm(D_sqrt, A)
H_t1 = torch.bmm(H_t1, D_sqrt)
H_t1 = torch.bmm(H_t1, node_feature)
H_t1 = torch.nn.functional.relu(torch.matmul(H_t1, self.W))
return H_t1
# input
batch_size = 4
num_node = 6
num_feature = 2
# create upper triangular part of the adjacent matrix
map_input_n = torch.triu(torch.randint(0, 2, (batch_size, num_node, num_node)))
map_input_n = map_input_n + map_input_n.transpose(-2, -1)
map_input_n.diagonal(dim1=1, dim2=2).zero_() # diagonal part = 0
print(map_input_n)
# node feature
map_input_feature_n = torch.randint(0, 3, (batch_size, num_node, num_feature))
# visualize
graph = [nx.Graph(i.numpy()) for i in map_input_n] # list conatins nxgraph
fig, axe = plt.subplots(2, 2)
axe = axe.flatten() # flatten for better index
for i in range(batch_size):
ax = axe[i]
for j in range(num_node): # add feature vector to every node
graph[i].nodes[j]['feature_vector'] = map_input_feature_n[i][j][:].numpy()
nx.draw(graph[i], ax=ax, with_labels=True, font_size=12, font_color='black', node_color='lightblue',
edge_color='gray')
ax.set_title(f"Graph {i + 1}")
plt.tight_layout()
plt.show()
# output
net = GCNs(num_feature, num_feature)
with torch.no_grad():
print(f"Your input is: \n {map_input_feature_n}")
output = net(map_input_n, map_input_feature_n)
print(f"Your output is: \n {output}")