.. _sec_seq2seq_attention: Bahdanau Chú ý ============== Chúng tôi đã nghiên cứu vấn đề dịch máy trong :numref:`sec_seq2seq`, nơi chúng tôi thiết kế một kiến trúc bộ mã hóa-giải mã dựa trên hai RNNs cho trình tự để học trình tự. Cụ thể, bộ mã hóa RNN biến một chuỗi có độ dài biến đổi thành một biến ngữ cảnh hình dạng cố định, sau đó bộ giải mã RNN tạo ra mã thông báo chuỗi đầu ra (mục tiêu) theo mã thông báo dựa trên các token được tạo ra và biến ngữ cảnh. Tuy nhiên, mặc dù không phải tất cả các mã thông báo đầu vào (nguồn) đều hữu ích cho việc giải mã một mã thông báo nhất định, biến ngữ cảnh *same* mã hóa toàn bộ chuỗi đầu vào vẫn được sử dụng ở mỗi bước giải mã. Trong một thách thức riêng biệt nhưng liên quan về thế hệ chữ viết tay cho một chuỗi văn bản nhất định, Graves đã thiết kế một mô hình chú ý khác biệt để căn chỉnh các ký tự văn bản với dấu vết bút dài hơn nhiều, trong đó căn chỉnh chỉ di chuyển theo một hướng :cite:`Graves.2013`. Lấy cảm hứng từ ý tưởng học cách căn chỉnh, Bahdanau et al. đề xuất một mô hình chú ý khác biệt mà không có giới hạn liên kết một chiều nghiêm trọng :cite:`Bahdanau.Cho.Bengio.2014`. Khi dự đoán một mã thông báo, nếu không phải tất cả các token đầu vào đều có liên quan, mô hình sẽ căn chỉnh (hoặc tham dự) chỉ với các phần của chuỗi đầu vào có liên quan đến dự đoán hiện tại. Điều này đạt được bằng cách coi biến ngữ cảnh như một đầu ra của sự chú ý pooling. Mô hình ------- Khi mô tả sự chú ý của Bahdanau cho bộ giải mã RNN bên dưới, chúng tôi sẽ làm theo cùng một ký hiệu trong :numref:`sec_seq2seq`. Mô hình dựa trên sự chú ý mới giống như trong :numref:`sec_seq2seq` ngoại trừ biến ngữ cảnh :math:`\mathbf{c}` trong :eq:`eq_seq2seq_s_t` được thay thế bằng :math:`\mathbf{c}_{t'}` tại bất kỳ bước thời gian giải mã nào :math:`t'`. Giả sử có :math:`T` token trong chuỗi đầu vào, biến ngữ cảnh tại bước thời gian giải mã :math:`t'` là đầu ra của sự chú ý pooling: .. math:: \mathbf{c}_{t'} = \sum_{t=1}^T \alpha(\mathbf{s}_{t' - 1}, \mathbf{h}_t) \mathbf{h}_t, trong đó bộ giải mã ẩn trạng thái :math:`\mathbf{s}_{t' - 1}` tại thời điểm bước :math:`t' - 1` là truy vấn và các trạng thái ẩn mã hóa :math:`\mathbf{h}_t` là cả các phím và giá trị, và trọng lượng chú ý :math:`\alpha` được tính như trong :eq:`eq_attn-scoring-alpha` bằng cách sử dụng chức năng chấm điểm chú ý phụ gia được xác định bởi :eq:`eq_additive-attn`. Hơi khác so với kiến trúc mã hóa giải mã RNN vani trong :numref:`fig_seq2seq_details`, kiến trúc tương tự với sự chú ý Bahdanau được mô tả trong :numref:`fig_s2s_attention_details`. .. _fig_s2s_attention_details: .. figure:: ../img/seq2seq-attention-details.svg Layers in an RNN encoder-decoder model with Bahdanau attention. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python from mxnet import np, npx from mxnet.gluon import nn, rnn from d2l import mxnet as d2l npx.set_np() .. raw:: html
.. raw:: html
.. code:: python import torch from torch import nn from d2l import torch as d2l .. raw:: html
.. raw:: html
.. code:: python import tensorflow as tf from d2l import tensorflow as d2l .. raw:: html
.. raw:: html
Xác định bộ giải mã với sự chú ý -------------------------------- Để thực hiện bộ giải mã mã RNN với sự chú ý của Bahdanau, chúng ta chỉ cần xác định lại bộ giải mã. Để hình dung các trọng lượng chú ý đã học được thuận tiện hơn, lớp ``AttentionDecoder`` sau định nghĩa giao diện cơ bản cho bộ giải mã với cơ chế chú ý. .. code:: python #@save class AttentionDecoder(d2l.Decoder): """The base attention-based decoder interface.""" def __init__(self, **kwargs): super(AttentionDecoder, self).__init__(**kwargs) @property def attention_weights(self): raise NotImplementedError Bây giờ chúng ta hãy triển khai bộ giải mã RNN với sự chú ý Bahdanau trong lớp ``Seq2SeqAttentionDecoder`` sau. Trạng thái của bộ giải mã được khởi tạo với (i) các trạng thái ẩn lớp cuối của bộ mã hóa ở tất cả các bước thời gian (như các phím và giá trị của sự chú ý); (ii) bộ mã hóa tất cả các lớp ẩn trạng thái ở bước thời gian cuối cùng (để khởi tạo trạng thái ẩn của bộ giải mã); và (iii) bộ mã hóa độ dài hợp lệ (để loại trừ các thẻ đệm trong tập hợp chú ý). Tại mỗi bước thời gian giải mã, trạng thái ẩn lớp cuối của bộ giải mã ở bước thời gian trước đó được sử dụng làm truy vấn của sự chú ý. Kết quả là, cả đầu ra chú ý và nhúng đầu vào được nối làm đầu vào của bộ giải mã RNN. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python class Seq2SeqAttentionDecoder(AttentionDecoder): def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs): super(Seq2SeqAttentionDecoder, self).__init__(**kwargs) self.attention = d2l.AdditiveAttention(num_hiddens, dropout) self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = rnn.GRU(num_hiddens, num_layers, dropout=dropout) self.dense = nn.Dense(vocab_size, flatten=False) def init_state(self, enc_outputs, enc_valid_lens, *args): # Shape of `outputs`: (`num_steps`, `batch_size`, `num_hiddens`). # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`, # `num_hiddens`) outputs, hidden_state = enc_outputs return (outputs.swapaxes(0, 1), hidden_state, enc_valid_lens) def forward(self, X, state): # Shape of `enc_outputs`: (`batch_size`, `num_steps`, `num_hiddens`). # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`, # `num_hiddens`) enc_outputs, hidden_state, enc_valid_lens = state # Shape of the output `X`: (`num_steps`, `batch_size`, `embed_size`) X = self.embedding(X).swapaxes(0, 1) outputs, self._attention_weights = [], [] for x in X: # Shape of `query`: (`batch_size`, 1, `num_hiddens`) query = np.expand_dims(hidden_state[0][-1], axis=1) # Shape of `context`: (`batch_size`, 1, `num_hiddens`) context = self.attention( query, enc_outputs, enc_outputs, enc_valid_lens) # Concatenate on the feature dimension x = np.concatenate((context, np.expand_dims(x, axis=1)), axis=-1) # Reshape `x` as (1, `batch_size`, `embed_size` + `num_hiddens`) out, hidden_state = self.rnn(x.swapaxes(0, 1), hidden_state) outputs.append(out) self._attention_weights.append(self.attention.attention_weights) # After fully-connected layer transformation, shape of `outputs`: # (`num_steps`, `batch_size`, `vocab_size`) outputs = self.dense(np.concatenate(outputs, axis=0)) return outputs.swapaxes(0, 1), [enc_outputs, hidden_state, enc_valid_lens] @property def attention_weights(self): return self._attention_weights .. raw:: html
.. raw:: html
.. code:: python class Seq2SeqAttentionDecoder(AttentionDecoder): def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs): super(Seq2SeqAttentionDecoder, self).__init__(**kwargs) self.attention = d2l.AdditiveAttention( num_hiddens, num_hiddens, num_hiddens, dropout) self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = nn.GRU( embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout) self.dense = nn.Linear(num_hiddens, vocab_size) def init_state(self, enc_outputs, enc_valid_lens, *args): # Shape of `outputs`: (`num_steps`, `batch_size`, `num_hiddens`). # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`, # `num_hiddens`) outputs, hidden_state = enc_outputs return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens) def forward(self, X, state): # Shape of `enc_outputs`: (`batch_size`, `num_steps`, `num_hiddens`). # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`, # `num_hiddens`) enc_outputs, hidden_state, enc_valid_lens = state # Shape of the output `X`: (`num_steps`, `batch_size`, `embed_size`) X = self.embedding(X).permute(1, 0, 2) outputs, self._attention_weights = [], [] for x in X: # Shape of `query`: (`batch_size`, 1, `num_hiddens`) query = torch.unsqueeze(hidden_state[-1], dim=1) # Shape of `context`: (`batch_size`, 1, `num_hiddens`) context = self.attention( query, enc_outputs, enc_outputs, enc_valid_lens) # Concatenate on the feature dimension x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1) # Reshape `x` as (1, `batch_size`, `embed_size` + `num_hiddens`) out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state) outputs.append(out) self._attention_weights.append(self.attention.attention_weights) # After fully-connected layer transformation, shape of `outputs`: # (`num_steps`, `batch_size`, `vocab_size`) outputs = self.dense(torch.cat(outputs, dim=0)) return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens] @property def attention_weights(self): return self._attention_weights .. raw:: html
.. raw:: html
.. code:: python class Seq2SeqAttentionDecoder(AttentionDecoder): def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs): super().__init__(**kwargs) self.attention = d2l.AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout) self.embedding = tf.keras.layers.Embedding(vocab_size, embed_size) self.rnn = tf.keras.layers.RNN(tf.keras.layers.StackedRNNCells( [tf.keras.layers.GRUCell(num_hiddens, dropout=dropout) for _ in range(num_layers)]), return_sequences=True, return_state=True) self.dense = tf.keras.layers.Dense(vocab_size) def init_state(self, enc_outputs, enc_valid_lens, *args): # Shape of `outputs`: (`batch_size`, `num_steps`, `num_hiddens`). # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`, `num_hiddens`) outputs, hidden_state = enc_outputs return (outputs, hidden_state, enc_valid_lens) def call(self, X, state, **kwargs): # Shape of `enc_outputs`: (`batch_size`, `num_steps`, `num_hiddens`). # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`, `num_hiddens`) enc_outputs, hidden_state, enc_valid_lens = state # Shape of the output `X`: (`num_steps`, `batch_size`, `embed_size`) X = self.embedding(X) # Input `X` has shape: (`batch_size`, `num_steps`) X = tf.transpose(X, perm=(1, 0, 2)) outputs, self._attention_weights = [], [] for x in X: # Shape of `query`: (`batch_size`, 1, `num_hiddens`) query = tf.expand_dims(hidden_state[-1], axis=1) # Shape of `context`: (`batch_size, 1, `num_hiddens`) context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens, **kwargs) # Concatenate on the feature dimension x = tf.concat((context, tf.expand_dims(x, axis=1)), axis=-1) out = self.rnn(x, hidden_state, **kwargs) hidden_state = out[1:] outputs.append(out[0]) self._attention_weights.append(self.attention.attention_weights) # After fully-connected layer transformation, shape of `outputs`: # (`batch_size`, `num_steps`, `vocab_size`) outputs = self.dense(tf.concat(outputs, axis=1)) return outputs, [enc_outputs, hidden_state, enc_valid_lens] @property def attention_weights(self): return self._attention_weights .. raw:: html
.. raw:: html
Sau đây, chúng tôi test the implemented decoder với Bahdanau chú ý sử dụng một minibatch gồm 4 chuỗi đầu vào của 7 bước thời gian. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) encoder.initialize() decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) decoder.initialize() X = np.zeros((4, 7)) # (`batch_size`, `num_steps`) state = decoder.init_state(encoder(X), None) output, state = decoder(X, state) output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape .. parsed-literal:: :class: output ((4, 7, 10), 3, (4, 7, 16), 1, (2, 4, 16)) .. raw:: html
.. raw:: html
.. code:: python encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) encoder.eval() decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) decoder.eval() X = torch.zeros((4, 7), dtype=torch.long) # (`batch_size`, `num_steps`) state = decoder.init_state(encoder(X), None) output, state = decoder(X, state) output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape .. parsed-literal:: :class: output (torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16])) .. raw:: html
.. raw:: html
.. code:: python encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) X = tf.zeros((4, 7)) state = decoder.init_state(encoder(X, training=False), None) output, state = decoder(X, state, training=False) output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape .. parsed-literal:: :class: output (TensorShape([4, 7, 10]), 3, TensorShape([4, 7, 16]), 2, TensorShape([4, 16])) .. raw:: html
.. raw:: html
Đào tạo ------- Tương tự như :numref:`sec_seq2seq_training`, ở đây chúng tôi chỉ định hyperparemeters, khởi tạo bộ mã hóa và bộ giải mã với sự chú ý của Bahdanau và đào tạo mô hình này để dịch máy. Do cơ chế chú ý mới được thêm vào, việc đào tạo này chậm hơn nhiều so với năm :numref:`sec_seq2seq_training` mà không có cơ chế chú ý. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1 batch_size, num_steps = 64, 10 lr, num_epochs, device = 0.005, 250, d2l.try_gpu() train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps) encoder = d2l.Seq2SeqEncoder( len(src_vocab), embed_size, num_hiddens, num_layers, dropout) decoder = Seq2SeqAttentionDecoder( len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout) net = d2l.EncoderDecoder(encoder, decoder) d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device) .. parsed-literal:: :class: output loss 0.025, 2888.8 tokens/sec on gpu(0) .. figure:: output_bahdanau-attention_7f08d9_41_1.svg .. raw:: html
.. raw:: html
.. code:: python embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1 batch_size, num_steps = 64, 10 lr, num_epochs, device = 0.005, 250, d2l.try_gpu() train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps) encoder = d2l.Seq2SeqEncoder( len(src_vocab), embed_size, num_hiddens, num_layers, dropout) decoder = Seq2SeqAttentionDecoder( len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout) net = d2l.EncoderDecoder(encoder, decoder) d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device) .. parsed-literal:: :class: output loss 0.020, 5580.3 tokens/sec on cuda:0 .. figure:: output_bahdanau-attention_7f08d9_44_1.svg .. raw:: html
.. raw:: html
.. code:: python embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1 batch_size, num_steps = 64, 10 lr, num_epochs, device = 0.005, 250, d2l.try_gpu() train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps) encoder = d2l.Seq2SeqEncoder( len(src_vocab), embed_size, num_hiddens, num_layers, dropout) decoder = Seq2SeqAttentionDecoder( len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout) net = d2l.EncoderDecoder(encoder, decoder) d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device) .. parsed-literal:: :class: output loss 0.028, 510.3 tokens/sec on .. figure:: output_bahdanau-attention_7f08d9_47_1.svg .. raw:: html
.. raw:: html
Sau khi mô hình được đào tạo, chúng tôi sử dụng nó để dịch một vài câu tiếng Anh sang tiếng Pháp và tính điểm BLEU của họ. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .'] fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .'] for eng, fra in zip(engs, fras): translation, dec_attention_weight_seq = d2l.predict_seq2seq( net, eng, src_vocab, tgt_vocab, num_steps, device, True) print(f'{eng} => {translation}, ', f'bleu {d2l.bleu(translation, fra, k=2):.3f}') .. parsed-literal:: :class: output go . => va !, bleu 1.000 i lost . => j'ai perdu ., bleu 1.000 he's calm . => il est riche ., bleu 0.658 i'm home . => je suis chez moi ., bleu 1.000 .. code:: python attention_weights = np.concatenate([step[0][0][0] for step in dec_attention_weight_seq], 0 ).reshape((1, 1, -1, num_steps)) .. raw:: html
.. raw:: html
.. code:: python engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .'] fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .'] for eng, fra in zip(engs, fras): translation, dec_attention_weight_seq = d2l.predict_seq2seq( net, eng, src_vocab, tgt_vocab, num_steps, device, True) print(f'{eng} => {translation}, ', f'bleu {d2l.bleu(translation, fra, k=2):.3f}') .. parsed-literal:: :class: output go . => va !, bleu 1.000 i lost . => j'ai perdu ., bleu 1.000 he's calm . => il est riche ., bleu 0.658 i'm home . => je suis chez moi ., bleu 1.000 .. code:: python attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weight_seq], 0).reshape(( 1, 1, -1, num_steps)) .. raw:: html
.. raw:: html
.. code:: python engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .'] fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .'] for eng, fra in zip(engs, fras): translation, dec_attention_weight_seq = d2l.predict_seq2seq( net, eng, src_vocab, tgt_vocab, num_steps, True) print(f'{eng} => {translation}, ', f'bleu {d2l.bleu(translation, fra, k=2):.3f}') .. parsed-literal:: :class: output go . => va !, bleu 1.000 i lost . => j'ai perdu ., bleu 1.000 he's calm . => il est ., bleu 0.658 i'm home . => je suis calme ., bleu 0.512 .. code:: python attention_weights = tf.reshape( tf.concat([step[0][0][0] for step in dec_attention_weight_seq], 0), (1, 1, -1, num_steps)) .. raw:: html
.. raw:: html
Bằng cách visualizing the attention weights khi dịch câu tiếng Anh cuối cùng, chúng ta có thể thấy rằng mỗi truy vấn gán trọng lượng không thống nhất trên các cặp key-value. Nó cho thấy ở mỗi bước giải mã, các phần khác nhau của chuỗi đầu vào được tổng hợp một cách chọn lọc trong tập hợp sự chú ý. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python # Plus one to include the end-of-sequence token d2l.show_heatmaps( attention_weights[:, :, :, :len(engs[-1].split()) + 1], xlabel='Key positions', ylabel='Query positions') .. figure:: output_bahdanau-attention_7f08d9_68_0.svg .. raw:: html
.. raw:: html
.. code:: python # Plus one to include the end-of-sequence token d2l.show_heatmaps( attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(), xlabel='Key positions', ylabel='Query positions') .. figure:: output_bahdanau-attention_7f08d9_71_0.svg .. raw:: html
.. raw:: html
.. code:: python # Plus one to include the end-of-sequence token d2l.show_heatmaps(attention_weights[:, :, :, :len(engs[-1].split()) + 1], xlabel='Key posistions', ylabel='Query posistions') .. figure:: output_bahdanau-attention_7f08d9_74_0.svg .. raw:: html
.. raw:: html
Tóm tắt ------- - Khi dự đoán một mã thông báo, nếu không phải tất cả các mã thông báo đầu vào đều có liên quan, bộ giải mã RNN với sự chú ý Bahdanau tập hợp chọn lọc các phần khác nhau của chuỗi đầu vào. Điều này đạt được bằng cách coi biến ngữ cảnh như một đầu ra của sự chú ý phụ gia tập hợp. - Trong bộ giải mã RNN, sự chú ý Bahdanau xử lý trạng thái ẩn bộ giải mã ở bước thời gian trước như truy vấn và các trạng thái ẩn mã hóa ở tất cả các bước thời gian như cả các phím và giá trị. Bài tập ------- 1. Thay thế GRU bằng LSTM trong thí nghiệm. 2. Sửa đổi thí nghiệm để thay thế chức năng chấm điểm sự chú ý phụ gia bằng sản phẩm điểm thu nhỏ. Làm thế nào để nó ảnh hưởng đến hiệu quả đào tạo? .. raw:: html
mxnetpytorch
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html