Technology

Transfer Learning, Explained

Abhay Abhay 4 min read
Transfer Learning, Explained
Photo by Steven Lelham on Unsplash

Suppose you want to build a classifier that tells apart 200 species of moths, and you have exactly 1,400 photos. Train a deep convolutional network from scratch on that, and it will memorise your moths, then fall over in a stiff breeze. The honest truth is that big neural networks are hungry — they want millions of images, not a shoebox of them. So here’s the trick the whole field leans on: don’t start from scratch. Borrow a model that already learned to see on a giant dataset, and re-point it at your problem. That’s transfer learning, and it’s the closest thing deep learning has to a free lunch.

The relay-race idea

Think of it as a relay race. Some research lab already ran the exhausting first leg — they trained a network on ImageNet, roughly 1.2 million labelled photos across 1,000 categories, burning serious GPU time. You don’t re-run that leg. You grab the baton (the trained weights) and sprint the short final stretch that’s specific to your task.

Why does a model trained on toasters, terriers, and traffic lights help with moths? Because of how convolutional networks organise what they learn. The early layers pick up generic, almost universal visual primitives — edges, corners, colour blobs, textures. The middle layers compose those into shapes and motifs. Only the final layers become truly task-specific (“this particular combination of fuzz and antennae means Luna moth”). Edges are edges whether you’re looking at a cat or a moth, so those early, hard-won features transfer beautifully. You’re really only missing the last leg: the specialised bit at the top.

Two ways to reuse a model

There are two flavours, and choosing between them is most of the skill.

Feature extraction (freeze). You take the pretrained network, lop off its original classifier head, and freeze every remaining layer — their weights won’t change. The frozen stack becomes a fixed “feature factory”: feed an image in, get a rich numerical fingerprint out. You then train one small new head on top of those fingerprints. It’s fast, light, and resistant to overfitting because you’re training only a handful of parameters. This is your default when data is scarce or your images look broadly like ImageNet’s.

Fine-tuning (unfreeze). Here you unfreeze some of the upper layers and let them keep learning on your data, usually at a tiny learning rate. This nudges those middle-to-high features away from “generic photo” and towards “my specific domain.” It can squeeze out real accuracy gains, but it’s slower, needs more data, and will happily overfit if you’re careless. The cardinal rule: use a very low learning rate. Pair freshly initialised top layers with big gradient updates and you’ll blow away the delicate pretrained features you came for — catastrophic forgetting, live and in colour.

A practical recipe combines both: extract first, then fine-tune. Train the new head while everything’s frozen, then unfreeze the top block and continue at a crawl.

A vision example in Keras

Here’s the canonical pattern with an ImageNet backbone. Note the two details that trip everyone up: keeping the base in inference mode (training=False) so its BatchNorm statistics survive, and recompiling after you flip trainable.

import keras

# 1. Load a pretrained backbone, minus its 1000-class head
base = keras.applications.MobileNetV2(
    weights="imagenet", include_top=False, input_shape=(160, 160, 3)
)
base.trainable = False  # FREEZE: feature-extraction mode

# 2. Bolt on a small task-specific head
inputs = keras.Input(shape=(160, 160, 3))
x = base(inputs, training=False)          # inference mode -> protects BatchNorm
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)
outputs = keras.layers.Dense(200, activation="softmax")(x)  # 200 moth species
model = keras.Model(inputs, outputs)

model.compile(optimizer="adam",
              loss="sparse_categorical_crossentropy", metrics=["accuracy"])
model.fit(train_ds, validation_data=val_ds, epochs=10)

# 3. Optional fine-tune: UNFREEZE the top, crawl with a tiny LR
base.trainable = True
model.compile(optimizer=keras.optimizers.Adam(1e-5),   # ~100x smaller
              loss="sparse_categorical_crossentropy", metrics=["accuracy"])
model.fit(train_ds, validation_data=val_ds, epochs=5)

Swap MobileNetV2 for ResNet50, EfficientNetV2, or Xception and the shape of the code is identical — that’s the beauty of it.

When it helps (and when it doesn’t)

Transfer learning shines when your dataset is small-to-medium and your domain isn’t wildly alien to the source. The further your images drift from natural photos — think raw medical scans or satellite imagery — the less the high-level features transfer, and the more fine-tuning (or more data) you’ll need. If your dataset is genuinely huge and unusual, training from scratch can eventually win. But that “eventually” costs a lot of GPU-hours, and most of us don’t have an ImageNet of our own lying around.

The takeaway

Next time you face a vision task with limited data, don’t train from zero — start from a pretrained ImageNet backbone. Default to feature extraction (freeze the base, train a new head). If you have enough data and want more accuracy, then unfreeze the top layers and fine-tune at a learning rate around 10–100x smaller than usual, keeping BatchNorm in inference mode. Reach for from-scratch training only when your domain is both enormous and unlike anything in ImageNet. You’ll get further, faster, on a fraction of the data — and your GPU bill will thank you.

Sources: Label Your Data — Transfer Learning vs Fine-Tuning, Keras Transfer Learning Guide, Codefinity — Fine-Tuning vs Feature Extraction.

More posts