Model-based Meta Learning
- July 31, 2020
- Posted by: Bhavesh Laddagiri
- Category: Artificial General Intelligence (AGI)
#CellStratAILab #disrupt4.0 #WeCreateAISuperstars #WhereLearningNeverStops
Recently I presented Part 2 of the Meta Learning Workshop series at the CellStrat AI Lab. The topic was Model-based Meta Learning.
A Quick Recap
- Meta-Learning deals with creating models which either learn and optimize fast or models which generalize and adapt to different tasks easily.
- Few-Shot classification frames image classification as a multiple-task learning problem with the goal to efficiently find similar images.
- Metric Learning is about learning to accurately measure similarity in between images in support set images and query images.
- We discussed Siamese Networks, Relation Networks, Matching Networks and Prototypical Networks
- We also implemented Prototypical Networks in PyTorch
Model-based Meta Learning
Model-based techniques are entirely different from Metric-based techniques. They make no assumption about maximizing P_θ (├ y┤| x_i,S).
Instead, the focus is on making the parameter-update/adaptability on new data faster using specific model architectures which generalize.
In this article, we will look at models which are better at generalization and to some extent, reasoning too. We will specifically dive into Memory Augmented Neural Networks.
Memory Augmented Neural Networks (MANN) :-
A computer uses three fundamental mechanisms to operate – arithmetic operations, logical flow control and external memory. They have a long-term storage as well as short-term storage in the form a working memory (RAM). As humans too, we use memory to retrieve facts from past experiences and put together different facts for inference and reasoning.
But the machine learning community has largely neglected the use of external memory for solving problems. (Note: RNNs have internal memory and its different).
Making use of external memory gives extra powers to models in terms of adaptability to new tasks by just accumulating new information in its working memory.
MANNs are a class of models which learn to do just that by using attention-based mechanisms to operate on its memory and infer.
Why MANN ?
RNNs are good at sequence related tasks and using previous information for inference but their capabilities are limited when in it comes to remembering long term relations over large amounts of data (e.g. an entire Wikipedia article).
Transformers have made leaps in handling long sequences and generalize well to other tasks, but we all know that they have been pre-trained on huge datasets with a lot of computation power. So comparing MANNs to Transformers is not fair.
On the other hand, MANNs don’t need to remember large amounts of data, they just need to learn to operate on an external memory and infer from it. Usually, this is achieved by using attention-based mechanisms.
There are two types of MANNs :-
- Memory Networks (MemNN) (not widely considered as MANN)
- Neural Turing Machines (NTM)
Usually when people talk about MANNs, they refer to NTMs as it has a more explicit external memory.
A memory network can learn to perform inference using a long-term memory component where information is read and written. The basic components of a Memory Network are:-
- Memory m, which is an indexed matrix of objects (numbers, strings, words, sentences)
- Input feature map I, which is responsible for converting the inputs to an internal feature vector
- Generalization G, for updating old memories given a new input and adapt to new tasks
- Output feature map O, which predicts an output (as a feature vector) given the input and the current memory state
- Response R, which just transforms the feature vector to a human readable format (like words)
General Architecture :-
The I, G, O and R components can be anything from the ml literature, it can be bag-of-words, embedding layers, neural networks, attention mechanisms etc.
A generalized structure of the components is as follows:
End to end Memory Networks for Text :-
Memory Networks with its components as neural networks can be applied for question answering where the input and output is textual with a sense of recurrence.
A few advantages of this are – not very deep, easy and fast to train, and simple design.
We will understand Memory Networks by applying it on the bAbI dataset by Facebook. This dataset contains a set of stories followed by a question and a one-word answer. These tasks range from answering questions about where a person/object is, to questions about size or position of objects and even path finding.
The (20) QA bABI tasks :-
The aim of this dataset is to test the textual understanding and reasoning abilities of models. The dataset contains 20 different types of QA tasks which require a sense of reasoning and inference.
*ClassName column just refers to the type of task
Basic Factoid QA :-
Let’s understand how we can use memory networks to solve tasks in the bAbI dataset. We will start with single supporting factoid task and work our way upwards. Let’s take an example problem and solve it.
Step 1: The I component – Convert input sentences to a feature vector.
The same embedding layer can be applied to encode questions.
Step 2: The G component – Put all the sentence feature vectors into memory in the form of a matrix.
Step 3: The O_a component – Assign weights to each sentence vector based on the question vector i.e. dot product of sentence and question vector and then a softmax over it. (The dot product is to get the similarity)
Step 4: The O_b & R component – •Dot the generated weights with the sentence vectors •Pass the resulting sentence vector through a dense layer with softmax activation. •This gives the predicted probabilities of each word in its vocabulary.
This block is also called a hop.
Overall Structure :-
But this architecture will only work for a single supporting fact because :
- Softmax only supports selecting 1 thing!
- For multiple supporting facts order is also important. Example –
- “John goes to the kitchen.
- John goes to the hallway.”
- “Where is John?” ⇐ Answer depends on which happened first.
To solve this for two supporting facts, we need two hops (O blocks) for getting information from two sentences (in other words two locations of memory).
In the next section, we will see what this means and it resembles recurrence.
Two Hop structure :-
- In the first hop, we generate weights using the question embedding and dot them with the sentences to get an output vector S_x.
- Now, we use this vector S_x to generate the new weights instead of the question vector in the second hop.
- This gives us S_y which is now passed through the dense layer to get the prediction.
- This way we can make use of multiple hops to access multiple sentences (memory).
Further Improvements :-
- “Curriculum learning” over all the tasks in one training session for generalization all in one model.
- Randomly replacing known words as unknown words during training to teach it to handle out-of-dictionary words during testing. Sort of like a dropout.
- Add more hops based on the story and question.
- Refer newer advancements in memory networks. Example, Paper: Self-Attentive Associative Memory https://arxiv.org/abs/2002.03519.
- Be creative and apply techniques from other domains to augment this technique.
Neural Turing Machine (NTM)
A Neural Turing Machine is a neural network coupled with an external memory with which it interacts using attention mechanisms. It was introduced by Graves et. al. in DeepMind in 2014.
It is analogous to a Turing Machine or Von Neuman architecture except that its end-to-end differentiable, allowing it to be efficiently trained by gradient descent.
Due to the usage of an external memory, they are good at retaining long sequences and can generalize quite well.
Vanilla NTMs have been tested to learn algorithms like copy, copy-repeat, associative recall and sorting using supervised methods.
Successors like the Differentiable Neural Computer (DNC) can perform more sophisticated general tasks like route planning, answering logical questions etc.
- The NTM Architecture contains two basic components – Controller and Memory.
- The controller is a neural network which interacts with the external world with standard input and output vectors.
- But additionally, it also interacts with the memory bank using selective read and write operations. The network output neurons responsible for memory interaction are called heads (analogy of the Turing machine).
- The head interaction with memory is very sparse and it uses focused attention mechanisms.
- The memory is defined by a matrix of size N×M where, N is the number of memory locations and M is the length of the vector at each location.
- Let M_t be the contents of the memory at time t and w_t be the vector of weights over the N locations. All weights are normalized and lie between 0 and 1 and sum up to 1.
- The Read Head returns a read vector r_t of size M defined by a weighted combination of row vectors M_t (i).
The write operation is decomposed into two parts – erase and add.
- Erase: Given a weight w_t emitted by the write head at time t along with an erase vector e_t (of length M), the memory contents M_(t-1) from the previous time step are modified as follows,
- If both weights w_t and erase vector e_t are 1 then the memory is reset to zero. If any one of them is 0 then the previous contents remains unchanged.
- Add: Each write head also produces an add vector a_t which is used to perform the changes to the Memory after the erase step.
Now, we know how the read and write operations are performed by the heads, but how are these weights and parameters produced in the first place? In the next section, we will understand the addressing mechanisms which generate these parameters.
Writing – Erase :-
Writing – Add :-
There are two types of addressing in NTM – content-based and location-based.
Content-based addressing works by focusing its attention on values which are similar to the current values (Like the Memory Network where the input question is dotted with each sentence in memory to get similarity).
However, using similarity to address the memory is not always optimal when we want to access variables which are not similar but important. E.g. if the model is performing an arithmetic task like x×y then it needs to access both x and y which are at different locations and not similar.
To handle such cases, location-based addressing diverts its focus to other locations in the memory.
In NTM, both addressing mechanisms are used concurrently but the impact of each mechanism is decided by other parameters.
Overview of the Addressing Mechanism :-
Content Addressing :-
- For content-based addressing, the controller first generates a key vector k_t of length M which is compared with every vector M_t (i) using a similarity measure K[.,.] which is cosine in this case.
- The content-based addressing produces a content weighting w_t^c by applying a softmax on the product of the similarity and a positive key strength β_t.
With the increase in key strength β the focus of the head changes sharply towards the most similar location in the memory.
Location Addressing :-
- Location-based addressing is designed to allow to move to a different location from the one given by content addressing.
- This is done by adding a rotational shift to the weights. For example, if currently the head is focusing on a single location, a rotation of 1 would shift the focus to next location or -1 would shift the focus in the opposite direction and so forth.
- Location-based addressing is split into three steps – Interpolation, Convolution Shift (Rotation) and Sharpening. •In the next few slides, we will be looking at each step, in detail.
Note: Content-based addressing is more general, because if we store location information also in the memory then separate location addressing is not required. But having a separate operation for location addressing proved to be useful in some cases.
- Before we can start the rotation, a scalar interpolation gate g_t (range 0,1) is used to control how much or how less do we want to use the content addressed weights.
- It allows us to shift between the weights of the previous time step w_(t-1) with the weights w_t^c produced by content addressing. This yields a gated weight w_t^g.
- If the g_t is 0, then the content weight w_t^c is ignored and the weights from previous timestep is used i.e. w_(t-1). Similarly, if it is 1 then previous timestep weights is ignored and only content weights is applied.
- This interpolation helps us account for information from the past as well.
Convolution Shift and Sharpen :-
After interpolation, the head emits a shift weighting s_t which is used as a 1D convolution filter over the weights. In other words, this allows us to change our focus from a single point decided to a range of points.
The length of s_t in general is defined by 2n+1 where n is the maximum shift value. For example, if the shifts by one position is allowed (n=1), then s_t would correspond to the degree of shifts of -1, 0 and +1 possibly weighted as [0.3, 0.6, 0.1].
The general convolution operation takes place as follows,
Location Addressing – Convolution :-
- In order to apply a convolution filter we need to do some form of padding by appending the last n* (n=1) elements to first and the first n* (n=1) elements to last to make it circular.
- By circular it means, if we go one step ahead of the last element value (0.9) we reach the first element value (0.01)
*n = the max shift value allowed
- Convolve s_t (our 1D filter) over the weights w_t^g by doing a dot product and then moving forward one step. (Refer the animation on the left)
- After the convolution, as you can see the initial focus in w_t^g was on the last element (0.9) but now in w ̃_t the focus is shifted to first element (0.72).
Sometimes, its possible that the shifting can change the focus to be more diffused and not very sharp. For example, if our s_t 1D filter is [0.1, 0.8, 0.1], then the there won’t be any focus shift and it would just get a bit dispersed or blurred.
To combat that, a sharpening technique is applied where γ_t≥1
The greater the γ_t the more focused the weights become.
The combined addressing system of content, interpolation and location can be applied in three primary ways –
- The location system can be turned off i.e. no shifting and only content system is used.
- The weight produced by content system can be chosen and shifted
- The weight from previous timestep can be chosen and shifted without any modification by content system
Controller Neural Network :-
The controller neural network is responsible for producing all these parameters for addressing, reading and writing. It can either be normal feedforward network or an LSTM.
By design, this network outputs a vector and the addressing parameters are extracted from that vector using a linear layer for each parameter.
LSTM has its own internal memory too, so it gives an added advantage over feedforward as it can process multiple timesteps.
On the other hand, feedforward network is transparent as it allows us to easily look under the hood and interpret the memory interaction which is difficult in case of LSTM where representing the internal state is not possible.
However, feedforwards hit a bottleneck in the number of concurrent read and write operations it can perform.
Experiment Results – Copy Task :-
In the copy task, the network is presented a sequence of random binary vectors as inputs with the target being the same as input.
Further Reading :-
For complete details of the experiments performed on the NTM – refer the paper, Neural Turing Machine https://arxiv.org/abs/1410.5401
One-Shot Learning with MANN https://arxiv.org/abs/1605.06065. This paper is an extension to NTM to apply it to few-shot classification (e.g. omniglot). Please read this for applying it to few-shot classification.
We have 4 primary torch.nn.Module classes – (1) Memory (2) Controller (3) Head (4) NTM.
The functions defined in memory are –
- Read (weights)
- Write (weights, add-vector, erase-vector)
- Content Addressing (key, beta)
- Reset – for initializing the memory matrix
The functions defined in head are –
- Define the LSTM Cell and Output FC Layer
- Forward (x input, previous read vectors)
- Returns the hidden and cell states as output
- Output (read vectors)
- Returns the external output sequence (aka predicted sequence)
- Reset – for initializing the states of LSTM Cell
The functions defined in head are –
- Define the linear layers for taking controller state as input and outputting the parameters like key vector, beta, gamma etc.
- Extract the NTM parameters from controller state
- Read or Write
- Reset – for initializing the learnable parameters of the linear layers
- Memory Augmented Neural Networks (MANN) use external memory coupled with attention to generalize and adapt to different tasks.
- Memory Networks use attention over the supporting information (e.g. stories) to infer on the input question.
- Multiple hops in Memory Networks can help consider multiple sequences for inference and resembles recurrent networks.
- Neural Turing Machine has controller neural network which learns to read and write to a memory matrix using selective attention to solve a task.
- NTMs tend to learn internal algorithms to solve the problem and thus generalize well.
- Memory Networks https://arxiv.org/abs/1410.3916
- End to End Memory Networks https://arxiv.org/abs/1503.08895
- bAbI https://research.fb.com/downloads/babi/
- Neural Turing Machine https://arxiv.org/abs/1410.5401
- One-Shot Learning with MANN https://arxiv.org/abs/1605.06065. This paper is an extension to NTM to apply it to few-shot classification (e.g. omniglot). Please read this for applying it to few-shot classification.
- Differential Neural Computers https://deepmind.com/blog/article/differentiable-neural-computers