.. _sec_nadaraya-watson: Chú ý Pooling: Hồi quy hạt nhân Nadaraya-Watson =============================================== Bây giờ bạn đã biết các thành phần chính của các cơ chế chú ý theo khuôn khổ trong :numref:`fig_qkv`. Để tái lập lại, các tương tác giữa các truy vấn (tín hiệu ý chí) và các phím (tín hiệu không có ý định) dẫn đến *chú ý cùng*. Sự chú ý tập hợp chọn lọc các giá trị (đầu vào cảm giác) để tạo ra đầu ra. Trong phần này, chúng tôi sẽ mô tả chi tiết hơn để cung cấp cho bạn cái nhìn cấp cao về cách các cơ chế chú ý hoạt động trong thực tế. Cụ thể, mô hình hồi quy hạt nhân Nadaraya-Watson đề xuất năm 1964 là một ví dụ đơn giản nhưng đầy đủ để chứng minh máy học với các cơ chế chú ý. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python from mxnet import autograd, gluon, np, npx from mxnet.gluon import nn from d2l import mxnet as d2l npx.set_np() .. raw:: html
.. raw:: html
.. code:: python import torch from torch import nn from d2l import torch as d2l .. raw:: html
.. raw:: html
.. code:: python import tensorflow as tf from d2l import tensorflow as d2l tf.random.set_seed(seed=1322) .. raw:: html
.. raw:: html
Tạo ra bộ dữ liệu ----------------- Để giữ cho mọi thứ đơn giản, chúng ta hãy xem xét vấn đề hồi quy sau: đưa ra một tập dữ liệu của cặp đầu vào-đầu ra :math:`\{(x_1, y_1), \ldots, (x_n, y_n)\}`, làm thế nào để tìm hiểu :math:`f` để dự đoán đầu ra :math:`\hat{y} = f(x)` cho bất kỳ đầu vào mới :math:`x`? Ở đây chúng ta tạo ra một bộ dữ liệu nhân tạo theo hàm phi tuyến sau với thuật ngữ nhiễu :math:`\epsilon`: .. math:: y_i = 2\sin(x_i) + x_i^{0.8} + \epsilon, trong đó :math:`\epsilon` tuân theo một phân phối bình thường với 0 trung bình và độ lệch chuẩn 0,5. Cả 50 ví dụ đào tạo và 50 ví dụ thử nghiệm được tạo ra. Để hình dung rõ hơn mô hình chú ý sau này, các đầu vào đào tạo được sắp xếp. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python n_train = 50 # No. of training examples x_train = np.sort(np.random.rand(n_train) * 5) # Training inputs def f(x): return 2 * np.sin(x) + x**0.8 y_train = f(x_train) + np.random.normal(0.0, 0.5, (n_train,)) # Training outputs x_test = np.arange(0, 5, 0.1) # Testing examples y_truth = f(x_test) # Ground-truth outputs for the testing examples n_test = len(x_test) # No. of testing examples n_test .. parsed-literal:: :class: output 50 .. raw:: html
.. raw:: html
.. code:: python n_train = 50 # No. of training examples x_train, _ = torch.sort(torch.rand(n_train) * 5) # Training inputs def f(x): return 2 * torch.sin(x) + x**0.8 y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,)) # Training outputs x_test = torch.arange(0, 5, 0.1) # Testing examples y_truth = f(x_test) # Ground-truth outputs for the testing examples n_test = len(x_test) # No. of testing examples n_test .. parsed-literal:: :class: output 50 .. raw:: html
.. raw:: html
.. code:: python n_train = 50 x_train = tf.sort(tf.random.uniform(shape=(n_train,), maxval=5)) def f(x): return 2 * tf.sin(x) + x**0.8 y_train = f(x_train) + tf.random.normal((n_train,), 0.0, 0.5) # Training outputs x_test = tf.range(0, 5, 0.1) # Testing examples y_truth = f(x_test) # Ground-truth outputs for the testing examples n_test = len(x_test) # No. of testing examples n_test .. parsed-literal:: :class: output 50 .. raw:: html
.. raw:: html
Hàm sau vẽ tất cả các ví dụ đào tạo (được biểu diễn bằng các vòng tròn), hàm tạo dữ liệu đất-chân lý ``f`` không có thuật ngữ nhiễu (được dán nhãn bởi “Truth”), và hàm dự đoán đã học (được dán nhãn bởi “Pred”). .. code:: python def plot_kernel_reg(y_hat): d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'], xlim=[0, 5], ylim=[-1, 5]) d2l.plt.plot(x_train, y_train, 'o', alpha=0.5); Pooling trung bình ------------------ Chúng ta bắt đầu với sự ước tính “ngu ngốc nhất” của thế giới cho vấn đề hồi quy này: sử dụng tổng hợp trung bình đến trung bình trên tất cả các đầu ra đào tạo: .. math:: f(x) = \frac{1}{n}\sum_{i=1}^n y_i, :label: eq_avg-pooling được vẽ dưới đây. Như chúng ta có thể thấy, ước tính này thực sự không quá thông minh. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python y_hat = y_train.mean().repeat(n_test) plot_kernel_reg(y_hat) .. figure:: output_nadaraya-watson_61a333_29_0.svg .. raw:: html
.. raw:: html
.. code:: python y_hat = torch.repeat_interleave(y_train.mean(), n_test) plot_kernel_reg(y_hat) .. figure:: output_nadaraya-watson_61a333_32_0.svg .. raw:: html
.. raw:: html
.. code:: python y_hat = tf.repeat(tf.reduce_mean(y_train), repeats=n_test) plot_kernel_reg(y_hat) .. figure:: output_nadaraya-watson_61a333_35_0.svg .. raw:: html
.. raw:: html
Không tham số Chú ý Pooling --------------------------- Rõ ràng, tổng hợp trung bình bỏ qua các đầu vào :math:`x_i`. Một ý tưởng tốt hơn đã được Nadaraya :cite:`Nadaraya.1964` và Watson :cite:`Watson.1964` đề xuất để cân nhắc các đầu ra :math:`y_i` theo vị trí đầu vào của chúng: .. math:: f(x) = \sum_{i=1}^n \frac{K(x - x_i)}{\sum_{j=1}^n K(x - x_j)} y_i, :label: eq_nadaraya-watson trong đó :math:`K` là một \* kernel\ *. Các ước tính trong :eq:`eq_nadaraya-watson` được gọi là * Nadaraya-Watson kernel regression\ *. Ở đây chúng tôi sẽ không đi sâu vào chi tiết của hạt nhân. Nhớ lại khuôn khổ của các cơ chế chú ý trong :numref:`fig_qkv`. Từ quan điểm của sự chú ý, chúng ta có thể viết lại :eq:`eq_nadaraya-watson` dưới dạng tổng quát hơn* chú ý cùng\*: .. math:: f(x) = \sum_{i=1}^n \alpha(x, x_i) y_i, :label: eq_attn-pooling trong đó :math:`x` là truy vấn và :math:`(x_i, y_i)` là cặp giá trị khóa-giá trị. So sánh :eq:`eq_attn-pooling` và :eq:`eq_avg-pooling`, sự chú ý tập hợp ở đây là trung bình có trọng số của các giá trị :math:`y_i`. Trọng lượng chú ý\*\ :math:`\alpha(x, x_i)` trong :eq:`eq_attn-pooling` được gán cho giá trị tương ứng :math:`y_i` dựa trên sự tương tác giữa truy vấn :math:`x` và khóa :math:`x_i` được mô hình hóa bởi :math:`\alpha`. Đối với bất kỳ truy vấn nào, trọng lượng chú ý của nó đối với tất cả các cặp khóa-giá trị là một phân phối xác suất hợp lệ: chúng không âm và tổng hợp lên đến một. Để đạt được trực giác của sự chú ý, chỉ cần xem xét một \* Gaussian kernel\* được định nghĩa là .. math:: K(u) = \frac{1}{\sqrt{2\pi}} \exp(-\frac{u^2}{2}). Cắm hạt nhân Gaussian vào :eq:`eq_attn-pooling` và :eq:`eq_nadaraya-watson` cho .. math:: \begin{aligned} f(x) &=\sum_{i=1}^n \alpha(x, x_i) y_i\\ &= \sum_{i=1}^n \frac{\exp\left(-\frac{1}{2}(x - x_i)^2\right)}{\sum_{j=1}^n \exp\left(-\frac{1}{2}(x - x_j)^2\right)} y_i \\&= \sum_{i=1}^n \mathrm{softmax}\left(-\frac{1}{2}(x - x_i)^2\right) y_i. \end{aligned} :label: eq_nadaraya-watson-gaussian Trong :eq:`eq_nadaraya-watson-gaussian`, một khóa :math:`x_i` gần với truy vấn đã cho :math:`x` sẽ nhận được *chú ý hơn* thông qua trọng lượng chú ý \* lớn hơn\* được gán cho giá trị tương ứng của khóa :math:`y_i`. Đáng chú ý, hồi quy hạt nhân Nadaraya-Watson là một mô hình không tham số; do đó :eq:`eq_nadaraya-watson-gaussian` là một ví dụ về *không tham số attention pooling*. Sau đây, chúng tôi vẽ dự đoán dựa trên mô hình chú ý không tham số này. Đường dự đoán là trơn tru và gần với sự thật mặt đất hơn so với sự thật được tạo ra bởi tổng hợp trung bình. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python # Shape of `X_repeat`: (`n_test`, `n_train`), where each row contains the # same testing inputs (i.e., same queries) X_repeat = x_test.repeat(n_train).reshape((-1, n_train)) # Note that `x_train` contains the keys. Shape of `attention_weights`: # (`n_test`, `n_train`), where each row contains attention weights to be # assigned among the values (`y_train`) given each query attention_weights = npx.softmax(-(X_repeat - x_train)**2 / 2) # Each element of `y_hat` is weighted average of values, where weights are # attention weights y_hat = np.dot(attention_weights, y_train) plot_kernel_reg(y_hat) .. figure:: output_nadaraya-watson_61a333_41_0.svg .. raw:: html
.. raw:: html
.. code:: python # Shape of `X_repeat`: (`n_test`, `n_train`), where each row contains the # same testing inputs (i.e., same queries) X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train)) # Note that `x_train` contains the keys. Shape of `attention_weights`: # (`n_test`, `n_train`), where each row contains attention weights to be # assigned among the values (`y_train`) given each query attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1) # Each element of `y_hat` is weighted average of values, where weights are # attention weights y_hat = torch.matmul(attention_weights, y_train) plot_kernel_reg(y_hat) .. figure:: output_nadaraya-watson_61a333_44_0.svg .. raw:: html
.. raw:: html
.. code:: python # Shape of `X_repeat`: (`n_test`, `n_train`), where each row contains the # same testing inputs (i.e., same queries) X_repeat = tf.repeat(tf.expand_dims(x_train, axis=0), repeats=n_train, axis=0) # Note that `x_train` contains the keys. Shape of `attention_weights`: # (`n_test`, `n_train`), where each row contains attention weights to be # assigned among the values (`y_train`) given each query attention_weights = tf.nn.softmax(-(X_repeat - tf.expand_dims(x_train, axis=1))**2/2, axis=1) # Each element of `y_hat` is weighted average of values, where weights are attention weights y_hat = tf.matmul(attention_weights, tf.expand_dims(y_train, axis=1)) plot_kernel_reg(y_hat) .. figure:: output_nadaraya-watson_61a333_47_0.svg .. raw:: html
.. raw:: html
Bây giờ chúng ta hãy nhìn vào trọng lượng chú ý. Ở đây thử nghiệm đầu vào là các truy vấn trong khi đầu vào đào tạo là chìa khóa. Vì cả hai đầu vào được sắp xếp, chúng ta có thể thấy rằng cặp khóa truy vấn càng gần, trọng lượng chú ý cao hơn nằm trong sự chú ý. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python d2l.show_heatmaps(np.expand_dims(np.expand_dims(attention_weights, 0), 0), xlabel='Sorted training inputs', ylabel='Sorted testing inputs') .. figure:: output_nadaraya-watson_61a333_53_0.svg .. raw:: html
.. raw:: html
.. code:: python d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0), xlabel='Sorted training inputs', ylabel='Sorted testing inputs') .. figure:: output_nadaraya-watson_61a333_56_0.svg .. raw:: html
.. raw:: html
.. code:: python d2l.show_heatmaps(tf.expand_dims(tf.expand_dims(attention_weights, axis=0), axis=0), xlabel='Sorted training inputs', ylabel='Sorted testing inputs') .. figure:: output_nadaraya-watson_61a333_59_0.svg .. raw:: html
.. raw:: html
\*\* Tham số chú ý Pooling\*\* ------------------------------ Hồi quy hạt nhân Nadaraya-Watson không tham số được hưởng lợi ích *nhất quán *: cung cấp đủ dữ liệu mô hình này hội tụ với giải pháp tối ưu. Tuy nhiên, chúng ta có thể dễ dàng tích hợp các thông số có thể học được vào tập hợp sự chú ý. Ví dụ, hơi khác so với :eq:`eq_nadaraya-watson-gaussian`, trong khoảng cách sau giữa truy vấn :math:`x` và khóa :math:`x_i` được nhân với một tham số có thể học được :math:`w`: .. math:: \begin{aligned}f(x) &= \sum_{i=1}^n \alpha(x, x_i) y_i \\&= \sum_{i=1}^n \frac{\exp\left(-\frac{1}{2}((x - x_i)w)^2\right)}{\sum_{j=1}^n \exp\left(-\frac{1}{2}((x - x_j)w)^2\right)} y_i \\&= \sum_{i=1}^n \mathrm{softmax}\left(-\frac{1}{2}((x - x_i)w)^2\right) y_i.\end{aligned} :label: eq_nadaraya-watson-gaussian-para Trong phần còn lại của phần này, chúng tôi sẽ đào tạo mô hình này bằng cách học tham số của sự chú ý trong :eq:`eq_nadaraya-watson-gaussian-para`. .. _subsec_batch_dot: Phép nhân ma trận hàng loạt ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Để tính toán sự chú ý hiệu quả hơn cho các minibatches, chúng ta có thể tận dụng các tiện ích nhân ma trận hàng loạt được cung cấp bởi các framework deep learning. Giả sử rằng minibatch đầu tiên chứa :math:`n` ma trận :math:`\mathbf{X}_1, \ldots, \mathbf{X}_n` của hình dạng :math:`a\times b`, và minibatch thứ hai chứa :math:`n` ma trận :math:`\mathbf{Y}_1, \ldots, \mathbf{Y}_n` hình dạng :math:`b\times c`. Phép nhân ma trận lô của chúng dẫn đến ma trận :math:`n` :math:`\mathbf{X}_1\mathbf{Y}_1, \ldots, \mathbf{X}_n\mathbf{Y}_n` của hình dạng :math:`a\times c`. Do đó, cho hai hàng chục hình dạng (:math:`n`, :math:`a`, :math:`b`) và (:math:`n`, :math:`b`, :math:`c`), hình dạng của sản lượng nhân ma trận lô của chúng là (:math:`n`, :math:`a`, :math:`c`) . .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python X = np.ones((2, 1, 4)) Y = np.ones((2, 4, 6)) npx.batch_dot(X, Y).shape .. parsed-literal:: :class: output (2, 1, 6) .. raw:: html
.. raw:: html
.. code:: python X = torch.ones((2, 1, 4)) Y = torch.ones((2, 4, 6)) torch.bmm(X, Y).shape .. parsed-literal:: :class: output torch.Size([2, 1, 6]) .. raw:: html
.. raw:: html
.. code:: python X = tf.ones((2, 1, 4)) Y = tf.ones((2, 4, 6)) tf.matmul(X, Y).shape .. parsed-literal:: :class: output TensorShape([2, 1, 6]) .. raw:: html
.. raw:: html
Trong bối cảnh của các cơ chế chú ý, chúng ta có thể sử dụng phép nhân ma trận minibatch để tính toán trung bình có trọng số của các giá trị trong một minibatch. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python weights = np.ones((2, 10)) * 0.1 values = np.arange(20).reshape((2, 10)) npx.batch_dot(np.expand_dims(weights, 1), np.expand_dims(values, -1)) .. parsed-literal:: :class: output array([[[ 4.5]], [[14.5]]]) .. raw:: html
.. raw:: html
.. code:: python weights = torch.ones((2, 10)) * 0.1 values = torch.arange(20.0).reshape((2, 10)) torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1)) .. parsed-literal:: :class: output tensor([[[ 4.5000]], [[14.5000]]]) .. raw:: html
.. raw:: html
.. code:: python weights = tf.ones((2, 10)) * 0.1 values = tf.reshape(tf.range(20.0), shape = (2, 10)) tf.matmul(tf.expand_dims(weights, axis=1), tf.expand_dims(values, axis=-1)).numpy() .. parsed-literal:: :class: output array([[[ 4.5]], [[14.5]]], dtype=float32) .. raw:: html
.. raw:: html
Xác định mô hình ~~~~~~~~~~~~~~~~ Sử dụng phép nhân ma trận minibatch, bên dưới chúng ta xác định phiên bản tham số của hồi quy hạt nhân Nadaraya-Watson dựa trên parametric attention pooling trong :eq:`eq_nadaraya-watson-gaussian-para`. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python class NWKernelRegression(nn.Block): def __init__(self, **kwargs): super().__init__(**kwargs) self.w = self.params.get('w', shape=(1,)) def forward(self, queries, keys, values): # Shape of the output `queries` and `attention_weights`: # (no. of queries, no. of key-value pairs) queries = queries.repeat(keys.shape[1]).reshape((-1, keys.shape[1])) self.attention_weights = npx.softmax( -((queries - keys) * self.w.data())**2 / 2) # Shape of `values`: (no. of queries, no. of key-value pairs) return npx.batch_dot(np.expand_dims(self.attention_weights, 1), np.expand_dims(values, -1)).reshape(-1) .. raw:: html
.. raw:: html
.. code:: python class NWKernelRegression(nn.Module): def __init__(self, **kwargs): super().__init__(**kwargs) self.w = nn.Parameter(torch.rand((1,), requires_grad=True)) def forward(self, queries, keys, values): # Shape of the output `queries` and `attention_weights`: # (no. of queries, no. of key-value pairs) queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1])) self.attention_weights = nn.functional.softmax( -((queries - keys) * self.w)**2 / 2, dim=1) # Shape of `values`: (no. of queries, no. of key-value pairs) return torch.bmm(self.attention_weights.unsqueeze(1), values.unsqueeze(-1)).reshape(-1) .. raw:: html
.. raw:: html
.. code:: python class NWKernelRegression(tf.keras.layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) self.w = tf.Variable(initial_value=tf.random.uniform(shape=(1,))) def call(self, queries, keys, values, **kwargs): # For training queries are `x_train`. Keys are distance of taining data for each point. Values are `y_train`. # Shape of the output `queries` and `attention_weights`: (no. of queries, no. of key-value pairs) queries = tf.repeat(tf.expand_dims(queries, axis=1), repeats=keys.shape[1], axis=1) self.attention_weights = tf.nn.softmax(-((queries - keys) * self.w)**2 /2, axis =1) # Shape of `values`: (no. of queries, no. of key-value pairs) return tf.squeeze(tf.matmul(tf.expand_dims(self.attention_weights, axis=1), tf.expand_dims(values, axis=-1))) .. raw:: html
.. raw:: html
Đào tạo ~~~~~~~ Sau đây, chúng tôi chuyển đổi tập dữ liệu đào tạo thành các khóa và giá trị để đào tạo mô hình chú ý. Trong tập hợp sự chú ý tham số, bất kỳ đầu vào đào tạo nào lấy các cặp giá trị khóa từ tất cả các ví dụ đào tạo ngoại trừ chính nó để dự đoán đầu ra của nó. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python # Shape of `X_tile`: (`n_train`, `n_train`), where each column contains the # same training inputs X_tile = np.tile(x_train, (n_train, 1)) # Shape of `Y_tile`: (`n_train`, `n_train`), where each column contains the # same training outputs Y_tile = np.tile(y_train, (n_train, 1)) # Shape of `keys`: ('n_train', 'n_train' - 1) keys = X_tile[(1 - np.eye(n_train)).astype('bool')].reshape((n_train, -1)) # Shape of `values`: ('n_train', 'n_train' - 1) values = Y_tile[(1 - np.eye(n_train)).astype('bool')].reshape((n_train, -1)) .. raw:: html
.. raw:: html
.. code:: python # Shape of `X_tile`: (`n_train`, `n_train`), where each column contains the # same training inputs X_tile = x_train.repeat((n_train, 1)) # Shape of `Y_tile`: (`n_train`, `n_train`), where each column contains the # same training outputs Y_tile = y_train.repeat((n_train, 1)) # Shape of `keys`: ('n_train', 'n_train' - 1) keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1)) # Shape of `values`: ('n_train', 'n_train' - 1) values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1)) .. raw:: html
.. raw:: html
.. code:: python # Shape of `X_tile`: (`n_train`, `n_train`), where each column contains the # same training inputs X_tile = tf.repeat(tf.expand_dims(x_train, axis=0), repeats=n_train, axis=0) # Shape of `Y_tile`: (`n_train`, `n_train`), where each column contains the # same training outputs Y_tile = tf.repeat(tf.expand_dims(y_train, axis=0), repeats=n_train, axis=0) # Shape of `keys`: ('n_train', 'n_train' - 1) keys = tf.reshape(X_tile[tf.cast(1 - tf.eye(n_train), dtype=tf.bool)], shape=(n_train, -1)) # Shape of `values`: ('n_train', 'n_train' - 1) values = tf.reshape(Y_tile[tf.cast(1 - tf.eye(n_train), dtype=tf.bool)], shape=(n_train, -1)) .. raw:: html
.. raw:: html
Sử dụng sự mất mát bình phương và gốc gradient ngẫu nhiên, chúng tôi đào tạo mô hình chú ý tham số. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python net = NWKernelRegression() net.initialize() loss = gluon.loss.L2Loss() trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.5}) animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5]) for epoch in range(5): with autograd.record(): l = loss(net(x_train, keys, values), y_train) l.backward() trainer.step(1) print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}') animator.add(epoch + 1, float(l.sum())) .. figure:: output_nadaraya-watson_61a333_113_0.svg .. raw:: html
.. raw:: html
.. code:: python net = NWKernelRegression() loss = nn.MSELoss(reduction='none') trainer = torch.optim.SGD(net.parameters(), lr=0.5) animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5]) for epoch in range(5): trainer.zero_grad() l = loss(net(x_train, keys, values), y_train) l.sum().backward() trainer.step() print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}') animator.add(epoch + 1, float(l.sum())) .. figure:: output_nadaraya-watson_61a333_116_0.svg .. raw:: html
.. raw:: html
.. code:: python net = NWKernelRegression() loss_object = tf.keras.losses.MeanSquaredError() optimizer = tf.keras.optimizers.SGD(learning_rate=0.5) animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5]) for epoch in range(5): with tf.GradientTape() as t: loss = loss_object(y_train, net(x_train, keys, values)) * len(y_train) grads = t.gradient(loss, net.trainable_variables) optimizer.apply_gradients(zip(grads, net.trainable_variables)) print(f'epoch {epoch + 1}, loss {float(loss):.6f}') animator.add(epoch + 1, float(loss)) .. figure:: output_nadaraya-watson_61a333_119_0.svg .. raw:: html
.. raw:: html
Sau khi đào tạo mô hình chú ý tham số, chúng ta có thể vẽ giá của nó. Cố gắng để phù hợp với tập dữ liệu đào tạo với tiếng ồn, dòng dự đoán là ít trơn tru hơn so với đối tác không tham số của nó đã được vẽ trước đó. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python # Shape of `keys`: (`n_test`, `n_train`), where each column contains the same # training inputs (i.e., same keys) keys = np.tile(x_train, (n_test, 1)) # Shape of `value`: (`n_test`, `n_train`) values = np.tile(y_train, (n_test, 1)) y_hat = net(x_test, keys, values) plot_kernel_reg(y_hat) .. figure:: output_nadaraya-watson_61a333_125_0.svg .. raw:: html
.. raw:: html
.. code:: python # Shape of `keys`: (`n_test`, `n_train`), where each column contains the same # training inputs (i.e., same keys) keys = x_train.repeat((n_test, 1)) # Shape of `value`: (`n_test`, `n_train`) values = y_train.repeat((n_test, 1)) y_hat = net(x_test, keys, values).unsqueeze(1).detach() plot_kernel_reg(y_hat) .. figure:: output_nadaraya-watson_61a333_128_0.svg .. raw:: html
.. raw:: html
.. code:: python # Shape of `keys`: (`n_test`, `n_train`), where each column contains the same # training inputs (i.e., same keys) keys = tf.repeat(tf.expand_dims(x_train, axis=0), repeats=n_test, axis=0) # Shape of `value`: (`n_test`, `n_train`) values = tf.repeat(tf.expand_dims(y_train, axis=0), repeats=n_test, axis=0) y_hat = net(x_test, keys, values) plot_kernel_reg(y_hat) .. figure:: output_nadaraya-watson_61a333_131_0.svg .. raw:: html
.. raw:: html
So sánh với sự chú ý không tham số, khu vực có trọng lượng chú ý lớn trở nên sắc bén hơn trong cài đặt có thể học được và tham số. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. code:: python d2l.show_heatmaps(np.expand_dims(np.expand_dims(net.attention_weights, 0), 0), xlabel='Sorted training inputs', ylabel='Sorted testing inputs') .. figure:: output_nadaraya-watson_61a333_137_0.svg .. raw:: html
.. raw:: html
.. code:: python d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0), xlabel='Sorted training inputs', ylabel='Sorted testing inputs') .. figure:: output_nadaraya-watson_61a333_140_0.svg .. raw:: html
.. raw:: html
.. code:: python d2l.show_heatmaps(tf.expand_dims(tf.expand_dims(net.attention_weights, axis=0), axis=0), xlabel='Sorted training inputs', ylabel='Sorted testing inputs') .. figure:: output_nadaraya-watson_61a333_143_0.svg .. raw:: html
.. raw:: html
Tóm tắt ------- - Hồi quy hạt nhân Nadaraya-Watson là một ví dụ về máy học với các cơ chế chú ý. - Sự chú ý của hồi quy hạt nhân Nadaraya-Watson là một trung bình trọng số của các đầu ra đào tạo. Từ góc độ chú ý, trọng lượng chú ý được gán cho một giá trị dựa trên một hàm của truy vấn và khóa được ghép nối với giá trị. - Chú ý pooling có thể là một trong hai không tham số hoặc tham số. Bài tập ------- 1. Tăng số lượng các ví dụ đào tạo. Bạn có thể học hồi quy hạt nhân Nadaraya-Watson không tham số tốt hơn? 2. Giá trị của :math:`w` đã học được của chúng tôi trong thí nghiệm tập hợp chú ý tham số là gì? Tại sao nó làm cho vùng có trọng số sắc nét hơn khi hình dung trọng lượng chú ý? 3. Làm thế nào chúng ta có thể thêm các siêu tham số vào hồi quy hạt nhân Nadaraya-Watson không tham số để dự đoán tốt hơn? 4. Thiết kế một tập hợp chú ý tham số khác cho hồi quy hạt nhân của phần này. Đào tạo mô hình mới này và hình dung trọng lượng chú ý của nó. .. raw:: html
mxnetpytorch
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html