Attention Mechanism
Attention Mechanism
This article will introduce a powerful technique in machine learning called Ateention Mechanism.
The core method of attention mechanism is to pay more attention to what we want. It allows model to weigh the importance of different parts of input dynamically rather than treating them equally. The model learns to assign higher weights to the most relevant elements.
Before stepping into the main text, we should first know some preliminary knowledge(Understand hidden states, encoder, decoder).
Preliminaries
1. RNNs and LSTM network
The content introduced below is mainly from this three blog/article. For details, click the following link to learn more.
[1] Recurrent neural network From Wikipedia
[2] Understanding LSTM Networks
RNNs, short for Recurrent Neural Networks, is a class of artificial nerural network used to process sequencial data. Unlike FFN(Forwardfeed Neural Network), RNNs process data across multiple times rather than in a single time, making them well-adapted for modelling and processing text, speech, and time series.
The following picture demonstrates the working flow of RNNs.
![An unrolled recurrent neural network from [2]](https://github.com/RyanLee-ljx/RyanLee-ljx.github.io/blob/image/attention/rnn2.png?raw=true)
RNNs are made up of many units or a loop. Each unit are fed with the previous unit's state————the hidden state and the input element . The hidden state is the extraction of the input's feature and can be calculated to obtain the output. We can use this formula to express:
Here are activation function like sigmoid, ReLu.
By applying another function/transformation like Softmax to , we can obtain the output at each time step.
![Out put of RNNs from [2]](https://github.com/RyanLee-ljx/RyanLee-ljx.github.io/blob/image/attention/rnn1.png?raw=true)
However, RNNs fall short in Long-TermDependencies, causing vanishing gradient problem. So here we introduce LSTM, short for Long Short Term Memory networks.
LSTM follows the main structure of RNNs, adapting the inner structure of each recurrent unit. Each unit comprise of one cell state and three 'gates', the 'forget gate layer', 'input gate layer' and 'output gate layer'
The key to LSTMs is the cell state, the horizontal line running through the top of the diagram.The cell state is kind of like a conveyor belt. It runs straight down the entire chain, with only some minor linear interactions. It’s very easy for information to just flow along it unchanged.
![cell state from [2]](https://github.com/RyanLee-ljx/RyanLee-ljx.github.io/blob/image/attention/LSTMcell.png?raw=true)
It receives information from the 'forget gate layer' and 'input gate layer'.
![Forget gate layer from [2]](https://github.com/RyanLee-ljx/RyanLee-ljx.github.io/blob/image/attention/LSTM3forget.png?raw=true)
![Input gate layer from [2]](https://github.com/RyanLee-ljx/RyanLee-ljx.github.io/blob/image/attention/LSTM3input.png?raw=true)
![Output gate layer from [2]](https://github.com/RyanLee-ljx/RyanLee-ljx.github.io/blob/image/attention/LSTM3output.png?raw=true)
data:image/s3,"s3://crabby-images/bfb29/bfb29b0f5562a046e36941844101d90f84e478d9" alt="Calculat new cell state"
data:image/s3,"s3://crabby-images/b7779/b7779a0dad1cb69e41d300b0da19ea24998c6668" alt="Overview of LSTM"
2. Encoder and Decoder
[4] 从Encoder-Decoder(Seq2Seq)理解Attention的本质
[5] What are Attention Mechanisms in Deep Learning?
The Encoder-Decoder framework, also called Seq2Seq framework, is a widely used design pattern in ML.
The encoder’s job is to capture the context and important information from the input sequence. Leveraging RNNs, LSTM, GRUs, it transforms the input sequence into a high-dimensional representation that can be used by the decoder to generate the output sequence.
The decoder’s job is to generate the output sequence one token at a time. It uses the information from the encoder as well as its own hidden states to produce coherent and contextually accurate outputs.
For example, here is a sentence pair
The encoder, as its name suggests, encodes the input sentence Source, transforming the input sentence into an intermediate semantic representation C:
For the decoder, its task is to generate the target sentence based on the intermediate semantic representation C of the sentence Source and the historical information generated previously.
Basic Components
Key Components of Attention
The attention mechanism typically involves the following key elements:
Query, Key, and Value
- Query (Q): Represents what the model is currently focusing on (e.g., the word or token for which attention is being computed).
- Key (K): Represents the features of the input elements that the model compares against the query.
- Value (V): Represents the actual information or content of the input elements that the model uses for the output.
Attention Score
The attention score measures the similarity between the query and each key. A common approach is to compute the dot product between the query and the keys, optionally scaled by a factor to stabilize training:
where is the dimension of the key vector.
The implementation of attention algorithms differs based on the way it measures the similarity.
Bahdanau Attention:
where , , and are model parameters, and represents a fully connected layer.
Luong Attention:
However, both Bahdanau and Luong attention are soft attention mechanisms, calculating αi,j using a softmax function.
Softmax
The scores are passed through a softmax function to produce normalized attention weights. These weights sum to 1, indicating the importance of each input element relative to the others:
Weighted Sum
The attention weights are applied to the values (V) to compute a weighted sum, which becomes the output of the attention mechanism:
Working Flow
Encoding Phase
The encoder processes the input(namely the source), transforming them into hidden state (namely the key). is the length of the input.
This process is the content we introduce in the preliminaries.
Attention Calculation
Calculate the similarity, namely the attention score, between decoder's current hidden state (namely the query) and each element of encoder's output() with (1).
Derive attention score of using (2).
Calculate weighted sum using (3). The result is called the context vector, which is passed to the decoder to generate the its next hidden state and the output.
The context vector computed by the attention mechanism, which contains weighted information from the Encoder's hidden states.
Decoding Phase
The next hidden state of the Decoder is determined by three main components:current hidden state, context vector, decoder input(e.g., the previously generated word in machine translation).
Specifically, the next hidden state is generated by feeding the current hidden state , the Context Vector , and the Decoder input into an RNNs, LSTM, GRUs. This can be expressed as:
Where:
- is the hidden state at the current time step,
- is the hidden state at the previous time step,
- is the Context Vector at the current time step,
- is the output from the previous time step (Decoder input).
The output of the Decoder is typically generated through the following steps:
Compute Logits Using Hidden State and Context Vector: The hidden state , the Context Vector and the output from the previous time step are combined and passed through a fully connected layer (often a linear transformation followed by an activation function) to produce unnormalized scores (logits).
Where:
- and are learnable parameters,
- represents the concatenation of the hidden state and the Context Vector.
Generate Probability Distribution via Softmax: The logits are passed through a Softmax function to produce a probability distribution over possible outputs.
Select the Output: Based on the probability distribution, the word with the highest probability is selected as the output for the current time step. Alternatively, sampling methods can be used to generate the output.
The following pictures can clearly illustrate this process.
data:image/s3,"s3://crabby-images/b86e2/b86e2ff7d436fae3a2e71d15087f0399434256c8" alt="illustration of Attention Mechanism(1)"
data:image/s3,"s3://crabby-images/faab1/faab184d33ba8a887b156d030558165973429168" alt="illustration of Attention Mechanism(2)"
Summary
The essence of attention mechanism is that we select important information from the input,giving them more weigths so as to omit redundant information.
The function of attention mechanism is to calculate the weighted sum between query and value. The weights are measured by the similarity between query and key.
Example & code
提示
To better understand the matrix calculation, we can see the following illustration.
data:image/s3,"s3://crabby-images/50a1f/50a1f5801777191ac3ffe157074d6e5a7308bcd4" alt=""
data:image/s3,"s3://crabby-images/0f54b/0f54b5b42bfe7ff6861e14264b1db4574dd79f52" alt="illustration"
Assume:
Key (K): [3, 2](three itmes with the dimension of 2)
Value (V): [3, 2](three itmes with the dimension of 2)
Query (Q): [3, 2](three itmes with the dimension of 2)
Code:
import torch
import torch.nn as nn
class Attention(nn.Module):
def __init__(self, dim):
super(Attention, self).__init__()
'''
if you want to apply a linear transformation to the input(q, k, v),
you can add the following code to your program.
self.wk = torch.nn.Linear(dim, dim)
self.wv = torch.nn.Linear(dim, dim)
self.wq = torch.nn.Linear(dim, dim)
'''
def forward(self, query, key, value):
'''
if the input is not in batch, we need to
add a dimension(batch dimension) to the input.
'''
if len(key.shape) == 2:
key = torch.unsqueeze(key, dim=0)
if len(query.shape) == 2:
query = torch.unsqueeze(query, dim=0)
if len(value.shape) == 2:
value = torch.unsqueeze(value, dim=0)
# make sure the input is float type.
key = key.float()
query = query.float()
value = value.float()
'''
if apply linear transformation, add the following code.
key = self.wk(key)
value = self.wv(value)
query = self.wv(query)
'''
d_k = key.shape[-1]
score = torch.bmm(query, key.transpose(-2, -1))/(d_k ** 0.5)
print(f"score: \n {score} \n")
attention_weights = torch.nn.functional.softmax(score, dim=-1)
print(f"attention weights: \n {attention_weights} \n")
output = torch.bmm(attention_weights, value)
return output
# input
Q = torch.tensor([[1, 2],
[3, 4],
[5, 6]])
K = torch.tensor([[0.5, 1],
[1.5, 2],
[2.5, 3]])
V = torch.tensor([[10, 20],
[30, 40],
[50, 60]])
attention_layer = Attention(2)
output = attention_layer(Q, K, V)
print(f"output: \n{output}\n")
Results:
data:image/s3,"s3://crabby-images/dbfe5/dbfe50e2309d9143c354a48d3299a6fa708cc25f" alt="Results"