An Explanation of Attention Based Encoder-Decoder Deep Learning Networks
In this article, I will give you a detailed understanding of Seq2Seq models using LSTM along with its advantages, limitations and the theory behind the attention mechanism. Before we dive into the details of this increasingly popular technique, let's look at some applications of where these models are typically used.
- Text Summarization - This is one of the very interesting applications of the encoder-decoder technique and it has been gaining in popularity ever since it was first introduced by Google back in 2014.
- Machine Translation - Ever wondered how Google translate works? Well here you go! This is one of the main algorithms in this field and it does a pretty good job too.
- Image/Video Captioning - Because of this very method, computers are capable of describing pictures. For example, a computer generated image caption could be "A picture of a girl throwing a Frisbee" or "A dog eating a bone".
- Speech Recognition - One of the very recent developments has been in the area of audio where audio inputs are fed to the deep learning algorithm during the training phase, in which the network tries to decode the message into text.
- Music Generation - This is definitely a thing! Although still at a very nascent stage, some of the music pieces generated by A.I. are almost indistinguishable from that generated by humans.
- Recommendation Engines - Something one might never have suspected, but indeed, such architectures are deployed in real world recommendation systems in a lot of companies.
- Chat bots - This could be a question-answer type chat bot or something much more powerful like Siri or Alexa which might have been trained over millions of data instances and is capable of performing a wide spectrum of tasks.
While those are only a handful of some examples that use encoder-decoder seq2seq models, they all have something in common. They are all based on sequential data, or data that flows in the form of a sequence. Be it language, user interactions, or even music, these are all occurring in a sequential pattern one after the other or what is known as 'timesteps'. However, this is different from a variant of Recurrent Neural Networks such as LSTM/GRU as the input and output length are not fixed and could be of very different sizes.
So how does a Seq2Seq model work?
There are 2 components of a sequence to sequence model - the encoder and the decoder.
Encoder - The encoder reads the entire input sequence, one word per timestep, processes it and captures some contextual information about the input sequence into what is known as a context vector or thought vector. This is expected to contain a good summary of the entire input sentence. Every cell in the LSTM layer returns a hidden state (h_i) and cell state (c_i). The last hidden state and cell state are used to initialize the decoder, which is the second component of this architecture.
Decoder - Just like the encoder, the decoder reads the entire target sequence offset by one timestep along with the last hidden state and cell state of the encoder and predicts the next word in the target sequence. We add <start> to the beginning of the target sequence, indicating the start of a sentence and <end> to the end of the sequence, indicating the end of the sentence.
Here's how this looks in practice:
Inference - After this seq2seq model is trained, we now have to set up the inference decoder to generate predictions on unseen data. There are a few steps involved here:
- First we encode the entire input sequence and pass on the internal states to initialize the decoder.
- Next, along with the internal states, we add the <start> token as input to the first timestep of the LSTM decoder.
- Run the decoder for the first timestep with the encoder's internal states.
- The output (y_1) will be the word which has the highest probability.
- Now we pass this word onto the next timestep as input and update the internal states with that of the decoder.
- Now with the internal states of the first timestep of the decoder, we run steps 3-5 until it predicts the <end> token that we had used in the training phase.
Here is how that looks like for a 3 timestep LSTM decoder:
Limitations of Encoder-Decoder Networks
The encoder-decoder neural network works fine for smaller sequences but its performance starts to deteriorate when the size of the input sequence becomes too long. This is because it becomes difficult for the encoder to compress all the contextual information of a longer sequence into a fixed size vector. Even if the size of the target sequence is smaller, the decoder suffers as a consequence of the long input sequence and the overall predictions could become inaccurate. This is when "attention" comes into the picture.
How does the attention mechanism work?
Attention focuses on the most important parts of the sequence instead of the entire sequence as a whole. Rather than building a single context vector out of the last hidden state of the encoder, attention creates shortcuts between the entire input sequence and the context vector. The weights of these context vectors are customizable for each output element. As a result, the context vector learns the alignment of the source sequence with the target sequence. Essentially, the context vector consumes 3 pieces of information:
- Encoder hidden states
- Decoder hidden states
- Alignment between source and target
Here is a visual representation of how this works in practice for a bidirectional RNN encoder:
Fig. 4. The encoder-decoder model with additive attention mechanism in Bahdanau et al., 2015.
As you can see, the next prediction of a word in the decoder RNN is based on the hidden state from the previous timestep + the encoder context vector which is dynamically calculated for each timestep. For example, consider the following source and target sentence from a English to French machine translation task:
- Source (English): "What does the cat eat?"
- Target (French): "Que mange le chat?"
The first word "What" in the input sequence is connected to the first word "Que" in the target sequence. In addition, "eat" is connected to "mange" and "chat" means cat in French. So instead of looking at the entire sequence, we are now paying attention to only specific parts of it that result in a prediction in the target sentence. Notice, the word "does" will not get much importance and as a result, its hidden states in the context vector calculation will not contain as much information as the other words such as cat and eat. Here is another pictorial visualization of the working of the attention mechanism:
For readers who might find the slides above to be a little complicated, let me go over the theory in a step-by-step manner. For some notations, LinkedIn doesn't allow subscripts so I will be using a underscore "_" for this.
- The encoder outputs the hidden state (h_i) for every time step i in the source sequence
- Similarly, the decoder outputs the hidden state (s_t) for every time step t in the target sequence
- We compute a score known as an alignment score (a_t,i) based on which the source word is aligned with the target word using a score function. The alignment score is computed from the source hidden state h_i and target hidden state s_t using the score function. This is given by:
a_t,i = align (s_t, h_i )
where a_t,i denotes the alignment score for the target timestep t and source time step i.
- We normalize the alignment scores using a softmax function to retrieve the attention weights (a_t,i):
The softmax ensures that the weights are between 0 and 1 and that they all add up to 1. Note, this alignment score uses all source hidden states and the decoder hidden state one timestep behind.
There are different types of attention mechanisms depending on the type of score function used. Here are most of the popular ones:
- We compute the linear sum of products of the attention weights a_t,i and hidden states of the encoder h_i to produce the attended context vector (c_t):
If there are some words (such as 'does' from our example above) that do not add useful information, the weights for those words will be low in the calculation of the context vector.
- The attended context vector and the target hidden state of the decoder at timestep t are concatenated to produce an attended hidden vector S_t which is then fed into a dense layer to produce y_t, which is the next predicted word in the decoder.
I know this was a heavy dosage of the math behind attention, but let's now look at a real world application of this technique.
Text Summarization in Python Keras with Attention Mechanism:
I put the encoder-decoder with attention mechanism concept to test on a real world dataset by Flipkart that is publicly available on Kaggle.
Objective: In some e-commerce companies, going mobile friendly or implementing SEO tactics often involves shortening of product titles so that it can fit the mobile screen. For smaller companies, doing this manually might work although it would be very laborious. However, for larger companies which might have hundreds of thousands of product SKUs, this is no easy task. This is when the encoder-decoder concept can come to the rescue. We can train the model on a subset of labeled product data and later use it on unseen products to truncate the product titles.
I tried replicating this application process by using product descriptions to generate product titles in this dataset. While initially a vanilla encoder decoder network didn't yield highly accurate results, I later implemented a 3 layer stacked LSTM with the Bahdanau Attention layer using the Keras API. This gave some very interesting results on the holdout dataset which I have shared below. The complete notebook can be accessed here.
Output (First 15 examples of holdout data):
1. Source Sentence: buy allure auto cm car mat hyundai sonata embera rs online allure auto cm car mat hyundai sonata embera best prices free shipping cash delivery genuine products day replacement guarantee Original Target: allure auto cm car mat hyundai sonata embera Predicted Target: allure auto cm car mat maruti gypsy 2. Source Sentence: buy handicrafts showpiece cm rs online handicrafts showpiece cm best prices free shipping cash delivery genuine products day replacement guarantee Original Target: handicrafts showpiece cm Predicted Target: handicrafts showpiece cm 3. Source Sentence: cotonex blue pink cotton kitchen linen set price rs reviews cotonex pure cotton glove specifications cotonex blue pink cotton kitchen linen set general brand cotonex design code kls material cotton style code kls pattern striped design stripe design color blue pink dimensions weight additional features fabric care machine washable dry clean bleach box number contents sales package pack sales package glove Original Target: cotonex blue pink cotton kitchen linen set Predicted Target: cotonex pink cotton kitchen linen set 4. Source Sentence: ploomz women push bra buy red ploomz women push bra rs online india shop online apparels huge collection branded clothes flipkart com Original Target: ploomz women push up bra Predicted Target: ploomz women push up bra 5. Source Sentence: hrx casual short sleeve printed women top buy petite blue grey milange hrx casual short sleeve printed women top rs online india shop online apparels huge collection branded clothes flipkart com Original Target: hrx casual short sleeve printed women top Predicted Target: allen solly casual short sleeve solid women top 6. Source Sentence: buy offspring solid single blanket yellow rs flipkart com genuine products free shipping cash delivery Original Target: offspring solid single blanket yellow Predicted Target: offspring solid single blanket blue 7. Source Sentence: orange orchid striped men polo neck shirt buy brown navy orange orchid striped men polo neck shirt rs online india shop online apparels huge collection branded clothes flipkart com Original Target: orange and orchid striped men polo neck shirt Predicted Target: orange and orchid striped men polo shirt 8. Source Sentence: faballey casual full sleeve solid women top buy blue faballey casual full sleeve solid women top rs online india shop online apparels huge collection branded clothes flipkart com Original Target: faballey casual full sleeve solid women top Predicted Target: faballey casual full sleeve solid women top 9. Source Sentence: aaliya festive full sleeve solid women top price rs light weight natural viscose fabric arrow shaped hand embroidery sleeve wear work party light weight natural viscose fabric arrow shaped hand embroidery sleeve wear work party Original Target: aaliya festive full sleeve solid women top Predicted Target: aaliya casual sleeveless embellished women top 10. Source Sentence: vision shell necklace buy vision shell necklace rs flipkart com genuine products day replacement guarantee free shipping cash delivery Original Target: shell necklace Predicted Target: trinketbag turquoise ivory necklace 11. Source Sentence: next steps striped round neck casual boy sweater price rs boys winter sweater boys winter sweater Original Target: next steps striped round neck casual boy sweater Predicted Target: zink london casual full sleeve solid women top 12. Source Sentence: key features hugme fashion full sleeve solid men jacket genuin leather bag ideal men women specifications hugme fashion full sleeve solid men jacket jacket details sleeve full sleeve fabric leather general details pattern solid ideal men additional details style code jk black Original Target: hugme fashion full sleeve solid men jacket Predicted Target: hugme fashion full sleeve solid men jacket 13. Source Sentence: specifications ball gb ddr dual core ram hard disk performance features processor speed ghz processor name intel processor type dual core number cores general brand ball operating system free graphics memory na gb model name gb ddr graphics intel system memory storage features memory technology ddr ram gb hard disk capacity gb warranty covered warranty parts product service type year domestic warranty warranty summary year domestic warranty covered parts product covered warranty burning physically damaged dimensions weight kg height cm width cm depth cm box sales package cpu driver cd power cable Original Target: ball gb ddr with dual hard disk Predicted Target: god business visiting card holder 14. Source Sentence: specifications oem bike centre stand general brand oem vehicle model name discover model number vehicle brand bajaj type centre stand material stainless steel color black dimensions weight box bike centre stand Original Target: oem bike centre stand Predicted Target: oem bike side stand 15. Source Sentence: buy plant container set rs online plant container set best prices free shipping cash delivery genuine products day replacement guarantee Original Target: plant container set Predicted Target: vgreen plant container set
While the model does a good job overall at predicting what the title of the product would be based on the description, we do see some inaccuracies which in a real world scenario could prove disastrous. There are a few ways to improve the model's performance:
- Using beam search strategy instead of the greedy argmax strategy while decoding the encoder sequence
- Feeding more data to the network
- Implementing a bidirectional LSTM
- Trying out different attention mechanisms such as Luong attention
- Choosing a multi-headed Transformers approach over attention based encoder-decoder
While there could still be a few cases of misclassification even after trying the techniques mentioned above, one way of minimizing them and yet achieving the task at hand is by intervening manually when the model's probabilities are lower than a determined threshold.
If you have any suggestions or feedback on how to improve the model, please do comment below and I will get back to you. The full code repository is accessible here.
¡gracias!... muy bien explicado 10 de 10
good read!
Hi Sukhmani, I would recommend the examples from the Keras website. They have some applications of encoder-decoder for signal processing and speech recognition. In addition you can also have a look at image captioning with transformers. Not sure if there are applications of it for simple numerical/structured datasets. Hope this helps.
Hello Keshav. I'm a beginner and currently learning about keras encoder decoder model implementations. Can you please guide me, as to how to proceed with non linguistic data i.e. a simple numeric dataset for prediction using encoder decoder models, including attention!