.. _sec_natural-language-inference-bert:
Suy luận ngôn ngữ tự nhiên: Tinh chỉnh BERT
===========================================
Trong các phần trước của chương này, chúng tôi đã thiết kế một kiến trúc
dựa trên sự chú ý (trong
:numref:`sec_natural-language-inference-attention`) cho nhiệm vụ suy
luận ngôn ngữ tự nhiên trên tập dữ liệu SNLI (như được mô tả trong
:numref:`sec_natural-language-inference-and-dataset`). Bây giờ chúng
tôi xem lại nhiệm vụ này bằng cách tinh chỉnh BERT. Như đã thảo luận
trong :numref:`sec_finetuning-bert`, suy luận ngôn ngữ tự nhiên là một
bài toán phân loại cặp văn bản cấp trình tự, và tinh chỉnh BERT chỉ đòi
hỏi một kiến trúc dựa trên MLP bổ sung, như minh họa trong
:numref:`fig_nlp-map-nli-bert`.
.. _fig_nlp-map-nli-bert:
.. figure:: ../img/nlp-map-nli-bert.svg
This section feeds pretrained BERT to an MLP-based architecture for
natural language inference.
Trong phần này, chúng tôi sẽ tải xuống một phiên bản nhỏ được đào tạo
trước của BERT, sau đó tinh chỉnh nó để suy luận ngôn ngữ tự nhiên trên
bộ dữ liệu SNLI.
.. raw:: html
.. raw:: html
.. code:: python
import json
import multiprocessing
import os
from mxnet import gluon, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
.. raw:: html
.. raw:: html
.. code:: python
import json
import multiprocessing
import os
import torch
from torch import nn
from d2l import torch as d2l
.. raw:: html
.. raw:: html
Đang tải BERT Pretrained
------------------------
Chúng tôi đã giải thích cách chuẩn bị BERT trên bộ dữ liệu WikiText-2
trong :numref:`sec_bert-dataset` và :numref:`sec_bert-pretraining`
(lưu ý rằng mô hình BERT ban đầu được đào tạo trước trên corpora lớn hơn
nhiều). Như đã thảo luận trong :numref:`sec_bert-pretraining`, mô hình
BERT ban đầu có hàng trăm triệu thông số. Trong phần sau, chúng tôi cung
cấp hai phiên bản BERT được đào tạo trước: “bert.base” lớn bằng mô hình
cơ sở BERT ban đầu đòi hỏi rất nhiều tài nguyên tính toán để tinh chỉnh,
trong khi “bert.small” là một phiên bản nhỏ để tạo điều kiện cho trình
diễn.
.. raw:: html
.. raw:: html
.. code:: python
d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.zip',
'7b3820b35da691042e5d34c0971ac3edbd80d3f4')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.zip',
'a4e718a47137ccd1809c9107ab4f5edd317bae2c')
.. raw:: html
.. raw:: html
.. code:: python
d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip',
'225d66f04cae318b841a13d32af3acc165f253ac')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip',
'c72329e68a732bef0452e4b96a1c341c8910f81f')
.. raw:: html
.. raw:: html
Hoặc mô hình BERT được đào tạo trước chứa một tập tin “vocab.json” xác
định tập từ vựng và một tập tin “pretrained.params” của các tham số được
đào tạo trước. Chúng tôi thực hiện chức năng ``load_pretrained_model``
sau đây để tải các thông số BERT được đào tạo trước.
.. raw:: html
.. raw:: html
.. code:: python
def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
num_heads, num_layers, dropout, max_len, devices):
data_dir = d2l.download_extract(pretrained_model)
# Define an empty vocabulary to load the predefined vocabulary
vocab = d2l.Vocab()
vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
vocab.token_to_idx = {token: idx for idx, token in enumerate(
vocab.idx_to_token)}
bert = d2l.BERTModel(len(vocab), num_hiddens, ffn_num_hiddens, num_heads,
num_layers, dropout, max_len)
# Load pretrained BERT parameters
bert.load_parameters(os.path.join(data_dir, 'pretrained.params'),
ctx=devices)
return bert, vocab
.. raw:: html
.. raw:: html
.. code:: python
def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
num_heads, num_layers, dropout, max_len, devices):
data_dir = d2l.download_extract(pretrained_model)
# Define an empty vocabulary to load the predefined vocabulary
vocab = d2l.Vocab()
vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
vocab.token_to_idx = {token: idx for idx, token in enumerate(
vocab.idx_to_token)}
bert = d2l.BERTModel(len(vocab), num_hiddens, norm_shape=[256],
ffn_num_input=256, ffn_num_hiddens=ffn_num_hiddens,
num_heads=4, num_layers=2, dropout=0.2,
max_len=max_len, key_size=256, query_size=256,
value_size=256, hid_in_features=256,
mlm_in_features=256, nsp_in_features=256)
# Load pretrained BERT parameters
bert.load_state_dict(torch.load(os.path.join(data_dir,
'pretrained.params')))
return bert, vocab
.. raw:: html
.. raw:: html
Để tạo điều kiện cho trình diễn trên hầu hết các máy móc, chúng tôi sẽ
tải và tinh chỉnh phiên bản nhỏ (“bert.small”) của BERT được đào tạo
trước trong phần này. Trong bài tập, chúng tôi sẽ chỉ ra cách tinh chỉnh
“bert.base” lớn hơn nhiều để cải thiện đáng kể độ chính xác của thử
nghiệm.
.. raw:: html
.. raw:: html
.. code:: python
devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model(
'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
num_layers=2, dropout=0.1, max_len=512, devices=devices)
.. parsed-literal::
:class: output
Downloading ../data/bert.small.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.zip...
.. raw:: html
.. raw:: html
.. code:: python
devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model(
'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
num_layers=2, dropout=0.1, max_len=512, devices=devices)
.. parsed-literal::
:class: output
Downloading ../data/bert.small.torch.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.torch.zip...
.. raw:: html
.. raw:: html
Tập dữ liệu cho tinh chỉnh BERT
-------------------------------
Đối với nhiệm vụ hạ lưu suy luận ngôn ngữ tự nhiên trên tập dữ liệu
SNLI, chúng tôi xác định một lớp tập dữ liệu tùy chỉnh
``SNLIBERTDataset``. Trong mỗi ví dụ, tiền đề và giả thuyết tạo thành
một cặp chuỗi văn bản và được đóng gói thành một chuỗi đầu vào BERT như
mô tả trong :numref:`fig_bert-two-seqs`. Nhớ lại
:numref:`subsec_bert_input_rep` rằng ID phân đoạn được sử dụng để phân
biệt tiền đề và giả thuyết trong một chuỗi đầu vào BERT. Với độ dài tối
đa được xác định trước của chuỗi đầu vào BERT (``max_len``), token cuối
cùng của cặp văn bản đầu vào dài hơn sẽ bị xóa cho đến khi ``max_len``
được đáp ứng. Để tăng tốc tạo bộ dữ liệu SNLI để tinh chỉnh BERT, chúng
tôi sử dụng 4 quy trình công nhân để tạo ra các ví dụ đào tạo hoặc thử
nghiệm song song.
.. raw:: html
.. raw:: html
.. code:: python
class SNLIBERTDataset(gluon.data.Dataset):
def __init__(self, dataset, max_len, vocab=None):
all_premise_hypothesis_tokens = [[
p_tokens, h_tokens] for p_tokens, h_tokens in zip(
*[d2l.tokenize([s.lower() for s in sentences])
for sentences in dataset[:2]])]
self.labels = np.array(dataset[2])
self.vocab = vocab
self.max_len = max_len
(self.all_token_ids, self.all_segments,
self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
print('read ' + str(len(self.all_token_ids)) + ' examples')
def _preprocess(self, all_premise_hypothesis_tokens):
pool = multiprocessing.Pool(4) # Use 4 worker processes
out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
all_token_ids = [
token_ids for token_ids, segments, valid_len in out]
all_segments = [segments for token_ids, segments, valid_len in out]
valid_lens = [valid_len for token_ids, segments, valid_len in out]
return (np.array(all_token_ids, dtype='int32'),
np.array(all_segments, dtype='int32'),
np.array(valid_lens))
def _mp_worker(self, premise_hypothesis_tokens):
p_tokens, h_tokens = premise_hypothesis_tokens
self._truncate_pair_of_tokens(p_tokens, h_tokens)
tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
token_ids = self.vocab[tokens] + [self.vocab['
']] \
* (self.max_len - len(tokens))
segments = segments + [0] * (self.max_len - len(segments))
valid_len = len(tokens)
return token_ids, segments, valid_len
def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
# Reserve slots for '', '', and '' tokens for the BERT
# input
while len(p_tokens) + len(h_tokens) > self.max_len - 3:
if len(p_tokens) > len(h_tokens):
p_tokens.pop()
else:
h_tokens.pop()
def __getitem__(self, idx):
return (self.all_token_ids[idx], self.all_segments[idx],
self.valid_lens[idx]), self.labels[idx]
def __len__(self):
return len(self.all_token_ids)
.. raw:: html
.. raw:: html
.. code:: python
class SNLIBERTDataset(torch.utils.data.Dataset):
def __init__(self, dataset, max_len, vocab=None):
all_premise_hypothesis_tokens = [[
p_tokens, h_tokens] for p_tokens, h_tokens in zip(
*[d2l.tokenize([s.lower() for s in sentences])
for sentences in dataset[:2]])]
self.labels = torch.tensor(dataset[2])
self.vocab = vocab
self.max_len = max_len
(self.all_token_ids, self.all_segments,
self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
print('read ' + str(len(self.all_token_ids)) + ' examples')
def _preprocess(self, all_premise_hypothesis_tokens):
pool = multiprocessing.Pool(4) # Use 4 worker processes
out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
all_token_ids = [
token_ids for token_ids, segments, valid_len in out]
all_segments = [segments for token_ids, segments, valid_len in out]
valid_lens = [valid_len for token_ids, segments, valid_len in out]
return (torch.tensor(all_token_ids, dtype=torch.long),
torch.tensor(all_segments, dtype=torch.long),
torch.tensor(valid_lens))
def _mp_worker(self, premise_hypothesis_tokens):
p_tokens, h_tokens = premise_hypothesis_tokens
self._truncate_pair_of_tokens(p_tokens, h_tokens)
tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
token_ids = self.vocab[tokens] + [self.vocab['
']] \
* (self.max_len - len(tokens))
segments = segments + [0] * (self.max_len - len(segments))
valid_len = len(tokens)
return token_ids, segments, valid_len
def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
# Reserve slots for '', '', and '' tokens for the BERT
# input
while len(p_tokens) + len(h_tokens) > self.max_len - 3:
if len(p_tokens) > len(h_tokens):
p_tokens.pop()
else:
h_tokens.pop()
def __getitem__(self, idx):
return (self.all_token_ids[idx], self.all_segments[idx],
self.valid_lens[idx]), self.labels[idx]
def __len__(self):
return len(self.all_token_ids)
.. raw:: html
.. raw:: html
Sau khi tải xuống tập dữ liệu SNLI, chúng tôi tạo ra các ví dụ đào tạo
và thử nghiệm bằng cách khởi tạo lớp ``SNLIBERTDataset``. Những ví dụ
như vậy sẽ được đọc trong minibatches trong quá trình đào tạo và thử
nghiệm suy luận ngôn ngữ tự nhiên.
.. raw:: html
.. raw:: html
.. code:: python
# Reduce `batch_size` if there is an out of memory error. In the original BERT
# model, `max_len` = 512
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = gluon.data.DataLoader(train_set, batch_size, shuffle=True,
num_workers=num_workers)
test_iter = gluon.data.DataLoader(test_set, batch_size,
num_workers=num_workers)
.. parsed-literal::
:class: output
read 549367 examples
read 9824 examples
.. raw:: html
.. raw:: html
.. code:: python
# Reduce `batch_size` if there is an out of memory error. In the original BERT
# model, `max_len` = 512
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(test_set, batch_size,
num_workers=num_workers)
.. parsed-literal::
:class: output
read 549367 examples
read 9824 examples
.. raw:: html
.. raw:: html
Tinh chỉnh BERT
---------------
Như :numref:`fig_bert-two-seqs` chỉ ra, tinh chỉnh BERT cho suy luận
ngôn ngữ tự nhiên chỉ yêu cầu thêm MLP bao gồm hai lớp được kết nối hoàn
toàn (xem ``self.hidden`` và ``self.output`` trong lớp
``BERTClassifier`` sau). MLP này biến đổi đại diện BERT của mã thông báo
“” đặc biệt, mã hóa thông tin của cả tiền đề và giả thuyết, thành ba đầu
ra của suy luận ngôn ngữ tự nhiên: entailment, mâu thuẫn, và trung lập.
.. raw:: html
.. raw:: html
.. code:: python
class BERTClassifier(nn.Block):
def __init__(self, bert):
super(BERTClassifier, self).__init__()
self.encoder = bert.encoder
self.hidden = bert.hidden
self.output = nn.Dense(3)
def forward(self, inputs):
tokens_X, segments_X, valid_lens_x = inputs
encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
return self.output(self.hidden(encoded_X[:, 0, :]))
.. raw:: html
.. raw:: html
.. code:: python
class BERTClassifier(nn.Module):
def __init__(self, bert):
super(BERTClassifier, self).__init__()
self.encoder = bert.encoder
self.hidden = bert.hidden
self.output = nn.Linear(256, 3)
def forward(self, inputs):
tokens_X, segments_X, valid_lens_x = inputs
encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
return self.output(self.hidden(encoded_X[:, 0, :]))
.. raw:: html
.. raw:: html
Sau đây, mô hình BERT được đào tạo trước ``bert`` được đưa vào phiên bản
``BERTClassifier`` ``net`` cho ứng dụng hạ lưu. Trong các triển khai phổ
biến của tinh chỉnh BERT, chỉ các tham số của lớp đầu ra của MLP bổ sung
(``net.output``) sẽ được học từ đầu. Tất cả các thông số của bộ mã hóa
BERT được đào tạo trước (``net.encoder``) và lớp ẩn của MLP bổ sung
(``net.hidden``) sẽ được tinh chỉnh.
.. raw:: html
.. raw:: html
.. code:: python
net = BERTClassifier(bert)
net.output.initialize(ctx=devices)
.. raw:: html
.. raw:: html
.. code:: python
net = BERTClassifier(bert)
.. raw:: html
.. raw:: html
Nhớ lại rằng trong :numref:`sec_bert` cả lớp ``MaskLM`` và lớp
``NextSentencePred`` đều có các thông số trong MLP được sử dụng của họ.
Các thông số này là một phần của những thông số trong mô hình BERT được
đào tạo trước ``bert``, và do đó là một phần của các thông số trong
``net``. Tuy nhiên, các thông số như vậy chỉ để tính toán mất mô hình
hóa ngôn ngữ đeo mặt nạ và mất dự đoán câu tiếp theo trong quá trình đào
tạo trước. Hai chức năng mất mát này không liên quan đến việc tinh chỉnh
các ứng dụng hạ lưu, do đó các thông số của MLP được sử dụng trong
``MaskLM`` và ``NextSentencePred`` không được cập nhật (staled) khi BERT
được tinh chỉnh.
Để cho phép các tham số với gradient cũ, cờ ``ignore_stale_grad=True``
được đặt trong hàm ``step`` của ``d2l.train_batch_ch13``. Chúng tôi sử
dụng chức năng này để đào tạo và đánh giá mô hình ``net`` bằng cách sử
dụng bộ đào tạo (``train_iter``) và bộ thử nghiệm (``test_iter``) của
SNLI. Do các nguồn lực tính toán hạn chế, độ chính xác đào tạo và thử
nghiệm có thể được cải thiện hơn nữa: chúng tôi để lại các cuộc thảo
luận của nó trong các bài tập.
.. raw:: html
.. raw:: html
.. code:: python
lr, num_epochs = 1e-4, 5
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})
loss = gluon.loss.SoftmaxCrossEntropyLoss()
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices,
d2l.split_batch_multi_inputs)
.. parsed-literal::
:class: output
loss 0.480, train acc 0.810, test acc 0.785
6981.9 examples/sec on [gpu(0), gpu(1)]
.. figure:: output_natural-language-inference-bert_1857e6_75_1.svg
.. raw:: html
.. raw:: html
.. code:: python
lr, num_epochs = 1e-4, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction='none')
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
.. parsed-literal::
:class: output
loss 0.518, train acc 0.791, test acc 0.779
10236.7 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]
.. figure:: output_natural-language-inference-bert_1857e6_78_1.svg
.. raw:: html
.. raw:: html
Tóm tắt
-------
- Chúng ta có thể tinh chỉnh mô hình BERT được đào tạo trước cho các
ứng dụng hạ nguồn, chẳng hạn như suy luận ngôn ngữ tự nhiên trên bộ
dữ liệu SNLI.
- Trong quá trình tinh chỉnh, mô hình BERT trở thành một phần của mô
hình cho ứng dụng hạ lưu. Các thông số chỉ liên quan đến mất sơ bộ sẽ
không được cập nhật trong quá trình tinh chỉnh.
Bài tập
-------
1. Tinh chỉnh một mô hình BERT được đào tạo trước lớn hơn nhiều như mô
hình cơ sở BERT ban đầu nếu tài nguyên tính toán của bạn cho phép.
Đặt các đối số trong hàm ``load_pretrained_model`` là: thay thế
'bert.small' bằng 'bert.base', tăng giá trị lần lượt là
``num_hiddens=256``, ``ffn_num_hiddens=512``, ``num_heads=4`` và
``num_layers=2`` lên 768, 3072, 12 và 12. Bằng cách tăng kỷ nguyên
tinh chỉnh (và có thể điều chỉnh các siêu tham số khác), bạn có thể
nhận được độ chính xác thử nghiệm cao hơn 0,86 không?
2. Làm thế nào để cắt ngắn một cặp chuỗi theo tỷ lệ chiều dài của chúng?
So sánh phương pháp cắt ngắn cặp này và phương pháp được sử dụng
trong lớp ``SNLIBERTDataset``. Ưu và nhược điểm của họ là gì?
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html