.. _sec_multihead-attention: Chú ý nhiều đầu =============== Trong thực tế, với cùng một tập hợp các truy vấn, khóa và giá trị, chúng tôi có thể muốn mô hình của chúng tôi kết hợp kiến thức từ các hành vi khác nhau của cùng một cơ chế chú ý, chẳng hạn như nắm bắt các phụ thuộc của các phạm vi khác nhau (ví dụ: phạm vi ngắn hơn so với phạm vi dài hơn) trong một chuỗi. Do đó, có thể có lợi khi cho phép cơ chế chú ý của chúng ta cùng sử dụng các không gian con đại diện khác nhau của các truy vấn, khóa và giá trị. Để kết thúc này, thay vì thực hiện một tập hợp chú ý duy nhất, các truy vấn, khóa và giá trị có thể được chuyển đổi với :math:`h` các phép chiếu tuyến tính học độc lập. Sau đó, các truy vấn, khóa và giá trị dự kiến :math:`h` này được đưa vào sự chú ý chung song song. Cuối cùng, :math:`h` đầu ra tập hợp chú ý được nối và chuyển đổi với một phép chiếu tuyến tính khác đã học để tạo ra đầu ra cuối cùng. Thiết kế này được gọi là \* sự chú ý đa đầu\ *, trong đó mỗi đầu ra tập hợp chú ý :math:`h` là đầu * :cite:`Vaswani.Shazeer.Parmar.ea.2017`. Sử dụng các lớp được kết nối hoàn toàn để thực hiện các biến đổi tuyến tính có thể học được, :numref:`fig_multi-head-attention` mô tả sự chú ý nhiều đầu. .. _fig_multi-head-attention: .. figure:: ../img/multi-head-attention.svg Multi-head attention, where multiple heads are concatenated then linearly transformed. Mô hình ------- Trước khi cung cấp việc thực hiện sự chú ý nhiều đầu, chúng ta hãy chính thức hóa mô hình này một cách toán học. Đưa ra một truy vấn :math:`\mathbf{q} \in \mathbb{R}^{d_q}`, một khóa :math:`\mathbf{k} \in \mathbb{R}^{d_k}`, và một giá trị :math:`\mathbf{v} \in \mathbb{R}^{d_v}`, mỗi đầu chú ý :math:`\mathbf{h}_i` (:math:`i = 1, \ldots, h`) được tính là .. math:: \mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v}, trong đó các thông số có thể học được :math:`\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}`, :math:`\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}` và :math:`\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}`, và :math:`f` là sự chú ý tập hợp, chẳng hạn như sự chú ý của phụ gia và sự chú ý của sản phẩm điểm thu nhỏ trong :numref:`sec_attention-scoring-functions`. Đầu ra chú ý nhiều đầu là một chuyển đổi tuyến tính khác thông qua các tham số có thể học được :math:`\mathbf W_o\in\mathbb R^{p_o\times h p_v}` của nối :math:`h` đầu: .. math:: \mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}. Dựa trên thiết kế này, mỗi đầu có thể tham dự các phần khác nhau của đầu vào. Các chức năng phức tạp hơn mức trung bình có trọng số đơn giản có thể được thể hiện. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python import math from mxnet import autograd, np, npx from mxnet.gluon import nn from d2l import mxnet as d2l npx.set_np() .. raw:: html
.. raw:: html
.. code:: python import math import torch from torch import nn from d2l import torch as d2l .. raw:: html
.. raw:: html
.. code:: python import tensorflow as tf from d2l import tensorflow as d2l .. raw:: html
.. raw:: html
Thực hiện --------- Trong triển khai của chúng tôi, chúng tôi chọn sự chú ý của sản phẩm điểm thu nhỏ cho mỗi đầu của sự chú ý nhiều đầu. Để tránh tăng trưởng đáng kể chi phí tính toán và chi phí tham số hóa, chúng tôi đặt :math:`p_q = p_k = p_v = p_o / h`. Lưu ý rằng các đầu :math:`h` có thể được tính toán song song nếu chúng ta đặt số lượng đầu ra của các biến đổi tuyến tính cho truy vấn, khóa và giá trị thành :math:`p_q h = p_k h = p_v h = p_o`. Trong việc triển khai sau đây, :math:`p_o` được chỉ định thông qua đối số ``num_hiddens``. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python #@save class MultiHeadAttention(nn.Block): """Multi-head attention.""" def __init__(self, num_hiddens, num_heads, dropout, use_bias=False, **kwargs): super(MultiHeadAttention, self).__init__(**kwargs) self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False) self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False) self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False) self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False) def forward(self, queries, keys, values, valid_lens): # Shape of `queries`, `keys`, or `values`: # (`batch_size`, no. of queries or key-value pairs, `num_hiddens`) # Shape of `valid_lens`: # (`batch_size`,) or (`batch_size`, no. of queries) # After transposing, shape of output `queries`, `keys`, or `values`: # (`batch_size` * `num_heads`, no. of queries or key-value pairs, # `num_hiddens` / `num_heads`) queries = transpose_qkv(self.W_q(queries), self.num_heads) keys = transpose_qkv(self.W_k(keys), self.num_heads) values = transpose_qkv(self.W_v(values), self.num_heads) if valid_lens is not None: # On axis 0, copy the first item (scalar or vector) for # `num_heads` times, then copy the next item, and so on valid_lens = valid_lens.repeat(self.num_heads, axis=0) # Shape of `output`: (`batch_size` * `num_heads`, no. of queries, # `num_hiddens` / `num_heads`) output = self.attention(queries, keys, values, valid_lens) # Shape of `output_concat`: # (`batch_size`, no. of queries, `num_hiddens`) output_concat = transpose_output(output, self.num_heads) return self.W_o(output_concat) .. raw:: html
.. raw:: html
.. code:: python #@save class MultiHeadAttention(nn.Module): """Multi-head attention.""" def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs): super(MultiHeadAttention, self).__init__(**kwargs) self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = nn.Linear(query_size, num_hiddens, bias=bias) self.W_k = nn.Linear(key_size, num_hiddens, bias=bias) self.W_v = nn.Linear(value_size, num_hiddens, bias=bias) self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias) def forward(self, queries, keys, values, valid_lens): # Shape of `queries`, `keys`, or `values`: # (`batch_size`, no. of queries or key-value pairs, `num_hiddens`) # Shape of `valid_lens`: # (`batch_size`,) or (`batch_size`, no. of queries) # After transposing, shape of output `queries`, `keys`, or `values`: # (`batch_size` * `num_heads`, no. of queries or key-value pairs, # `num_hiddens` / `num_heads`) queries = transpose_qkv(self.W_q(queries), self.num_heads) keys = transpose_qkv(self.W_k(keys), self.num_heads) values = transpose_qkv(self.W_v(values), self.num_heads) if valid_lens is not None: # On axis 0, copy the first item (scalar or vector) for # `num_heads` times, then copy the next item, and so on valid_lens = torch.repeat_interleave( valid_lens, repeats=self.num_heads, dim=0) # Shape of `output`: (`batch_size` * `num_heads`, no. of queries, # `num_hiddens` / `num_heads`) output = self.attention(queries, keys, values, valid_lens) # Shape of `output_concat`: # (`batch_size`, no. of queries, `num_hiddens`) output_concat = transpose_output(output, self.num_heads) return self.W_o(output_concat) .. raw:: html
.. raw:: html
.. code:: python #@save class MultiHeadAttention(tf.keras.layers.Layer): """Multi-head attention.""" def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs): super().__init__(**kwargs) self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=bias) self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=bias) self.W_v = tf.keras.layers.Dense(num_hiddens, use_bias=bias) self.W_o = tf.keras.layers.Dense(num_hiddens, use_bias=bias) def call(self, queries, keys, values, valid_lens, **kwargs): # Shape of `queries`, `keys`, or `values`: # (`batch_size`, no. of queries or key-value pairs, `num_hiddens`) # Shape of `valid_lens`: # (`batch_size`,) or (`batch_size`, no. of queries) # After transposing, shape of output `queries`, `keys`, or `values`: # (`batch_size` * `num_heads`, no. of queries or key-value pairs, # `num_hiddens` / `num_heads`) queries = transpose_qkv(self.W_q(queries), self.num_heads) keys = transpose_qkv(self.W_k(keys), self.num_heads) values = transpose_qkv(self.W_v(values), self.num_heads) if valid_lens is not None: # On axis 0, copy the first item (scalar or vector) for # `num_heads` times, then copy the next item, and so on valid_lens = tf.repeat(valid_lens, repeats=self.num_heads, axis=0) # Shape of `output`: (`batch_size` * `num_heads`, no. of queries, `num_hiddens` / `num_heads`) output = self.attention(queries, keys, values, valid_lens, **kwargs) # Shape of `output_concat`: (`batch_size`, no. of queries, `num_hiddens`) output_concat = transpose_output(output, self.num_heads) return self.W_o(output_concat) .. raw:: html
.. raw:: html
Để cho phép tính toán song song của nhiều đầu, lớp ``MultiHeadAttention`` trên sử dụng hai hàm chuyển vị như định nghĩa bên dưới. Cụ thể, chức năng ``transpose_output`` đảo ngược hoạt động của hàm ``transpose_qkv``. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python #@save def transpose_qkv(X, num_heads): """Transposition for parallel computation of multiple attention heads.""" # Shape of input `X`: # (`batch_size`, no. of queries or key-value pairs, `num_hiddens`). # Shape of output `X`: # (`batch_size`, no. of queries or key-value pairs, `num_heads`, # `num_hiddens` / `num_heads`) X = X.reshape(X.shape[0], X.shape[1], num_heads, -1) # Shape of output `X`: # (`batch_size`, `num_heads`, no. of queries or key-value pairs, # `num_hiddens` / `num_heads`) X = X.transpose(0, 2, 1, 3) # Shape of `output`: # (`batch_size` * `num_heads`, no. of queries or key-value pairs, # `num_hiddens` / `num_heads`) return X.reshape(-1, X.shape[2], X.shape[3]) #@save def transpose_output(X, num_heads): """Reverse the operation of `transpose_qkv`.""" X = X.reshape(-1, num_heads, X.shape[1], X.shape[2]) X = X.transpose(0, 2, 1, 3) return X.reshape(X.shape[0], X.shape[1], -1) .. raw:: html
.. raw:: html
.. code:: python #@save def transpose_qkv(X, num_heads): """Transposition for parallel computation of multiple attention heads.""" # Shape of input `X`: # (`batch_size`, no. of queries or key-value pairs, `num_hiddens`). # Shape of output `X`: # (`batch_size`, no. of queries or key-value pairs, `num_heads`, # `num_hiddens` / `num_heads`) X = X.reshape(X.shape[0], X.shape[1], num_heads, -1) # Shape of output `X`: # (`batch_size`, `num_heads`, no. of queries or key-value pairs, # `num_hiddens` / `num_heads`) X = X.permute(0, 2, 1, 3) # Shape of `output`: # (`batch_size` * `num_heads`, no. of queries or key-value pairs, # `num_hiddens` / `num_heads`) return X.reshape(-1, X.shape[2], X.shape[3]) #@save def transpose_output(X, num_heads): """Reverse the operation of `transpose_qkv`.""" X = X.reshape(-1, num_heads, X.shape[1], X.shape[2]) X = X.permute(0, 2, 1, 3) return X.reshape(X.shape[0], X.shape[1], -1) .. raw:: html
.. raw:: html
.. code:: python #@save def transpose_qkv(X, num_heads): """Transposition for parallel computation of multiple attention heads.""" # Shape of input `X`: # (`batch_size`, no. of queries or key-value pairs, `num_hiddens`). # Shape of output `X`: # (`batch_size`, no. of queries or key-value pairs, `num_heads`, # `num_hiddens` / `num_heads`) X = tf.reshape(X, shape=(X.shape[0], X.shape[1], num_heads, -1)) # Shape of output `X`: # (`batch_size`, `num_heads`, no. of queries or key-value pairs, # `num_hiddens` / `num_heads`) X = tf.transpose(X, perm=(0, 2, 1, 3)) # Shape of `output`: # (`batch_size` * `num_heads`, no. of queries or key-value pairs, # `num_hiddens` / `num_heads`) return tf.reshape(X, shape=(-1, X.shape[2], X.shape[3])) #@save def transpose_output(X, num_heads): """Reverse the operation of `transpose_qkv`.""" X = tf.reshape(X, shape=(-1, num_heads, X.shape[1], X.shape[2])) X = tf.transpose(X, perm=(0, 2, 1, 3)) return tf.reshape(X, shape=(X.shape[0], X.shape[1], -1)) .. raw:: html
.. raw:: html
Hãy để chúng tôi test our implemented ``MultiHeadAttention`` class sử dụng một ví dụ đồ chơi trong đó các phím và giá trị giống nhau. Kết quả là, hình dạng của đầu ra chú ý nhiều đầu là (``batch_size``, ``num_queries``, ``num_hiddens``). .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_heads, 0.5) attention.initialize() batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, np.array([3, 2]) X = np.ones((batch_size, num_queries, num_hiddens)) Y = np.ones((batch_size, num_kvpairs, num_hiddens)) attention(X, Y, Y, valid_lens).shape .. parsed-literal:: :class: output (2, 4, 100) .. raw:: html
.. raw:: html
.. code:: python num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5) attention.eval() .. parsed-literal:: :class: output MultiHeadAttention( (attention): DotProductAttention( (dropout): Dropout(p=0.5, inplace=False) ) (W_q): Linear(in_features=100, out_features=100, bias=False) (W_k): Linear(in_features=100, out_features=100, bias=False) (W_v): Linear(in_features=100, out_features=100, bias=False) (W_o): Linear(in_features=100, out_features=100, bias=False) ) .. code:: python batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, torch.tensor([3, 2]) X = torch.ones((batch_size, num_queries, num_hiddens)) Y = torch.ones((batch_size, num_kvpairs, num_hiddens)) attention(X, Y, Y, valid_lens).shape .. parsed-literal:: :class: output torch.Size([2, 4, 100]) .. raw:: html
.. raw:: html
.. code:: python num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5) batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, tf.constant([3, 2]) X = tf.ones((batch_size, num_queries, num_hiddens)) Y = tf.ones((batch_size, num_kvpairs, num_hiddens)) attention(X, Y, Y, valid_lens, training=False).shape .. parsed-literal:: :class: output TensorShape([2, 4, 100]) .. raw:: html
.. raw:: html
Tóm tắt ------- - Sự chú ý nhiều đầu kết hợp kiến thức về cùng một tập hợp sự chú ý thông qua các không gian con đại diện khác nhau của các truy vấn, khóa và giá trị. - Để tính toán nhiều đầu của sự chú ý nhiều đầu song song, cần phải thao tác tensor thích hợp. Bài tập ------- 1. Hình dung trọng lượng chú ý của nhiều đầu trong thí nghiệm này. 2. Giả sử rằng chúng ta có một mô hình được đào tạo dựa trên sự chú ý nhiều đầu và chúng ta muốn cắt tỉa những đầu chú ý ít quan trọng nhất để tăng tốc độ dự đoán. Làm thế nào chúng ta có thể thiết kế các thí nghiệm để đo lường tầm quan trọng của một đầu chú ý? .. raw:: html
mxnetpytorch
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html