跳至主要內容

Attention Mechanism

RyanLee_ljx...大约 6 分钟ML

Attention Mechanism

This article will introduce a powerful technique in machine learning called Ateention Mechanism.

The core method of attention mechanism is to pay more attention to what we want. It allows model to weigh the importance of different parts of input dynamically rather than treating them equally. The model learns to assign higher weights to the most relevant elements.

Before stepping into the main text, we should first know some preliminary knowledge(Understand hidden states, encoder, decoder).

Preliminaries

1. RNNs and LSTM network

The content introduced below is mainly from this three blog/article. For details, click the following link to learn more.

[1] Recurrent neural network From Wikipediaopen in new window

[2] Understanding LSTM Networksopen in new window

[3] 如何从RNN起步,一步一步通俗理解LSTMopen in new window

RNNs, short for Recurrent Neural Networks, is a class of artificial nerural network used to process sequencial data. Unlike FFN(Forwardfeed Neural Network), RNNs process data across multiple times rather than in a single time, making them well-adapted for modelling and processing text, speech, and time series.

The following picture demonstrates the working flow of RNNs.

An unrolled recurrent neural network from [2]
An unrolled recurrent neural network from [2]

RNNs are made up of many units or a loop. Each unit are fed with the previous unit's state————the hidden state ht1h_{t-1} and the input element xtx_{t}. The hidden state is the extraction of the input's feature and can be calculated to obtain the output. We can use this formula to express:

ht=f(Uht1,Wxt,b) h_{t} = f(Uh_{t-1}, Wx_{t}, b)

Here ff are activation function like sigmoid, ReLu.

By applying another function/transformation like Softmax to hh, we can obtain the output yty_{t} at each time step.

Out put of RNNs from [2]
Out put of RNNs from [2]

However, RNNs fall short in Long-TermDependencies, causing vanishing gradient problem. So here we introduce LSTM, short for Long Short Term Memory networks.

LSTM follows the main structure of RNNs, adapting the inner structure of each recurrent unit. Each unit comprise of one cell state and three 'gates', the 'forget gate layer', 'input gate layer' and 'output gate layer'

The key to LSTMs is the cell state, the horizontal line running through the top of the diagram.The cell state is kind of like a conveyor belt. It runs straight down the entire chain, with only some minor linear interactions. It’s very easy for information to just flow along it unchanged.

cell state from [2]
cell state from [2]

It receives information from the 'forget gate layer' and 'input gate layer'.

Forget gate layer from [2]
Forget gate layer from [2]
Input gate layer from [2]
Input gate layer from [2]
Output gate layer from [2]
Output gate layer from [2]
Calculat new cell state
Calculat new cell state
Overview of LSTM
Overview of LSTM

2. Encoder and Decoder

[4] 从Encoder-Decoder(Seq2Seq)理解Attention的本质open in new window

[5] What are Attention Mechanisms in Deep Learning?open in new window

The Encoder-Decoder framework, also called Seq2Seq framework, is a widely used design pattern in ML.

The encoder’s job is to capture the context and important information from the input sequence. Leveraging RNNs, LSTM, GRUs, it transforms the input sequence into a high-dimensional representation that can be used by the decoder to generate the output sequence.

The decoder’s job is to generate the output sequence one token at a time. It uses the information from the encoder as well as its own hidden states to produce coherent and contextually accurate outputs.

For example, here is a sentence pair (Source,Target)(Source, Target)

Source=(x1,x2,x3,...,xn) Source=(x_{1},x_{2},x_{3},...,x_{n})

Target=(y1,y2,y3,...,yn) Target=(y_{1},y_{2},y_{3},...,y_{n})

The encoder, as its name suggests, encodes the input sentence Source, transforming the input sentence into an intermediate semantic representation C:

C=f(x1,x2,x3,...,xn) C=f(x_{1},x_{2},x_{3},...,x_{n})

