.. _sec_multihead-attention:
Chú ý nhiều đầu
===============
Trong thực tế, với cùng một tập hợp các truy vấn, khóa và giá trị, chúng
tôi có thể muốn mô hình của chúng tôi kết hợp kiến thức từ các hành vi
khác nhau của cùng một cơ chế chú ý, chẳng hạn như nắm bắt các phụ thuộc
của các phạm vi khác nhau (ví dụ: phạm vi ngắn hơn so với phạm vi dài
hơn) trong một chuỗi. Do đó, có thể có lợi khi cho phép cơ chế chú ý của
chúng ta cùng sử dụng các không gian con đại diện khác nhau của các truy
vấn, khóa và giá trị.
Để kết thúc này, thay vì thực hiện một tập hợp chú ý duy nhất, các truy
vấn, khóa và giá trị có thể được chuyển đổi với :math:`h` các phép chiếu
tuyến tính học độc lập. Sau đó, các truy vấn, khóa và giá trị dự kiến
:math:`h` này được đưa vào sự chú ý chung song song. Cuối cùng,
:math:`h` đầu ra tập hợp chú ý được nối và chuyển đổi với một phép chiếu
tuyến tính khác đã học để tạo ra đầu ra cuối cùng. Thiết kế này được gọi
là \* sự chú ý đa đầu\ *, trong đó mỗi đầu ra tập hợp chú ý :math:`h` là
đầu * :cite:`Vaswani.Shazeer.Parmar.ea.2017`. Sử dụng các lớp được kết
nối hoàn toàn để thực hiện các biến đổi tuyến tính có thể học được,
:numref:`fig_multi-head-attention` mô tả sự chú ý nhiều đầu.
.. _fig_multi-head-attention:
.. figure:: ../img/multi-head-attention.svg
Multi-head attention, where multiple heads are concatenated then
linearly transformed.
Mô hình
-------
Trước khi cung cấp việc thực hiện sự chú ý nhiều đầu, chúng ta hãy chính
thức hóa mô hình này một cách toán học. Đưa ra một truy vấn
:math:`\mathbf{q} \in \mathbb{R}^{d_q}`, một khóa
:math:`\mathbf{k} \in \mathbb{R}^{d_k}`, và một giá trị
:math:`\mathbf{v} \in \mathbb{R}^{d_v}`, mỗi đầu chú ý
:math:`\mathbf{h}_i` (:math:`i = 1, \ldots, h`) được tính là
.. math:: \mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},
trong đó các thông số có thể học được
:math:`\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}`,
:math:`\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}` và
:math:`\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}`, và :math:`f` là
sự chú ý tập hợp, chẳng hạn như sự chú ý của phụ gia và sự chú ý của sản
phẩm điểm thu nhỏ trong :numref:`sec_attention-scoring-functions`. Đầu
ra chú ý nhiều đầu là một chuyển đổi tuyến tính khác thông qua các tham
số có thể học được :math:`\mathbf W_o\in\mathbb R^{p_o\times h p_v}` của
nối :math:`h` đầu:
.. math:: \mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.
Dựa trên thiết kế này, mỗi đầu có thể tham dự các phần khác nhau của đầu
vào. Các chức năng phức tạp hơn mức trung bình có trọng số đơn giản có
thể được thể hiện.
.. raw:: html
.. raw:: html
.. code:: python
import math
from mxnet import autograd, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
.. raw:: html
.. raw:: html
.. code:: python
import math
import torch
from torch import nn
from d2l import torch as d2l
.. raw:: html
.. raw:: html
.. code:: python
import tensorflow as tf
from d2l import tensorflow as d2l
.. raw:: html
.. raw:: html
Thực hiện
---------
Trong triển khai của chúng tôi, chúng tôi chọn sự chú ý của sản phẩm
điểm thu nhỏ cho mỗi đầu của sự chú ý nhiều đầu. Để tránh tăng trưởng
đáng kể chi phí tính toán và chi phí tham số hóa, chúng tôi đặt
:math:`p_q = p_k = p_v = p_o / h`. Lưu ý rằng các đầu :math:`h` có thể
được tính toán song song nếu chúng ta đặt số lượng đầu ra của các biến
đổi tuyến tính cho truy vấn, khóa và giá trị thành
:math:`p_q h = p_k h = p_v h = p_o`. Trong việc triển khai sau đây,
:math:`p_o` được chỉ định thông qua đối số ``num_hiddens``.
.. raw:: html
.. raw:: html
.. code:: python
#@save
class MultiHeadAttention(nn.Block):
"""Multi-head attention."""
def __init__(self, num_hiddens, num_heads, dropout, use_bias=False,
**kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
def forward(self, queries, keys, values, valid_lens):
# Shape of `queries`, `keys`, or `values`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`)
# Shape of `valid_lens`:
# (`batch_size`,) or (`batch_size`, no. of queries)
# After transposing, shape of output `queries`, `keys`, or `values`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
# On axis 0, copy the first item (scalar or vector) for
# `num_heads` times, then copy the next item, and so on
valid_lens = valid_lens.repeat(self.num_heads, axis=0)
# Shape of `output`: (`batch_size` * `num_heads`, no. of queries,
# `num_hiddens` / `num_heads`)
output = self.attention(queries, keys, values, valid_lens)
# Shape of `output_concat`:
# (`batch_size`, no. of queries, `num_hiddens`)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
.. raw:: html
.. raw:: html
.. code:: python
#@save
class MultiHeadAttention(nn.Module):
"""Multi-head attention."""
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
# Shape of `queries`, `keys`, or `values`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`)
# Shape of `valid_lens`:
# (`batch_size`,) or (`batch_size`, no. of queries)
# After transposing, shape of output `queries`, `keys`, or `values`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
# On axis 0, copy the first item (scalar or vector) for
# `num_heads` times, then copy the next item, and so on
valid_lens = torch.repeat_interleave(
valid_lens, repeats=self.num_heads, dim=0)
# Shape of `output`: (`batch_size` * `num_heads`, no. of queries,
# `num_hiddens` / `num_heads`)
output = self.attention(queries, keys, values, valid_lens)
# Shape of `output_concat`:
# (`batch_size`, no. of queries, `num_hiddens`)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
.. raw:: html
.. raw:: html
.. code:: python
#@save
class MultiHeadAttention(tf.keras.layers.Layer):
"""Multi-head attention."""
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
self.W_v = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
self.W_o = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
def call(self, queries, keys, values, valid_lens, **kwargs):
# Shape of `queries`, `keys`, or `values`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`)
# Shape of `valid_lens`:
# (`batch_size`,) or (`batch_size`, no. of queries)
# After transposing, shape of output `queries`, `keys`, or `values`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
# On axis 0, copy the first item (scalar or vector) for
# `num_heads` times, then copy the next item, and so on
valid_lens = tf.repeat(valid_lens, repeats=self.num_heads, axis=0)
# Shape of `output`: (`batch_size` * `num_heads`, no. of queries, `num_hiddens` / `num_heads`)
output = self.attention(queries, keys, values, valid_lens, **kwargs)
# Shape of `output_concat`: (`batch_size`, no. of queries, `num_hiddens`)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
.. raw:: html
.. raw:: html
Để cho phép tính toán song song của nhiều đầu, lớp
``MultiHeadAttention`` trên sử dụng hai hàm chuyển vị như định nghĩa bên
dưới. Cụ thể, chức năng ``transpose_output`` đảo ngược hoạt động của hàm
``transpose_qkv``.
.. raw:: html
.. raw:: html
.. code:: python
#@save
def transpose_qkv(X, num_heads):
"""Transposition for parallel computation of multiple attention heads."""
# Shape of input `X`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`).
# Shape of output `X`:
# (`batch_size`, no. of queries or key-value pairs, `num_heads`,
# `num_hiddens` / `num_heads`)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# Shape of output `X`:
# (`batch_size`, `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
X = X.transpose(0, 2, 1, 3)
# Shape of `output`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
return X.reshape(-1, X.shape[2], X.shape[3])
#@save
def transpose_output(X, num_heads):
"""Reverse the operation of `transpose_qkv`."""
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.transpose(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
.. raw:: html
.. raw:: html
.. code:: python
#@save
def transpose_qkv(X, num_heads):
"""Transposition for parallel computation of multiple attention heads."""
# Shape of input `X`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`).
# Shape of output `X`:
# (`batch_size`, no. of queries or key-value pairs, `num_heads`,
# `num_hiddens` / `num_heads`)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# Shape of output `X`:
# (`batch_size`, `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
X = X.permute(0, 2, 1, 3)
# Shape of `output`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
return X.reshape(-1, X.shape[2], X.shape[3])
#@save
def transpose_output(X, num_heads):
"""Reverse the operation of `transpose_qkv`."""
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
.. raw:: html
.. raw:: html
.. code:: python
#@save
def transpose_qkv(X, num_heads):
"""Transposition for parallel computation of multiple attention heads."""
# Shape of input `X`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`).
# Shape of output `X`:
# (`batch_size`, no. of queries or key-value pairs, `num_heads`,
# `num_hiddens` / `num_heads`)
X = tf.reshape(X, shape=(X.shape[0], X.shape[1], num_heads, -1))
# Shape of output `X`:
# (`batch_size`, `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
X = tf.transpose(X, perm=(0, 2, 1, 3))
# Shape of `output`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
return tf.reshape(X, shape=(-1, X.shape[2], X.shape[3]))
#@save
def transpose_output(X, num_heads):
"""Reverse the operation of `transpose_qkv`."""
X = tf.reshape(X, shape=(-1, num_heads, X.shape[1], X.shape[2]))
X = tf.transpose(X, perm=(0, 2, 1, 3))
return tf.reshape(X, shape=(X.shape[0], X.shape[1], -1))
.. raw:: html
.. raw:: html
Hãy để chúng tôi test our implemented ``MultiHeadAttention`` class sử
dụng một ví dụ đồ chơi trong đó các phím và giá trị giống nhau. Kết quả
là, hình dạng của đầu ra chú ý nhiều đầu là (``batch_size``,
``num_queries``, ``num_hiddens``).
.. raw:: html
.. raw:: html
.. code:: python
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()
batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
Y = np.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
.. parsed-literal::
:class: output
(2, 4, 100)
.. raw:: html
.. raw:: html
.. code:: python
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attention.eval()
.. parsed-literal::
:class: output
MultiHeadAttention(
(attention): DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
(W_q): Linear(in_features=100, out_features=100, bias=False)
(W_k): Linear(in_features=100, out_features=100, bias=False)
(W_v): Linear(in_features=100, out_features=100, bias=False)
(W_o): Linear(in_features=100, out_features=100, bias=False)
)
.. code:: python
batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
.. parsed-literal::
:class: output
torch.Size([2, 4, 100])
.. raw:: html
.. raw:: html
.. code:: python
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, tf.constant([3, 2])
X = tf.ones((batch_size, num_queries, num_hiddens))
Y = tf.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens, training=False).shape
.. parsed-literal::
:class: output
TensorShape([2, 4, 100])
.. raw:: html
.. raw:: html
Tóm tắt
-------
- Sự chú ý nhiều đầu kết hợp kiến thức về cùng một tập hợp sự chú ý
thông qua các không gian con đại diện khác nhau của các truy vấn,
khóa và giá trị.
- Để tính toán nhiều đầu của sự chú ý nhiều đầu song song, cần phải
thao tác tensor thích hợp.
Bài tập
-------
1. Hình dung trọng lượng chú ý của nhiều đầu trong thí nghiệm này.
2. Giả sử rằng chúng ta có một mô hình được đào tạo dựa trên sự chú ý
nhiều đầu và chúng ta muốn cắt tỉa những đầu chú ý ít quan trọng nhất
để tăng tốc độ dự đoán. Làm thế nào chúng ta có thể thiết kế các thí
nghiệm để đo lường tầm quan trọng của một đầu chú ý?
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html