Technology

Graph Neural Networks (GNNs), Explained

Abhay Abhay 4 min read
Graph Neural Networks (GNNs), Explained
Photo by Jakub Żerdzicki on Unsplash

Most of the data we actually care about doesn’t fit neatly into a spreadsheet. Your social circle isn’t a list of names; it’s a web of who-knows-whom. A molecule isn’t a string of characters; it’s atoms wired together by bonds. A recommendation isn’t a row in a table; it’s the tangle of which users liked which things. All of these are graphs — nodes connected by edges — and for years our best neural networks politely ignored that structure, flattening everything into vectors and hoping for the best.

Graph Neural Networks (GNNs) are the architecture that finally stopped pretending. Instead of forcing your messy, beautiful network into a grid, a GNN learns directly on the graph itself. Here’s how.

The one idea that powers all of it: message passing

If you remember nothing else, remember this: a GNN is just nodes gossiping with their neighbors.

The mechanism is called message passing, and it’s the single operation that separates a GNN from every other neural network. It runs in three steps, repeated layer by layer:

  1. Message — every node sends its current feature vector to its neighbors.
  2. Aggregate — each node gathers the messages from its neighbors and combines them with a simple, order-independent operation (sum, mean, or max). Order-independent matters, because your friends don’t come in a fixed sequence.
  3. Update — the node feeds its own features plus the aggregated neighbor info through a small neural network to produce a new representation.

Do that once, and every node knows about its immediate neighbors. Do it twice, and it knows about neighbors-of-neighbors. Stack a few layers and information ripples outward across the graph, exactly the way a rumor travels through an office — except here the rumor is a learned embedding and nobody gets fired.

That’s genuinely the whole trick. Every famous GNN variant is a remix of “let nodes talk to their neighbors.”

GCN and GraphSAGE in one breath each

  • GCN (Graph Convolutional Network): the classic. It averages messages from neighbors using a normalized version of the graph’s connections, then applies a linear transform and a nonlinearity. Elegant, but it wants to see the whole graph at once.
  • GraphSAGE: built for graphs too big to swallow whole. Instead of listening to every neighbor, it samples a handful and summarizes them. That makes it scale to billions of edges — the difference between reading every reply in a thread and skimming a representative few.

There are flashier cousins too (GAT, which learns how much to weight each neighbor with attention), but GCN and GraphSAGE are the two you’ll meet first.

What you can actually do with the embeddings

Once message passing has produced rich vectors, you pick a task by choosing what you predict on:

  • Node-level: classify a single node. Is this account a bot? Is this paper about biology?
  • Edge-level: predict whether a connection should exist. This is link prediction, and it’s the quiet engine behind “People You May Know” and product recommendations.
  • Graph-level: summarize the entire graph into one prediction. Is this molecule toxic? Will this protein bind? This is where GNNs have become genuinely useful in drug discovery.

Same machinery, three altitudes. The graph doesn’t change; only the read-out does.

A taste in code

With PyTorch Geometric, a two-layer GCN for node classification is almost suspiciously short:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, num_classes):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)   # round 1 of gossip
        self.conv2 = GCNConv(hidden_dim, num_classes)  # round 2

    def forward(self, x, edge_index):
        # x: node features [N, in_dim]
        # edge_index: the graph's connections [2, num_edges]
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)  # class scores per node

Each GCNConv is one full round of message passing. Notice you never tell it the shape of the graph — you just hand it the edges, and the layer figures out who talks to whom. That flexibility is the whole point.

Where they show up in the wild

GNNs power friend and content recommendations at social-scale platforms, predict molecular and protein properties in pharma pipelines, flag fraud rings in payment networks (fraudsters betray themselves through their connections), forecast traffic on road graphs, and even help route packets and chips. Anywhere relationships carry the signal, a GNN can learn from them.

The takeaway

Reach for a GNN when the connections in your data matter as much as the data points themselves — and don’t overthink the architecture. Your starting recipe: model the problem as a graph (decide what’s a node and what’s an edge), pick your task altitude (node, edge, or graph), and stack two or three message-passing layers in PyTorch Geometric — GCN if the graph fits in memory, GraphSAGE if it doesn’t. Resist going deeper than three or four layers; pile on too many and every node’s embedding blurs into the same mush (a real failure mode called over-smoothing). Start shallow, let the neighbors gossip, and add depth only when the metrics ask for it.


Sources: Message Passing in GNNs (Kumo.ai / PyG), Graph Neural Network glossary (TigerGraph), GCN vs SAGE shootout.

More posts