Research/Blog
Metric-based Meta Learning
- July 23, 2020
- Posted by: Bhavesh Laddagiri
- Category: Artificial General Intelligence (AGI) Reinforcement Learning
#CellStratAILab #disrupt4.0 #WeCreateAISuperstars #WhereLearningNeverStops
Recently, I presented a session on Metric-based Meta Learning at the CellStrat AI Lab (where I am an AI Researcher).
Metric-based Meta Learning might be considered a domain of Artificial General Intelligence (AGI). This is due to the fact that Meta Learning helps us create generalized systems with relatively less data.
Introduction
Meta Learning (aka Learning to Learn) is about designing models which can learn and adapt to new tasks fast and efficiently without much fine-tuning. We want to design models which can learn to generalize well to new tasks, data or environments which are different from the ones in training. E.g. :
- A game playing bot trained to play PUBG should be able to generalize to Call of Duty
- A robot trained to walk on flat soil should be able to walk on ragged terrain
- An image classifier trained on dogs should be able to recognize cats given a few images of what a cat looks like
Types of Meta-Learning :-
Meta Learning can be approached in different ways :
- Metric-Based – Learn an efficient distance function for similarity
- Model-Based – Learn to utilize internal/external memory for adapting (MANN)
- Optimization-Based – Optimize the model parameters explicitly for learning quickly
A simple Meta-Learning System
Here we will discuss a simple meta learning system.
- Normal Training – How does a normal neural network learn
- Framing the Problem – Transform a normal problem into a meta-learning problem
- Meta Training – How can a neural network learn to learn efficiently
Normal Training :-
Two meta factors affecting the training process are: Optimizer and Initial Parameters .

Let’s break down the illustration from previous slide.
- We start by initializing our model f_θ with parameters θ_0.
- Let’s say we have an input x. That input x is forward passed through our model f_θ to get some prediction and then we compare it with our label to get the loss.
- This loss is then backpropagated to compute the gradients.
- These gradients are then used by an optimizer like SGD to optimize the parameters of our model such that the loss in the next time step is less than the previous time step.
- We can incrementally number our model parameters after every optimization step as θ_0, θ_1, θ_2, and so on.
- This process is repeated until we reach a θ_n which solves our problem with minimal loss.
If you notice the way the model is first initialized, and the optimizer used plays a key role in how fast or efficiently the model is able to learn. Meta-learning is all about trying to learn the meta parameters like initial weights, optimization strategy and other hyperparameters.
Framing the Problem :-
The aim of meta-learning is to learn to generalize well on multiple tasks with little data.
So what we can do is split the dataset into multiple sub-problems and ask our model to perform well on all those tasks. Its like creating fake datasets from one big dataset to simulate multiple task solving.
Framing a meta-learning problem in the right way can force the neural network to learn features such that they generalize well on multiple unseen but related tasks.
Meta Training :-

Let’s break down the illustration above :
- As usually, we start by creating our main model which is solving the problem f_θ. This model is called the learner.
- We pass an input, calculate loss and compute gradients.
- But now instead of a heuristic optimizer like SGD another neural network g_ϕ will take those gradients and update the weights of our main model f_θ. This neural optimizer is called the meta-learner.
- Our meta-learner (g_ϕ) optimizes the parameters of our learner (f_θ) for some defined number of steps constituting an episode.
- After one episode, a meta-loss is calculated to evaluate how well our meta-learner updated our learner. This could be a simple summation of the losses of the learner
- This meta-loss is backpropagated through entirely and a parent optimizer (meta-optimizer) takes the gradients of the meta-loss to update the parameters of our meta-learner (g_ϕ) which in turn is responsible for the weights of the learner (f_θ).
Applications of Meta Learning :-
- Hyperparameter optimization for meta learning the optimal hyperparameters. Techniques like Genetic Algorithms can also be used to optimize the neural network hyperparameters (meta).
- Neural Architecture Search
- Few-shot learning and generalization across multiple tasks with little/no fine-tuning.
- Research in Meta-Learning is also driving the research towards AGI with advancements in Hybrid AI, Reasoning and RL.
Metric-based Meta Learning
The core idea here is like the nearest neighbor algorithms (k-NN, k-means) and kernel density estimation i.e. – The predicted probabilities of a given input is equal to the weighted sum of its labels (one-hot encoded) where the weight is generated by a kernel function which measures the similarity between two samples.

