10.5. Chú ý nhiều đầu¶
Trong thực tế, với cùng một tập hợp các truy vấn, khóa và giá trị, chúng tôi có thể muốn mô hình của chúng tôi kết hợp kiến thức từ các hành vi khác nhau của cùng một cơ chế chú ý, chẳng hạn như nắm bắt các phụ thuộc của các phạm vi khác nhau (ví dụ: phạm vi ngắn hơn so với phạm vi dài hơn) trong một chuỗi. Do đó, có thể có lợi khi cho phép cơ chế chú ý của chúng ta cùng sử dụng các không gian con đại diện khác nhau của các truy vấn, khóa và giá trị.
Để kết thúc này, thay vì thực hiện một tập hợp chú ý duy nhất, các truy vấn, khóa và giá trị có thể được chuyển đổi với \(h\) các phép chiếu tuyến tính học độc lập. Sau đó, các truy vấn, khóa và giá trị dự kiến \(h\) này được đưa vào sự chú ý chung song song. Cuối cùng, \(h\) đầu ra tập hợp chú ý được nối và chuyển đổi với một phép chiếu tuyến tính khác đã học để tạo ra đầu ra cuối cùng. Thiết kế này được gọi là * sự chú ý đa đầu*, trong đó mỗi đầu ra tập hợp chú ý \(h\) là đầu * [Vaswani et al., 2017]. Sử dụng các lớp được kết nối hoàn toàn để thực hiện các biến đổi tuyến tính có thể học được, Fig. 10.5.1 mô tả sự chú ý nhiều đầu.
Fig. 10.5.1 Multi-head attention, where multiple heads are concatenated then linearly transformed.¶
10.5.1. Mô hình¶
Trước khi cung cấp việc thực hiện sự chú ý nhiều đầu, chúng ta hãy chính thức hóa mô hình này một cách toán học. Đưa ra một truy vấn \(\mathbf{q} \in \mathbb{R}^{d_q}\), một khóa \(\mathbf{k} \in \mathbb{R}^{d_k}\), và một giá trị \(\mathbf{v} \in \mathbb{R}^{d_v}\), mỗi đầu chú ý \(\mathbf{h}_i\) (\(i = 1, \ldots, h\)) được tính là
trong đó các thông số có thể học được \(\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}\), \(\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}\) và \(\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}\), và \(f\) là sự chú ý tập hợp, chẳng hạn như sự chú ý của phụ gia và sự chú ý của sản phẩm điểm thu nhỏ trong Section 10.3. Đầu ra chú ý nhiều đầu là một chuyển đổi tuyến tính khác thông qua các tham số có thể học được \(\mathbf W_o\in\mathbb R^{p_o\times h p_v}\) của nối \(h\) đầu:
Dựa trên thiết kế này, mỗi đầu có thể tham dự các phần khác nhau của đầu vào. Các chức năng phức tạp hơn mức trung bình có trọng số đơn giản có thể được thể hiện.
import math
from mxnet import autograd, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
import math
import torch
from torch import nn
from d2l import torch as d2l
import tensorflow as tf
from d2l import tensorflow as d2l
10.5.2. Thực hiện¶
Trong triển khai của chúng tôi, chúng tôi chọn sự chú ý của sản phẩm
điểm thu nhỏ cho mỗi đầu của sự chú ý nhiều đầu. Để tránh tăng trưởng
đáng kể chi phí tính toán và chi phí tham số hóa, chúng tôi đặt
\(p_q = p_k = p_v = p_o / h\). Lưu ý rằng các đầu \(h\) có thể
được tính toán song song nếu chúng ta đặt số lượng đầu ra của các biến
đổi tuyến tính cho truy vấn, khóa và giá trị thành
\(p_q h = p_k h = p_v h = p_o\). Trong việc triển khai sau đây,
\(p_o\) được chỉ định thông qua đối số num_hiddens
.
#@save
class MultiHeadAttention(nn.Block):
"""Multi-head attention."""
def __init__(self, num_hiddens, num_heads, dropout, use_bias=False,
**kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
def forward(self, queries, keys, values, valid_lens):
# Shape of `queries`, `keys`, or `values`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`)
# Shape of `valid_lens`:
# (`batch_size`,) or (`batch_size`, no. of queries)
# After transposing, shape of output `queries`, `keys`, or `values`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
# On axis 0, copy the first item (scalar or vector) for
# `num_heads` times, then copy the next item, and so on
valid_lens = valid_lens.repeat(self.num_heads, axis=0)
# Shape of `output`: (`batch_size` * `num_heads`, no. of queries,
# `num_hiddens` / `num_heads`)
output = self.attention(queries, keys, values, valid_lens)
# Shape of `output_concat`:
# (`batch_size`, no. of queries, `num_hiddens`)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
#@save
class MultiHeadAttention(nn.Module):
"""Multi-head attention."""
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
# Shape of `queries`, `keys`, or `values`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`)
# Shape of `valid_lens`:
# (`batch_size`,) or (`batch_size`, no. of queries)
# After transposing, shape of output `queries`, `keys`, or `values`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
# On axis 0, copy the first item (scalar or vector) for
# `num_heads` times, then copy the next item, and so on
valid_lens = torch.repeat_interleave(
valid_lens, repeats=self.num_heads, dim=0)
# Shape of `output`: (`batch_size` * `num_heads`, no. of queries,
# `num_hiddens` / `num_heads`)
output = self.attention(queries, keys, values, valid_lens)
# Shape of `output_concat`:
# (`batch_size`, no. of queries, `num_hiddens`)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
#@save
class MultiHeadAttention(tf.keras.layers.Layer):
"""Multi-head attention."""
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
self.W_v = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
self.W_o = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
def call(self, queries, keys, values, valid_lens, **kwargs):
# Shape of `queries`, `keys`, or `values`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`)
# Shape of `valid_lens`:
# (`batch_size`,) or (`batch_size`, no. of queries)
# After transposing, shape of output `queries`, `keys`, or `values`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
# On axis 0, copy the first item (scalar or vector) for
# `num_heads` times, then copy the next item, and so on
valid_lens = tf.repeat(valid_lens, repeats=self.num_heads, axis=0)
# Shape of `output`: (`batch_size` * `num_heads`, no. of queries, `num_hiddens` / `num_heads`)
output = self.attention(queries, keys, values, valid_lens, **kwargs)
# Shape of `output_concat`: (`batch_size`, no. of queries, `num_hiddens`)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
Để cho phép tính toán song song của nhiều đầu, lớp
MultiHeadAttention
trên sử dụng hai hàm chuyển vị như định nghĩa bên
dưới. Cụ thể, chức năng transpose_output
đảo ngược hoạt động của hàm
transpose_qkv
.
#@save
def transpose_qkv(X, num_heads):
"""Transposition for parallel computation of multiple attention heads."""
# Shape of input `X`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`).
# Shape of output `X`:
# (`batch_size`, no. of queries or key-value pairs, `num_heads`,
# `num_hiddens` / `num_heads`)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# Shape of output `X`:
# (`batch_size`, `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
X = X.transpose(0, 2, 1, 3)
# Shape of `output`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
return X.reshape(-1, X.shape[2], X.shape[3])
#@save
def transpose_output(X, num_heads):
"""Reverse the operation of `transpose_qkv`."""
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.transpose(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
#@save
def transpose_qkv(X, num_heads):
"""Transposition for parallel computation of multiple attention heads."""
# Shape of input `X`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`).
# Shape of output `X`:
# (`batch_size`, no. of queries or key-value pairs, `num_heads`,
# `num_hiddens` / `num_heads`)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# Shape of output `X`:
# (`batch_size`, `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
X = X.permute(0, 2, 1, 3)
# Shape of `output`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
return X.reshape(-1, X.shape[2], X.shape[3])
#@save
def transpose_output(X, num_heads):
"""Reverse the operation of `transpose_qkv`."""
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
#@save
def transpose_qkv(X, num_heads):
"""Transposition for parallel computation of multiple attention heads."""
# Shape of input `X`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`).
# Shape of output `X`:
# (`batch_size`, no. of queries or key-value pairs, `num_heads`,
# `num_hiddens` / `num_heads`)
X = tf.reshape(X, shape=(X.shape[0], X.shape[1], num_heads, -1))
# Shape of output `X`:
# (`batch_size`, `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
X = tf.transpose(X, perm=(0, 2, 1, 3))
# Shape of `output`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
return tf.reshape(X, shape=(-1, X.shape[2], X.shape[3]))
#@save
def transpose_output(X, num_heads):
"""Reverse the operation of `transpose_qkv`."""
X = tf.reshape(X, shape=(-1, num_heads, X.shape[1], X.shape[2]))
X = tf.transpose(X, perm=(0, 2, 1, 3))
return tf.reshape(X, shape=(X.shape[0], X.shape[1], -1))
Hãy để chúng tôi test our implemented MultiHeadAttention
class sử
dụng một ví dụ đồ chơi trong đó các phím và giá trị giống nhau. Kết quả
là, hình dạng của đầu ra chú ý nhiều đầu là (batch_size
,
num_queries
, num_hiddens
).
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()
batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
Y = np.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
(2, 4, 100)
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention(
(attention): DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
(W_q): Linear(in_features=100, out_features=100, bias=False)
(W_k): Linear(in_features=100, out_features=100, bias=False)
(W_v): Linear(in_features=100, out_features=100, bias=False)
(W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
torch.Size([2, 4, 100])
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, tf.constant([3, 2])
X = tf.ones((batch_size, num_queries, num_hiddens))
Y = tf.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens, training=False).shape
TensorShape([2, 4, 100])
10.5.3. Tóm tắt¶
Sự chú ý nhiều đầu kết hợp kiến thức về cùng một tập hợp sự chú ý thông qua các không gian con đại diện khác nhau của các truy vấn, khóa và giá trị.
Để tính toán nhiều đầu của sự chú ý nhiều đầu song song, cần phải thao tác tensor thích hợp.
10.5.4. Bài tập¶
Hình dung trọng lượng chú ý của nhiều đầu trong thí nghiệm này.
Giả sử rằng chúng ta có một mô hình được đào tạo dựa trên sự chú ý nhiều đầu và chúng ta muốn cắt tỉa những đầu chú ý ít quan trọng nhất để tăng tốc độ dự đoán. Làm thế nào chúng ta có thể thiết kế các thí nghiệm để đo lường tầm quan trọng của một đầu chú ý?