For the decoder, its task is to generate the target sentence based on the intermediate semantic representation C of the sentence Source and the historical information generated previously.

yi=g(C,y1,y2,y3,...,yi1) y_{i}=g(C,y_{1},y_{2},y_{3},...,y_{i-1})

Basic Components

Key Components of Attention

The attention mechanism typically involves the following key elements:

Query, Key, and Value

  • Query (Q): Represents what the model is currently focusing on (e.g., the word or token for which attention is being computed).
  • Key (K): Represents the features of the input elements that the model compares against the query.
  • Value (V): Represents the actual information or content of the input elements that the model uses for the output.

Attention Score

The attention score measures the similarity between the query and each key. A common approach is to compute the dot product between the query and the keys, optionally scaled by a factor to stabilize training:

Score(Q,K)=QKTdk(1) \text{Score}(Q, K) = \frac{QK^T}{\sqrt{d_k}} (1)

where dkd_k is the dimension of the key vector.

The implementation of attention algorithms differs based on the way it measures the similarity.

Bahdanau Attention:

e(hi,sj)=Utanh(Vh+Ws) e(h_i, s_j) = Utanh(V_h + W_s)

where UU, VV, and WW are model parameters, and e(h,s)e(h, s) represents a fully connected layer.

Luong Attention:

e(h,s)=hTWs e(h, s) = hTWs

However, both Bahdanau and Luong attention are soft attention mechanisms, calculating αi,j using a softmax function.

Softmax

The scores are passed through a softmax function to produce normalized attention weights. These weights sum to 1, indicating the importance of each input element relative to the others:

Attention Weights=softmax(Score(Q,K))(2) \text{Attention Weights} = \text{softmax}(\text{Score}(Q, K)) (2)

Weighted Sum

The attention weights are applied to the values (V) to compute a weighted sum, which becomes the output of the attention mechanism:

Output=Attention WeightsV(3) \text{Output} = \text{Attention Weights} \cdot V (3)

Working Flow

Encoding Phase

The encoder processes the input(namely the source), transforming them into hidden state [s1,s2,...,sk,...,sT][s_1, s_2, ..., s_k, ...,s_T](namely the key). TT is the length of the input.

This process is the content we introduce in the preliminaries.

Attention Calculation

Calculate the similarity, namely the attention score, between decoder's current hidden state hth_{t}(namely the query) and each element of encoder's output(si,i=1,2...Ts_i, i=1,2...T) with (1).

Derive attention score of sis_i using (2).

Calculate weighted sum using (3). The result is called the context vector, which is passed to the decoder to generate the its next hidden state and the output.

The context vector computed by the attention mechanism, which contains weighted information from the Encoder's hidden states.

Decoding Phase

The next hidden state of the Decoder is determined by three main components:current hidden state, context vector, decoder input(e.g., the previously generated word in machine translation).

Specifically, the next hidden state hth_t is generated by feeding the current hidden state ht1h_{t-1}, the Context Vector ctc_t, and the Decoder input yt1y_{t-1} into an RNNs, LSTM, GRUs. This can be expressed as:

ht=RNN(ht1,[ct,yt1]) h_t = \text{RNN}(h_{t-1}, [c_t, y_{t-1}])

Where:

  • hth_t is the hidden state at the current time step,
  • ht1h_{t-1} is the hidden state at the previous time step,
  • ctc_t is the Context Vector at the current time step,
  • yt1y_{t-1} is the output from the previous time step (Decoder input).