Metric Learning is all about learning to measure the similarity between an input image and another image in the database (aka support set). In the next sections, we will understand the few-shot classification problem which is solved by this approach.
Few Shot Classification Problem
Supervised Learning :-
We have a dataset D {x,y} containing the both the inputs (feature vectors) and true labels. The optimal model parameters θ^∗ is defined by a function

i.e. find parameters which maximize the expectation of the probability of y given x on all the data points.
In the next section, we will see how we can transform the dataset into mini-datasets consisting of multiple related tasks as one data point.
Few-shot Classification :-
Few-Shot Classification is an instance of meta-learning in a supervised context.
Let’s take a simple image classification dataset D with images and its corresponding labels. Each task in D is a single data-point having one image and one label and the goal is to predict the label given a single image trained on the entire dataset.
To frame it as a meta-learning problem for generalization, we can transform this dataset into a collection of mini-datasets having only a handful images per class. We train our model on all these mini-datasets with the goal to learn to classify with little training data. Each forward pass on one mini-dataset (task) is called an episode.
In notational terms, this translates to the Dataset D being transformed into a mini dataset B each containing a Support Set S and Query Set Q.
Support Set S contains the small number training examples of each class and Query Set Q contains the testing examples to run classification on, given the training examples in support set S.
It is modelled as a k-shot n-way problem where k is the number of training examples for one class and n is the total number of classes in the task/mini-dataset B. Support Set S contains k number of examples for each class in n.
The data transformed this way allows us to train explicitly for learning with less data.
During testing, we will be giving the model a mini-dataset B only, so training it in the same way as it would be tested on with mini-datasets, it is also called “training in the same way as testing”.
In practice, we perform similarity measurement on the support set images with the query set images making it a metric-based meta learning problem.

In its simplest form, the goal is to find parameters such that

i.e. finding the optimal parameters θ∗ for our model such that the expectation is maximized for all sub-tasks B in the dataset D where for each sub-task B the probability of classifying the correct label is maximized.
Convolution Siamese Networks
Convolution Siamese Networks were initially proposed by Koch, Zemel & Salakhutdinov (2015).
It consists of two twin input networks with a joint output which predicts the probability of two images being same based on the distance between their feature vectors.
Siamese Neural Network is used for one-shot image classification. It is trained and tested as follows –
- First, it is trained on pairs of input images as a verification problem of whether the pair are same or not (0 | 1).
- During testing phase, it is used iteratively to compare the test image with a support set usually containing one image for each class.
Loss is measured with Cross-Entropy as the output is binary.

Breaking down the above illustration :
- The two input images are fed through a feature extractor (CNN) one by one to get the embeddings of each image.
- A distance function (Euclidean, cosine) is used to compare the embeddings and calculate the distance
- The distance is then fed through a linear layer with sigmoid to predict whether both input images are same or not.
- A binary prediction is made with 1 being similar and 0 being not similar.
- The loss is accordingly calculated and backpropagated to update the feature extractor and linear head.
Siamese Network – Pseudo Implementation :-
Siamese Model :

Training :

Relation Networks
Relation Network for k-Shot Classification was proposed by Sung et al., 2018.
Relation Network is similar to Siamese Network with a few differences,
- The relationship between the two inputs is not captured by a simple Euclidean or Cosine Distance but rather a neural network based on a concatenation of the feature vectors
- The loss function is Mean Squared Error instead of Crossentropy as Relation Networks try to predict the relation scores which is more like regression than a binary prediction.

where f_θ is the embedding module and g_ϕ is the relation module
Relation Network – Architecture :-


Relation Network – Architecture (Zero Shot) :-

In a zero-shot setting, where we want to classify images without any training/support image, we can pass the query image through a CNN like Inception or ResNet to get the embeddings and then concatenate it with a class specific semantic embedding (obtained externally) passed through some hidden layers.
Matching Networks
Matching Networks were proposed by Vinyals et al., 2016. They work by embedding the images in support set and query set and then performing a generalized form of nearest neighbor classification.

