Optimization-Based Meta Learning
- August 6, 2020
- Posted by: Bhavesh Laddagiri
- Category: Artificial General Intelligence (AGI)
Recently I presented a session on Optimization-based Meta Learning, Part 3 of the Meta Learning Series, at the CellStrat AI Lab. The previous parts are found here – Part 1 (Metric-based Meta Learning), Part 2 (Model-based Meta Learning)
Meta Learning, of course, refers to “Learning to Learn“.
(Metric-based Meta Learning)
- Meta-Learning deals with creating models which either learn and optimize fast or models which generalize and adapt to different tasks easily.
- Few-Shot classification frames image classification as a multiple-task learning problem with the goal to efficiently find similar images.
- Metric Learning is about learning to accurately measure similarity in between images in support set images and query images.
- We discussed Siamese Networks, Relation Networks, Matching Networks and Prototypical Networks with an implementation.
- We also implemented Prototypical Networks in PyTorch.
(Model-based Meta Learning)
- Model-based Meta-Learning aims at designing architectures which are inherently capable of generalizing and learning with less data.
- Memory Augmented Neural Networks (MANN) are neural networks which use an external memory coupled with attention to solve a problem.
- Memory Networks use scaled dot product attention over the incoming data for inference.
- Neural Turing Machines have an external memory matrix on which a controller neural network can perform read and write operations using attention mechanisms.
- We also implemented NTMs in PyTorch to solve the simple algorithmic tasks.
Optimization-based Meta Learning
Currently, all deep learning models learn by backpropagating the gradients and then performing gradient descent. But this method is not designed for learning with less data and requires a lot of iterations to arrive at the minima.
Optimization-based Meta-Learning intends to design algorithms which modify the training algorithm such that they can learn with less data in just a few training steps.
Usually, this refers to learning an initialization of parameters which can be fine-tuned with a few gradient updates. Some examples of such algorithms are –
- LSTM Meta-Learner
- Model-Agnostic Meta-Learner (MAML)
Transfer Learning vs Optimization-based Meta Learning :-
For transfer learning to work, we first pretrain a model on a very large dataset. The resulting pretrained model is then used as the initial parameters for fine-tuning on a new and medium-sized dataset.
However, transfer learning is not designed specifically for learning with less data. It just happens to be good at medium-sized datasets because of the pre-learned features.
Optimization-based meta learning is aimed at finding those initial set of parameters which are generalizable to a wide range of problems, so that when we have a new problem, we only need a few gradient updates for fine-tuning to a small dataset.
These algorithms are explicitly designed and trained for finding that set of initial parameters which can be fine-tuned later in a couple of training steps.
Model-agnostic Meta Learner (MAML) :-
It was introduced by Finn et al. (2017).
It is a model-agnostic meta-learning algorithm which explicitly learns parameters that can generalize to any new task by fine-tuning with a single training step.
It is model-agnostic i.e. it can work with any deep learning model which is trained with gradient descent. The authors have tested it on few-shot classification, regression and even reinforcement learning to demonstrate its flexibility.
MAML Algorithm :-
During the meta training, we optimize to find the optimal initial parameters (blue) such that it is close to all related tasks.
Performing a few steps (usually 1) of fine-tuning on the specific task should generalize it well without overfitting.
Defining the Problem :-
- We will be applying MAML on few-shot image classification problem. The dataset is split into multiple tasks.
- Each task is sampled as K-shot N-way classification. K images for each of the N classes which comprises of the Support Set (train).
- The Query Set (test) also contains some K images for each of the N classes.
- The model is a vanilla CNN based image classifier.
In this last step (above) the meta loss gradient is calculated with respect to the initial parameters θ making the backpropagation go through the entire computation graph including the task specific fine-tuning part where we had calculated the gradient once for support set loss L_(S_i ).
This means it involves taking the gradient of the gradient making the derivative of second-order (aka Hessian vector).
Intuitively, this means moving the initial parameters θ to a place which is easy to reach for all tasks for fine-tuning in a single gradient update.
The Meta Parameter Update of θ is the backpropagation through the entire inner fast weights leading to second order derivatives.
First Order MAML :-
As calculating higher-order derivatives is a bit expensive, we can also ignore the second-order derivative and just calculate the meta-loss’s gradient with respect to the fast weights θ_i′ and perform the meta-update on θ.
- Few-shot image classification
- Reinforcement Learning
- Any deep learning model which learns uses gradient-based optimizers (Adam, SGD, etc.).
Reptile was introduced by Nichol et al (2018) at OpenAI. It is an extension to First-Order MAML and is quite closely related to it.
Reptile works by simply sampling a Task, performing SGD on that task for k times, and then moving the initial parameters towards the fast weights of that task. Doing this repeatedly helps find an optimal set of initial parameters.
But one might ask that isn’t it just doing transfer learning multiple times? In the coming sections, we will understand why it works.
MAML vs Transfer Learning :-
MAML Summary :-
Optimization-based Meta Learning algorithms learns a general weight initialization which is easy to fine-tune to downstream tasks with a small number of gradient updates.
MAML is a model-agnostic algorithm that can be applied to any deep learning model. It uses second-order derivates to update the initial model parameters which is used by each task for fine-tuning. Intuitively, this finds a set of parameters which is equidistant to the different tasks’ parameters.
First-Order MAML (FOMAML) simply ignores calculating second-order derivates and just uses the meta-loss to update the initial parameters with respect to the fast weights.
Like FOMAML, Reptile repeatedly performs SGD on a task and moves initial parameters towards the new task’s weights. This is done repeatedly on multiple tasks and eventually the resulting initial parameters are generalizable.
Where to use what :-
Metric-based Meta Learning – Use for training on datasets for few-shot classification problems. The required dataset should have enough classes to extract few-shot samples during meta-training.
Model-based Meta Learning – Memory Networks are mainly used for NLP based tasks. But it can be tuned for any sequence-based tasks. Memory Networks are shallow and compute friendly.
Neural Turing Machine can also be extended for few-shot classification problems and vanilla NTM can be used for supervised training for algorithmic tasks.
Optimization-based Meta Learning – MAML can be used for few-shot classification problems or any similar problem where data is scarce. MAML is model agnostic so it is flexible enough to fit any domain mostly. But it requires more compute power.
Usually, Meta learning is used for few-shot classification like tasks, so in general Reptile is good for that task while using lesser compute and time than MAML.
If your problem is few-shot classification, choose Reptile and if it is any other domain then choose MAML (First Order variant of MAML).
- Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks https://arxiv.org/abs/1703.03400
- On First-Order Meta-Learning Algorithms (Reptile) https://arxiv.org/abs/1803.02999