.. _sec_word2vec_data: Các Dataset cho Pretraining Word Embeddings =========================================== Bây giờ chúng ta đã biết các chi tiết kỹ thuật của các mô hình word2vec và các phương pháp đào tạo gần đúng, chúng ta hãy đi qua các triển khai của họ. Cụ thể, chúng tôi sẽ lấy mô hình skip-gram trong :numref:`sec_word2vec` và lấy mẫu âm trong :numref:`sec_approx_train` làm ví dụ. Trong phần này, chúng ta bắt đầu với tập dữ liệu để đào tạo trước mô hình nhúng từ: định dạng ban đầu của dữ liệu sẽ được chuyển thành các minibatches có thể được lặp lại trong quá trình đào tạo. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python import math import os import random from mxnet import gluon, np from d2l import mxnet as d2l .. raw:: html
.. raw:: html
.. code:: python import math import os import random import torch from d2l import torch as d2l .. raw:: html
.. raw:: html
Đọc tập dữ liệu --------------- Tập dữ liệu mà chúng tôi sử dụng ở đây là `Penn Tree Bank (PTB) `__. Cơ sở này được lấy mẫu từ các bài báo của Wall Street Journal, được chia thành các bộ đào tạo, xác nhận và kiểm tra. Ở định dạng ban đầu, mỗi dòng của tệp văn bản đại diện cho một câu của các từ được phân tách bằng dấu cách. Ở đây chúng ta coi từng từ như một mã thông báo. .. raw:: html
.. raw:: html
.. code:: python #@save d2l.DATA_HUB['ptb'] = (d2l.DATA_URL + 'ptb.zip', '319d85e578af0cdc590547f26231e4e31cdf1e42') #@save def read_ptb(): """Load the PTB dataset into a list of text lines.""" data_dir = d2l.download_extract('ptb') # Read the training set. with open(os.path.join(data_dir, 'ptb.train.txt')) as f: raw_text = f.read() return [line.split() for line in raw_text.split('\n')] sentences = read_ptb() f'# sentences: {len(sentences)}' .. parsed-literal:: :class: output Downloading ../data/ptb.zip from http://d2l-data.s3-accelerate.amazonaws.com/ptb.zip... .. parsed-literal:: :class: output '# sentences: 42069' .. raw:: html
.. raw:: html
.. code:: python #@save d2l.DATA_HUB['ptb'] = (d2l.DATA_URL + 'ptb.zip', '319d85e578af0cdc590547f26231e4e31cdf1e42') #@save def read_ptb(): """Load the PTB dataset into a list of text lines.""" data_dir = d2l.download_extract('ptb') # Read the training set. with open(os.path.join(data_dir, 'ptb.train.txt')) as f: raw_text = f.read() return [line.split() for line in raw_text.split('\n')] sentences = read_ptb() f'# sentences: {len(sentences)}' .. parsed-literal:: :class: output Downloading ../data/ptb.zip from http://d2l-data.s3-accelerate.amazonaws.com/ptb.zip... .. parsed-literal:: :class: output '# sentences: 42069' .. raw:: html
.. raw:: html
Sau khi đọc bộ đào tạo, chúng tôi xây dựng một từ vựng cho corpus, trong đó bất kỳ từ nào xuất hiện dưới 10 lần được thay thế bằng mã thông báo "". Lưu ý rằng tập dữ liệu gốc cũng chứa "" token đại diện cho các từ hiếm (không xác định). .. raw:: html
.. raw:: html
.. code:: python vocab = d2l.Vocab(sentences, min_freq=10) f'vocab size: {len(vocab)}' .. parsed-literal:: :class: output 'vocab size: 6719' .. raw:: html
.. raw:: html
.. code:: python vocab = d2l.Vocab(sentences, min_freq=10) f'vocab size: {len(vocab)}' .. parsed-literal:: :class: output 'vocab size: 6719' .. raw:: html
.. raw:: html
Lấy mẫu phụ ----------- Dữ liệu văn bản thường có các từ tần số cao như “the”, “a”, và “in”: chúng thậm chí có thể xảy ra hàng tỷ lần trong thể rất lớn. Tuy nhiên, những từ này thường đồng xuất hiện với nhiều từ khác nhau trong các cửa sổ ngữ cảnh, cung cấp ít tín hiệu hữu ích. Ví dụ, hãy xem xét từ “chip” trong một cửa sổ ngữ cảnh: trực giác sự xuất hiện của nó với một từ tần số thấp “intel” hữu ích hơn trong đào tạo hơn là đồng xuất hiện với một từ tần số cao “a”. Hơn nữa, đào tạo với một lượng lớn các từ (tần số cao) là chậm. Do đó, khi đào tạo từ nhúng mô hình, các từ tần số cao có thể là \* mẫu con\* :cite:`Mikolov.Sutskever.Chen.ea.2013`. Cụ thể, mỗi từ được lập chỉ mục :math:`w_i` trong tập dữ liệu sẽ bị loại bỏ với xác suất .. math:: P(w_i) = \max\left(1 - \sqrt{\frac{t}{f(w_i)}}, 0\right), trong đó :math:`f(w_i)` là tỷ lệ của số từ :math:`w_i` với tổng số từ trong tập dữ liệu và hằng số :math:`t` là một siêu tham số (:math:`10^{-4}` trong thí nghiệm). Chúng ta có thể thấy rằng chỉ khi tần số tương đối :math:`f(w_i) > t`, từ :math:`w_i` mới có thể bị loại bỏ và tần số tương đối của từ càng cao thì xác suất bị loại bỏ càng lớn. .. raw:: html
.. raw:: html
.. code:: python #@save def subsample(sentences, vocab): """Subsample high-frequency words.""" # Exclude unknown tokens '' sentences = [[token for token in line if vocab[token] != vocab.unk] for line in sentences] counter = d2l.count_corpus(sentences) num_tokens = sum(counter.values()) # Return True if `token` is kept during subsampling def keep(token): return(random.uniform(0, 1) < math.sqrt(1e-4 / counter[token] * num_tokens)) return ([[token for token in line if keep(token)] for line in sentences], counter) subsampled, counter = subsample(sentences, vocab) .. raw:: html
.. raw:: html
.. code:: python #@save def subsample(sentences, vocab): """Subsample high-frequency words.""" # Exclude unknown tokens '' sentences = [[token for token in line if vocab[token] != vocab.unk] for line in sentences] counter = d2l.count_corpus(sentences) num_tokens = sum(counter.values()) # Return True if `token` is kept during subsampling def keep(token): return(random.uniform(0, 1) < math.sqrt(1e-4 / counter[token] * num_tokens)) return ([[token for token in line if keep(token)] for line in sentences], counter) subsampled, counter = subsample(sentences, vocab) .. raw:: html
.. raw:: html
Đoạn mã sau vẽ biểu đồ của số lượng mã thông báo trên mỗi câu trước và sau khi lấy mẫu. Đúng như dự đoán, subsampling rút ngắn đáng kể các câu bằng cách thả các từ tần số cao, điều này sẽ dẫn đến tăng tốc đào tạo. .. raw:: html
.. raw:: html
.. code:: python d2l.show_list_len_pair_hist(['origin', 'subsampled'], '# tokens per sentence', 'count', sentences, subsampled); .. figure:: output_word-embedding-dataset_f77071_39_0.svg .. raw:: html
.. raw:: html
.. code:: python d2l.show_list_len_pair_hist(['origin', 'subsampled'], '# tokens per sentence', 'count', sentences, subsampled); .. figure:: output_word-embedding-dataset_f77071_42_0.svg .. raw:: html
.. raw:: html
Đối với mã thông báo riêng lẻ, tỷ lệ lấy mẫu của từ tần số cao “the” nhỏ hơn 1/20. .. raw:: html
.. raw:: html
.. code:: python def compare_counts(token): return (f'# of "{token}": ' f'before={sum([l.count(token) for l in sentences])}, ' f'after={sum([l.count(token) for l in subsampled])}') compare_counts('the') .. parsed-literal:: :class: output '# of "the": before=50770, after=2081' .. raw:: html
.. raw:: html
.. code:: python def compare_counts(token): return (f'# of "{token}": ' f'before={sum([l.count(token) for l in sentences])}, ' f'after={sum([l.count(token) for l in subsampled])}') compare_counts('the') .. parsed-literal:: :class: output '# of "the": before=50770, after=2033' .. raw:: html
.. raw:: html
Ngược lại, các từ tần số thấp “tham gia” được giữ hoàn toàn. .. raw:: html
.. raw:: html
.. code:: python compare_counts('join') .. parsed-literal:: :class: output '# of "join": before=45, after=45' .. raw:: html
.. raw:: html
.. code:: python compare_counts('join') .. parsed-literal:: :class: output '# of "join": before=45, after=45' .. raw:: html
.. raw:: html
Sau khi lấy mẫu đăng ký, chúng tôi ánh xạ mã thông báo đến các chỉ số của họ cho corpus. .. raw:: html
.. raw:: html
.. code:: python corpus = [vocab[line] for line in subsampled] corpus[:3] .. parsed-literal:: :class: output [[], [2115, 1], [22, 5277, 3054, 1580, 95]] .. raw:: html
.. raw:: html
.. code:: python corpus = [vocab[line] for line in subsampled] corpus[:3] .. parsed-literal:: :class: output [[], [392, 2115, 145, 274, 406], [5277, 3054, 1580]] .. raw:: html
.. raw:: html
Trích xuất từ trung tâm và từ ngữ cảnh -------------------------------------- Hàm ``get_centers_and_contexts`` sau trích xuất tất cả các từ trung tâm và từ ngữ cảnh của chúng từ ``corpus``. Nó đồng đều mẫu một số nguyên giữa 1 và ``max_window_size`` một cách ngẫu nhiên như kích thước cửa sổ ngữ cảnh. Đối với bất kỳ từ trung tâm nào, những từ có khoảng cách từ nó không vượt quá kích thước cửa sổ ngữ cảnh được lấy mẫu là các từ ngữ cảnh của nó. .. raw:: html
.. raw:: html
.. code:: python #@save def get_centers_and_contexts(corpus, max_window_size): """Return center words and context words in skip-gram.""" centers, contexts = [], [] for line in corpus: # To form a "center word--context word" pair, each sentence needs to # have at least 2 words if len(line) < 2: continue centers += line for i in range(len(line)): # Context window centered at `i` window_size = random.randint(1, max_window_size) indices = list(range(max(0, i - window_size), min(len(line), i + 1 + window_size))) # Exclude the center word from the context words indices.remove(i) contexts.append([line[idx] for idx in indices]) return centers, contexts .. raw:: html
.. raw:: html
.. code:: python #@save def get_centers_and_contexts(corpus, max_window_size): """Return center words and context words in skip-gram.""" centers, contexts = [], [] for line in corpus: # To form a "center word--context word" pair, each sentence needs to # have at least 2 words if len(line) < 2: continue centers += line for i in range(len(line)): # Context window centered at `i` window_size = random.randint(1, max_window_size) indices = list(range(max(0, i - window_size), min(len(line), i + 1 + window_size))) # Exclude the center word from the context words indices.remove(i) contexts.append([line[idx] for idx in indices]) return centers, contexts .. raw:: html
.. raw:: html
Tiếp theo, chúng ta tạo ra một tập dữ liệu nhân tạo có chứa hai câu 7 và 3 từ, tương ứng. Hãy để kích thước cửa sổ ngữ cảnh tối đa là 2 và in tất cả các từ trung tâm và các từ ngữ cảnh của chúng. .. raw:: html
.. raw:: html
.. code:: python tiny_dataset = [list(range(7)), list(range(7, 10))] print('dataset', tiny_dataset) for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)): print('center', center, 'has contexts', context) .. parsed-literal:: :class: output dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]] center 0 has contexts [1, 2] center 1 has contexts [0, 2, 3] center 2 has contexts [1, 3] center 3 has contexts [1, 2, 4, 5] center 4 has contexts [3, 5] center 5 has contexts [4, 6] center 6 has contexts [4, 5] center 7 has contexts [8, 9] center 8 has contexts [7, 9] center 9 has contexts [8] .. raw:: html
.. raw:: html
.. code:: python tiny_dataset = [list(range(7)), list(range(7, 10))] print('dataset', tiny_dataset) for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)): print('center', center, 'has contexts', context) .. parsed-literal:: :class: output dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]] center 0 has contexts [1, 2] center 1 has contexts [0, 2, 3] center 2 has contexts [0, 1, 3, 4] center 3 has contexts [1, 2, 4, 5] center 4 has contexts [2, 3, 5, 6] center 5 has contexts [3, 4, 6] center 6 has contexts [4, 5] center 7 has contexts [8, 9] center 8 has contexts [7, 9] center 9 has contexts [8] .. raw:: html
.. raw:: html
Khi đào tạo trên tập dữ liệu PTB, chúng tôi đặt kích thước cửa sổ ngữ cảnh tối đa là 5. Sau đây trích xuất tất cả các từ trung tâm và các từ ngữ cảnh của chúng trong tập dữ liệu. .. raw:: html
.. raw:: html
.. code:: python all_centers, all_contexts = get_centers_and_contexts(corpus, 5) f'# center-context pairs: {sum([len(contexts) for contexts in all_contexts])}' .. parsed-literal:: :class: output '# center-context pairs: 1500272' .. raw:: html
.. raw:: html
.. code:: python all_centers, all_contexts = get_centers_and_contexts(corpus, 5) f'# center-context pairs: {sum([len(contexts) for contexts in all_contexts])}' .. parsed-literal:: :class: output '# center-context pairs: 1497165' .. raw:: html
.. raw:: html
Lấy mẫu âm ---------- Chúng tôi sử dụng lấy mẫu tiêu cực cho đào tạo gần đúng. Để lấy mẫu các từ nhiễu theo một phân phối được xác định trước, chúng ta xác định lớp ``RandomGenerator`` sau, trong đó phân phối lấy mẫu (có thể không chuẩn hóa) được truyền qua đối số ``sampling_weights``. .. raw:: html
.. raw:: html
.. code:: python #@save class RandomGenerator: """Randomly draw among {1, ..., n} according to n sampling weights.""" def __init__(self, sampling_weights): # Exclude self.population = list(range(1, len(sampling_weights) + 1)) self.sampling_weights = sampling_weights self.candidates = [] self.i = 0 def draw(self): if self.i == len(self.candidates): # Cache `k` random sampling results self.candidates = random.choices( self.population, self.sampling_weights, k=10000) self.i = 0 self.i += 1 return self.candidates[self.i - 1] .. raw:: html
.. raw:: html
.. code:: python #@save class RandomGenerator: """Randomly draw among {1, ..., n} according to n sampling weights.""" def __init__(self, sampling_weights): # Exclude self.population = list(range(1, len(sampling_weights) + 1)) self.sampling_weights = sampling_weights self.candidates = [] self.i = 0 def draw(self): if self.i == len(self.candidates): # Cache `k` random sampling results self.candidates = random.choices( self.population, self.sampling_weights, k=10000) self.i = 0 self.i += 1 return self.candidates[self.i - 1] .. raw:: html
.. raw:: html
Ví dụ: chúng ta có thể vẽ 10 biến ngẫu nhiên :math:`X` trong số các chỉ số 1, 2 và 3 với xác suất lấy mẫu :math:`P(X=1)=2/9, P(X=2)=3/9` và :math:`P(X=3)=4/9` như sau. .. raw:: html
.. raw:: html
.. code:: python generator = RandomGenerator([2, 3, 4]) [generator.draw() for _ in range(10)] .. parsed-literal:: :class: output [1, 1, 3, 2, 3, 3, 1, 2, 1, 2] .. raw:: html
.. raw:: html
Đối với một cặp từ trung tâm và từ ngữ cảnh, chúng tôi lấy mẫu ngẫu nhiên ``K`` (5 trong thí nghiệm) các từ tiếng ồn. Theo các đề xuất trong bài báo word2vec, xác suất lấy mẫu :math:`P(w)` của một từ tiếng ồn :math:`w` được đặt thành tần số tương đối của nó trong từ điển nâng lên công suất 0,75 :cite:`Mikolov.Sutskever.Chen.ea.2013`. .. raw:: html
.. raw:: html
.. code:: python #@save def get_negatives(all_contexts, vocab, counter, K): """Return noise words in negative sampling.""" # Sampling weights for words with indices 1, 2, ... (index 0 is the # excluded unknown token) in the vocabulary sampling_weights = [counter[vocab.to_tokens(i)]**0.75 for i in range(1, len(vocab))] all_negatives, generator = [], RandomGenerator(sampling_weights) for contexts in all_contexts: negatives = [] while len(negatives) < len(contexts) * K: neg = generator.draw() # Noise words cannot be context words if neg not in contexts: negatives.append(neg) all_negatives.append(negatives) return all_negatives all_negatives = get_negatives(all_contexts, vocab, counter, 5) .. raw:: html
.. raw:: html
.. code:: python #@save def get_negatives(all_contexts, vocab, counter, K): """Return noise words in negative sampling.""" # Sampling weights for words with indices 1, 2, ... (index 0 is the # excluded unknown token) in the vocabulary sampling_weights = [counter[vocab.to_tokens(i)]**0.75 for i in range(1, len(vocab))] all_negatives, generator = [], RandomGenerator(sampling_weights) for contexts in all_contexts: negatives = [] while len(negatives) < len(contexts) * K: neg = generator.draw() # Noise words cannot be context words if neg not in contexts: negatives.append(neg) all_negatives.append(negatives) return all_negatives all_negatives = get_negatives(all_contexts, vocab, counter, 5) .. raw:: html
.. raw:: html
.. _subsec_word2vec-minibatch-loading: Tải ví dụ đào tạo trong Minibatches ----------------------------------- Sau khi tất cả các từ trung tâm cùng với các từ ngữ cảnh và các từ tiếng ồn được lấy mẫu được trích xuất, chúng sẽ được chuyển thành các ví dụ nhỏ có thể được tải lặp lại trong quá trình đào tạo. Trong một minibatch, ví dụ :math:`i^\mathrm{th}` bao gồm một từ trung tâm và :math:`n_i` từ ngữ cảnh của nó và :math:`m_i` từ nhiễu. Do kích thước cửa sổ ngữ cảnh khác nhau, :math:`n_i+m_i` thay đổi cho :math:`i` khác nhau. Do đó, đối với mỗi ví dụ, chúng tôi nối các từ ngữ cảnh và các từ nhiễu của nó trong biến ``contexts_negatives`` và pad số không cho đến khi độ dài nối đạt :math:`\max_i n_i+m_i` (``max_len``). Để loại trừ các miếng đệm trong tính toán tổn thất, chúng tôi xác định một biến mặt nạ ``masks``. Có một sự tương ứng một-một giữa các phần tử trong ``masks`` và các phần tử trong ``contexts_negatives``, trong đó số không (nếu không) trong ``masks`` tương ứng với các miếng đệm trong ``contexts_negatives``. Để phân biệt giữa các ví dụ tích cực và tiêu cực, chúng ta tách các từ ngữ cảnh khỏi các từ nhiễu trong ``contexts_negatives`` thông qua một biến ``labels``. Tương tự như ``masks``, cũng có sự tương ứng một-một giữa các phần tử trong ``labels`` và các phần tử trong ``contexts_negatives``, trong đó các phần tử (nếu không số không) trong ``labels`` tương ứng với các từ ngữ cảnh (ví dụ tích cực) trong ``contexts_negatives``. Ý tưởng trên được thực hiện trong hàm ``batchify`` sau. Đầu vào của nó ``data`` là một danh sách có độ dài bằng với kích thước lô, trong đó mỗi phần tử là một ví dụ bao gồm từ trung tâm ``center``, các từ ngữ cảnh của nó ``context``, và các từ nhiễu của nó ``negative``. Hàm này trả về một minibatch có thể được nạp để tính toán trong quá trình đào tạo, chẳng hạn như bao gồm biến mask. .. raw:: html
.. raw:: html
.. code:: python #@save def batchify(data): """Return a minibatch of examples for skip-gram with negative sampling.""" max_len = max(len(c) + len(n) for _, c, n in data) centers, contexts_negatives, masks, labels = [], [], [], [] for center, context, negative in data: cur_len = len(context) + len(negative) centers += [center] contexts_negatives += [context + negative + [0] * (max_len - cur_len)] masks += [[1] * cur_len + [0] * (max_len - cur_len)] labels += [[1] * len(context) + [0] * (max_len - len(context))] return (np.array(centers).reshape((-1, 1)), np.array( contexts_negatives), np.array(masks), np.array(labels)) .. raw:: html
.. raw:: html
.. code:: python #@save def batchify(data): """Return a minibatch of examples for skip-gram with negative sampling.""" max_len = max(len(c) + len(n) for _, c, n in data) centers, contexts_negatives, masks, labels = [], [], [], [] for center, context, negative in data: cur_len = len(context) + len(negative) centers += [center] contexts_negatives += [context + negative + [0] * (max_len - cur_len)] masks += [[1] * cur_len + [0] * (max_len - cur_len)] labels += [[1] * len(context) + [0] * (max_len - len(context))] return (torch.tensor(centers).reshape((-1, 1)), torch.tensor( contexts_negatives), torch.tensor(masks), torch.tensor(labels)) .. raw:: html
.. raw:: html
Hãy để chúng tôi kiểm tra chức năng này bằng cách sử dụng một minibatch gồm hai ví dụ. .. raw:: html
.. raw:: html
.. code:: python x_1 = (1, [2, 2], [3, 3, 3, 3]) x_2 = (1, [2, 2, 2], [3, 3]) batch = batchify((x_1, x_2)) names = ['centers', 'contexts_negatives', 'masks', 'labels'] for name, data in zip(names, batch): print(name, '=', data) .. parsed-literal:: :class: output centers = [[1.] [1.]] contexts_negatives = [[2. 2. 3. 3. 3. 3.] [2. 2. 2. 3. 3. 0.]] masks = [[1. 1. 1. 1. 1. 1.] [1. 1. 1. 1. 1. 0.]] labels = [[1. 1. 0. 0. 0. 0.] [1. 1. 1. 0. 0. 0.]] .. raw:: html
.. raw:: html
.. code:: python x_1 = (1, [2, 2], [3, 3, 3, 3]) x_2 = (1, [2, 2, 2], [3, 3]) batch = batchify((x_1, x_2)) names = ['centers', 'contexts_negatives', 'masks', 'labels'] for name, data in zip(names, batch): print(name, '=', data) .. parsed-literal:: :class: output centers = tensor([[1], [1]]) contexts_negatives = tensor([[2, 2, 3, 3, 3, 3], [2, 2, 2, 3, 3, 0]]) masks = tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 0]]) labels = tensor([[1, 1, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0]]) .. raw:: html
.. raw:: html
Đặt tất cả mọi thứ lại với nhau ------------------------------- Cuối cùng, chúng ta định nghĩa hàm ``load_data_ptb`` đọc tập dữ liệu PTB và trả về bộ lặp dữ liệu và từ vựng. .. raw:: html
.. raw:: html
.. code:: python #@save def load_data_ptb(batch_size, max_window_size, num_noise_words): """Download the PTB dataset and then load it into memory.""" sentences = read_ptb() vocab = d2l.Vocab(sentences, min_freq=10) subsampled, counter = subsample(sentences, vocab) corpus = [vocab[line] for line in subsampled] all_centers, all_contexts = get_centers_and_contexts( corpus, max_window_size) all_negatives = get_negatives( all_contexts, vocab, counter, num_noise_words) dataset = gluon.data.ArrayDataset( all_centers, all_contexts, all_negatives) data_iter = gluon.data.DataLoader( dataset, batch_size, shuffle=True,batchify_fn=batchify, num_workers=d2l.get_dataloader_workers()) return data_iter, vocab .. raw:: html
.. raw:: html
.. code:: python #@save def load_data_ptb(batch_size, max_window_size, num_noise_words): """Download the PTB dataset and then load it into memory.""" num_workers = d2l.get_dataloader_workers() sentences = read_ptb() vocab = d2l.Vocab(sentences, min_freq=10) subsampled, counter = subsample(sentences, vocab) corpus = [vocab[line] for line in subsampled] all_centers, all_contexts = get_centers_and_contexts( corpus, max_window_size) all_negatives = get_negatives( all_contexts, vocab, counter, num_noise_words) class PTBDataset(torch.utils.data.Dataset): def __init__(self, centers, contexts, negatives): assert len(centers) == len(contexts) == len(negatives) self.centers = centers self.contexts = contexts self.negatives = negatives def __getitem__(self, index): return (self.centers[index], self.contexts[index], self.negatives[index]) def __len__(self): return len(self.centers) dataset = PTBDataset(all_centers, all_contexts, all_negatives) data_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, collate_fn=batchify, num_workers=num_workers) return data_iter, vocab .. raw:: html
.. raw:: html
Hãy để chúng tôi in minibatch đầu tiên của bộ lặp dữ liệu. .. raw:: html
.. raw:: html
.. code:: python data_iter, vocab = load_data_ptb(512, 5, 5) for batch in data_iter: for name, data in zip(names, batch): print(name, 'shape:', data.shape) break .. parsed-literal:: :class: output centers shape: (512, 1) contexts_negatives shape: (512, 60) masks shape: (512, 60) labels shape: (512, 60) .. raw:: html
.. raw:: html
.. code:: python data_iter, vocab = load_data_ptb(512, 5, 5) for batch in data_iter: for name, data in zip(names, batch): print(name, 'shape:', data.shape) break .. parsed-literal:: :class: output centers shape: torch.Size([512, 1]) contexts_negatives shape: torch.Size([512, 60]) masks shape: torch.Size([512, 60]) labels shape: torch.Size([512, 60]) .. raw:: html
.. raw:: html
Tóm tắt ------- - Các từ tần số cao có thể không hữu ích trong đào tạo. Chúng tôi có thể subsample chúng để tăng tốc trong đào tạo. - Đối với hiệu quả tính toán, chúng tôi tải các ví dụ trong minibatches. Chúng ta có thể xác định các biến khác để phân biệt các miếng đệm từ các miếng đệm không và các ví dụ tích cực với các biến tiêu cực. Bài tập ------- 1. Làm thế nào để thời gian chạy của mã trong phần này thay đổi nếu không sử dụng subsampling? 2. Các ``RandomGenerator`` lớp lưu trữ ``k`` kết quả lấy mẫu ngẫu nhiên. Đặt ``k`` thành các giá trị khác và xem nó ảnh hưởng đến tốc độ tải dữ liệu như thế nào. 3. Những siêu tham số khác trong mã của phần này có thể ảnh hưởng đến tốc độ tải dữ liệu? .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html