Where, x ̂ is the query sample and x_i are the samples from the support set. y_i is the one-hot encoded label. a(x ̂,x_i ) is calculated by performing a simple softmax over the cosine distance of the feature vectors

It was also the first paper to introduce “training as testing”.
Matching Networks – Architecture :-

Matching Networks – Simple Embedding :-
In simple version of Matching Networks, the embedding function is a simple neural network with a single image as input and a fixed vector output.
- The support set images are encoded with feature extractor g and the query image is embedded with feature extractor f. Both are based on CNNs.
- The cosine distance is calculated between the query image x ̂ and the support set images x_i.
- Softmax is applied over this cosine distance which now acts as our attentional weight a(x ̂,x_i ).
- This attention weight is now applied on

giving us the predictions.
- Potentially, f=g i.e. feature extractor can be same for support set images as well as query images for keeping things simple.
Matching Networks – Full Context Embedding :-
Using a single embedding function for every image in the support set is not ideal as the embedding is independent of the other images in the support set and does not really capture the contextual relationship among the support set images and if two support set images are similar then it could pose a problem.
So the authors propose an Full Context Embedding version where,
- The support set embedding function g takes the whole support set S as input so that the learned embedding can capture the relationship among the images in the support set. g is a bidirectional LSTM model.
- The query embedding function f takes the query image and runs an LSTM on the feature vector with read attention over the support set S.
- Now, we have two separate feature extractors for the two different purposes. So in this case f≠g
The support g(x_i, S) encodes x_i with S as the context using a bidirectional LSTM.
The query f(x ̂, S) encodes x ̂ using an LSTM with read attention over S.
- First the query image goes through a regular CNN feature extractor f′(x ̂) to get the basic features.
- Then, an LSTM is trained with a read attention vector as part of its hidden state for k processing steps,

At the end of k time steps,

Prototypical Networks
Prototypical Networks were proposed by Snell, Swersky & Zemel, 2017.
- Firstly, it uses an embedding function f_θ to encode each input to an M-dimensional feature vector.
- Then, for every class c in the support set S, there exists a prototype vector v_c which is equal to the mean of the embeddings of the images in support set S belonging to class c and acts as a general representation of the class.
- Finally, use a simple distance function to classify a query image against the class prototypes.
Prototypical Networks – Prototype :-

Prototypical Networks – Working :-
- The prototype is calculated as a simple mean of the embeddings of the images in the support set belonging to that class.

- The probability distribution of the classes given a query input x is defined as

Where, ⅆ is any distance function. The author’s use (Squared) Euclidean distance and they found it empirically better than cosine in this case.
Prototypical Networks – Loss :-
The loss function used to train this network is Negative Log-Likelihood (NLL) Loss. In other words the negative logarithm of the softmax probability of the correct class.
Let’s say we have three classes – Cat, Dog, Horse and we test our model on one image of each class, then the NLL Loss for the correct class is as follows.


This loss function penalizes the model exponentially for low probability output and rewards for higher probability output.
Benchmark Results
Omniglot Dataset :-
Omniglot Dataset is like the MNIST of Few-Shot Classification. It contains 1623 different characters from 50 different alphabets (Hindi, Latin, Armenian etc.).

Applying Metric-based Meta Learning models to the Omniglot dataset :-

*Results with No Finetuning
Summary
- Meta-Learning deals with creating models which either learn and optimize fast or models which generalize and adapt to different tasks easily.
- Few-Shot classification frames image classification as a multiple-task learning problem.
- Metric Learning is about learning to accurately measure similarity in a given support set and generalize to other datasets.
- Siamese Network is an image verification model. Relation Network is similar to Siamese but learns the distance mapping instead of using a predefined function like cosine.
- Matching Networks learn to match the query image with the correct image in support set using distance-based attention
- Prototypical Networks measure distance between the input image and a general class representation called the prototype.
References :-
- Siamese Network http://www.cs.toronto.edu/~rsalakhu/papers/oneshot1.pdf
- Matching Network https://arxiv.org/abs/1606.04080
- Relation Network https://arxiv.org/abs/1711.06025
- Prototypical Networks https://arxiv.org/abs/1703.05175
- https://lilianweng.github.io/lil-log/2018/11/30/meta-learning.html