The output of the Decoder is typically generated through the following steps:

  1. Compute Logits Using Hidden State and Context Vector: The hidden state hth_t, the Context Vector ctc_t and the output from the previous time step yt1y_{t-1} are combined and passed through a fully connected layer (often a linear transformation followed by an activation function) to produce unnormalized scores (logits).

    logits=Wo[ht;ct;yt1]+bo \text{logits} = W_o [h_t; c_t; y_{t-1}] + b_o

    Where:

    • WoW_o and bob_o are learnable parameters,
    • [ht;ct][h_t; c_t] represents the concatenation of the hidden state and the Context Vector.
  2. Generate Probability Distribution via Softmax: The logits are passed through a Softmax function to produce a probability distribution over possible outputs.

    p(yty<t,x)=Softmax(logits) p(y_t | y_{<t}, x) = \text{Softmax}(\text{logits})

  3. Select the Output: Based on the probability distribution, the word with the highest probability is selected as the output yty_t for the current time step. Alternatively, sampling methods can be used to generate the output.

The following pictures can clearly illustrate this process.

illustration of Attention Mechanism(1)
illustration of Attention Mechanism(1)
illustration of Attention Mechanism(2)
illustration of Attention Mechanism(2)

Summary

The essence of attention mechanism is that we select important information from the input,giving them more weigths so as to omit redundant information.

The function of attention mechanism is to calculate the weighted sum between query and value. The weights are measured by the similarity between query and key.

Attention(Source,Queryt)=i=1NSimilarity(Queryt,Keyi)Valuei \text{Attention}(\text{Source}, \text{Query}_t) = \sum_{i=1}^{N}\text{Similarity}(\text{Query}_t, \text{Key}_i)\cdot \text{Value}_i

Example & code

提示

To better understand the matrix calculation, we can see the following illustration.

illustration
illustration

Assume:

Key (K): [3, 2](three itmes with the dimension of 2)

Value (V): [3, 2](three itmes with the dimension of 2)

Query (Q): [3, 2](three itmes with the dimension of 2)

Q=[123456],K=[0.511.522.53],v=[102030405060] Q = \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}, \quad {K} = \begin{bmatrix} 0.5 & 1 \\ 1.5 & 2 \\ 2.5 & 3 \end{bmatrix}, \quad {v} = \begin{bmatrix} 10 & 20 \\ 30 & 40 \\ 50 & 60 \end{bmatrix}

Code:

import torch
import torch.nn as nn


class Attention(nn.Module):
    def __init__(self, dim):
        super(Attention, self).__init__()
        '''
        if you want to apply a linear transformation to the input(q, k, v), 
        you can add the following code to your program.
        self.wk = torch.nn.Linear(dim, dim)
        self.wv = torch.nn.Linear(dim, dim)
        self.wq = torch.nn.Linear(dim, dim)
        '''

    def forward(self, query, key, value):
        '''
        if the input is not in batch, we need to 
        add a dimension(batch dimension) to the input.
        '''
        if len(key.shape) == 2:
            key = torch.unsqueeze(key, dim=0)
        if len(query.shape) == 2:
            query = torch.unsqueeze(query, dim=0)
        if len(value.shape) == 2:
            value = torch.unsqueeze(value, dim=0)
        # make sure the input is float type. 
        key = key.float()
        query = query.float()
        value = value.float()
        '''
        if apply linear transformation, add the following code.
        key = self.wk(key)
        value = self.wv(value)
        query = self.wv(query)
        '''
        d_k = key.shape[-1]
        score = torch.bmm(query, key.transpose(-2, -1))/(d_k ** 0.5) 
        print(f"score: \n {score} \n")
        attention_weights = torch.nn.functional.softmax(score, dim=-1)
        print(f"attention weights: \n {attention_weights} \n")
        output = torch.bmm(attention_weights, value)
        return output


# input
Q = torch.tensor([[1, 2],
                  [3, 4],
                  [5, 6]])

K = torch.tensor([[0.5, 1],
                  [1.5, 2],
                  [2.5, 3]])

V = torch.tensor([[10, 20],
                  [30, 40],
                  [50, 60]])

attention_layer = Attention(2)
output = attention_layer(Q, K, V)

print(f"output: \n{output}\n")

Results:

Results
Results
评论
  • 按正序
  • 按倒序
  • 按热度
Powered by Waline v3.1.3