Chuyển kiểu thần kinh ===================== Nếu bạn là một người đam mê nhiếp ảnh, bạn có thể đã quen thuộc với bộ lọc. Nó có thể thay đổi kiểu màu của ảnh để ảnh phong cảnh trở nên sắc nét hơn hoặc ảnh chân dung đã làm trắng da. Tuy nhiên, một bộ lọc thường chỉ thay đổi một khía cạnh của ảnh. Để áp dụng một phong cách lý tưởng cho ảnh, có lẽ bạn cần thử nhiều kết hợp bộ lọc khác nhau. Quá trình này phức tạp như điều chỉnh các siêu tham số của một mô hình. Trong phần này, chúng ta sẽ tận dụng các biểu diễn theo lớp của CNN để tự động áp dụng kiểu của một hình ảnh cho một hình ảnh khác, tức là, \* phong cách transfer\* :cite:`Gatys.Ecker.Bethge.2016`. Nhiệm vụ này cần hai hình ảnh đầu vào: một là hình ảnh\* nội dung hình ảnh\* và cái còn lại là hình ảnh\* phong cách\*. Chúng tôi sẽ sử dụng mạng nơ-ron để sửa đổi hình ảnh nội dung để làm cho nó gần với hình ảnh phong cách theo phong cách. Ví dụ, hình ảnh nội dung trong :numref:`fig_style_transfer` là một bức ảnh phong cảnh do chúng tôi chụp trong Vườn quốc gia Núi Rainier ở ngoại ô Seattle, trong khi hình ảnh phong cách là một bức tranh sơn dầu với chủ đề cây sồi mùa thu. Trong hình ảnh tổng hợp đầu ra, các nét cọ dầu của hình ảnh phong cách được áp dụng, dẫn đến màu sắc sống động hơn, đồng thời vẫn giữ được hình dạng chính của các đối tượng trong hình ảnh nội dung. .. _fig_style_transfer: .. figure:: ../img/style-transfer.svg Given content and style images, style transfer outputs a synthesized image. Phương pháp ----------- :numref:`fig_style_transfer_model` minh họa phương thức chuyển kiểu dựa trên CNN với một ví dụ đơn giản hóa. Đầu tiên, chúng tôi khởi tạo hình ảnh tổng hợp, ví dụ, vào hình ảnh nội dung. Hình ảnh tổng hợp này là biến duy nhất cần được cập nhật trong quá trình chuyển kiểu, tức là các tham số mô hình được cập nhật trong quá trình đào tạo. Sau đó, chúng tôi chọn CNN được đào tạo trước để trích xuất các tính năng hình ảnh và đóng băng các thông số mô hình của nó trong quá trình đào tạo. CNN sâu này sử dụng nhiều lớp để trích xuất các tính năng phân cấp cho hình ảnh. Chúng ta có thể chọn đầu ra của một số lớp này làm tính năng nội dung hoặc tính năng phong cách. Lấy :numref:`fig_style_transfer_model` làm ví dụ. Mạng nơ-ron được đào tạo trước ở đây có 3 lớp phức tạp, trong đó lớp thứ hai xuất ra các tính năng nội dung, và các lớp thứ nhất và thứ ba xuất ra các tính năng kiểu. .. _fig_style_transfer_model: .. figure:: ../img/neural-style.svg CNN-based style transfer process. Solid lines show the direction of forward propagation and dotted lines show backward propagation. Tiếp theo, chúng ta tính toán hàm mất của chuyển kiểu thông qua lan truyền về phía trước (hướng của mũi tên rắn), và cập nhật các tham số mô hình (hình ảnh tổng hợp cho đầu ra) thông qua truyền ngược (hướng của các mũi tên đứt nét). Chức năng mất thường được sử dụng trong chuyển phong cách bao gồm ba phần: (i) \* mất nội dung\* làm cho hình ảnh tổng hợp và hình ảnh nội dung gần gũi trong các tính năng nội dung; (ii) \* phong cách mất đi\* làm cho hình ảnh tổng hợp và phong cách gần gũi trong các tính năng phong cách; và (iii) \* mất biến thể\* giúp giảm noise tiếng ồn in the synthesized tổng hợp image hình ảnh. Cuối cùng, khi đào tạo mô hình kết thúc, chúng tôi xuất các thông số mô hình của chuyển kiểu để tạo ra hình ảnh tổng hợp cuối cùng. Sau đây, chúng tôi sẽ giải thích các chi tiết kỹ thuật của chuyển phong cách thông qua một thí nghiệm cụ thể. Đọc nội dung và phong cách hình ảnh ----------------------------------- Đầu tiên, chúng tôi đọc nội dung và phong cách hình ảnh. Từ các trục tọa độ in của chúng, chúng ta có thể nói rằng những hình ảnh này có kích thước khác nhau. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python %matplotlib inline from mxnet import autograd, gluon, image, init, np, npx from mxnet.gluon import nn from d2l import mxnet as d2l npx.set_np() d2l.set_figsize() content_img = image.imread('../img/rainier.jpg') d2l.plt.imshow(content_img.asnumpy()); .. figure:: output_neural-style_5de8ca_3_0.svg .. code:: python style_img = image.imread('../img/autumn-oak.jpg') d2l.plt.imshow(style_img.asnumpy()); .. figure:: output_neural-style_5de8ca_4_0.svg .. raw:: html
.. raw:: html
.. code:: python %matplotlib inline import torch import torchvision from torch import nn from d2l import torch as d2l d2l.set_figsize() content_img = d2l.Image.open('../img/rainier.jpg') d2l.plt.imshow(content_img); .. figure:: output_neural-style_5de8ca_7_0.svg .. code:: python style_img = d2l.Image.open('../img/autumn-oak.jpg') d2l.plt.imshow(style_img); .. figure:: output_neural-style_5de8ca_8_0.svg .. raw:: html
.. raw:: html
Tiền xử lý và xử lý sau ----------------------- Dưới đây, chúng tôi xác định hai chức năng cho hình ảnh tiền xử lý và xử lý hậu kỳ. Chức năng ``preprocess`` chuẩn hóa mỗi kênh trong ba kênh RGB của hình ảnh đầu vào và biến đổi kết quả thành định dạng đầu vào CNN. Hàm ``postprocess`` khôi phục các giá trị điểm ảnh trong hình ảnh đầu ra về giá trị ban đầu của chúng trước khi tiêu chuẩn hóa. Vì chức năng in ảnh yêu cầu mỗi pixel có giá trị điểm nổi từ 0 đến 1, chúng ta thay thế bất kỳ giá trị nào nhỏ hơn 0 hoặc lớn hơn 1 với 0 hoặc 1, tương ứng. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python rgb_mean = np.array([0.485, 0.456, 0.406]) rgb_std = np.array([0.229, 0.224, 0.225]) def preprocess(img, image_shape): img = image.imresize(img, *image_shape) img = (img.astype('float32') / 255 - rgb_mean) / rgb_std return np.expand_dims(img.transpose(2, 0, 1), axis=0) def postprocess(img): img = img[0].as_in_ctx(rgb_std.ctx) return (img.transpose(1, 2, 0) * rgb_std + rgb_mean).clip(0, 1) .. raw:: html
.. raw:: html
.. code:: python rgb_mean = torch.tensor([0.485, 0.456, 0.406]) rgb_std = torch.tensor([0.229, 0.224, 0.225]) def preprocess(img, image_shape): transforms = torchvision.transforms.Compose([ torchvision.transforms.Resize(image_shape), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)]) return transforms(img).unsqueeze(0) def postprocess(img): img = img[0].to(rgb_std.device) img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1) return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1)) .. raw:: html
.. raw:: html
tính năng chiết xuất -------------------- Chúng tôi sử dụng mô hình VGG-19 được đào tạo trước trên bộ dữ liệu ImageNet để trích xuất các tính năng hình ảnh :cite:`Gatys.Ecker.Bethge.2016`. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python pretrained_net = gluon.model_zoo.vision.vgg19(pretrained=True) .. raw:: html
.. raw:: html
.. code:: python pretrained_net = torchvision.models.vgg19(pretrained=True) .. raw:: html
.. raw:: html
Để trích xuất các tính năng nội dung và tính năng kiểu của hình ảnh, chúng ta có thể chọn đầu ra của các lớp nhất định trong mạng VGG. Nói chung, càng gần lớp đầu vào, dễ dàng hơn để trích xuất chi tiết của hình ảnh, và ngược lại, dễ dàng hơn để trích xuất thông tin toàn cầu của hình ảnh. Để tránh giữ lại quá mức các chi tiết của hình ảnh nội dung trong hình ảnh tổng hợp, chúng tôi chọn một lớp VGG gần đầu ra hơn làm lớp nội dung *để xuất các tính năng nội dung của hình ảnh. Chúng tôi cũng chọn đầu ra của các lớp VGG khác nhau để trích xuất các tính năng kiểu địa phương và toàn cầu. Các lớp này còn được gọi là lớp kiểu *. Như đã đề cập trong :numref:`sec_vgg`, mạng VGG sử dụng 5 khối phức tạp. Trong thí nghiệm, chúng ta chọn lớp phức tạp cuối cùng của khối phức tạp thứ tư làm lớp nội dung, và lớp phức tạp đầu tiên của mỗi khối phức tạp làm lớp phong cách. Các chỉ số của các lớp này có thể thu được bằng cách in phiên bản ``pretrained_net``. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python style_layers, content_layers = [0, 5, 10, 19, 28], [25] .. raw:: html
.. raw:: html
.. code:: python style_layers, content_layers = [0, 5, 10, 19, 28], [25] .. raw:: html
.. raw:: html
Khi trích xuất các đối tượng bằng cách sử dụng các lớp VGG, chúng ta chỉ cần sử dụng tất cả các đối tượng từ lớp input đến lớp nội dung hoặc layer style gần nhất với lớp đầu ra nhất. Chúng ta hãy xây dựng một phiên bản mạng mới ``net``, chỉ giữ lại tất cả các lớp VGG được sử dụng để trích xuất tính năng. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python net = nn.Sequential() for i in range(max(content_layers + style_layers) + 1): net.add(pretrained_net.features[i]) .. raw:: html
.. raw:: html
.. code:: python net = nn.Sequential(*[pretrained_net.features[i] for i in range(max(content_layers + style_layers) + 1)]) .. raw:: html
.. raw:: html
Với đầu vào ``X``, nếu chúng ta chỉ đơn giản gọi tuyên truyền chuyển tiếp ``net(X)``, chúng ta chỉ có thể nhận được đầu ra của lớp cuối cùng. Vì chúng ta cũng cần các đầu ra của các lớp trung gian, chúng ta cần thực hiện tính toán từng lớp và giữ cho các đầu ra lớp nội dung và kiểu dáng. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python def extract_features(X, content_layers, style_layers): contents = [] styles = [] for i in range(len(net)): X = net[i](X) if i in style_layers: styles.append(X) if i in content_layers: contents.append(X) return contents, styles .. raw:: html
.. raw:: html
.. code:: python def extract_features(X, content_layers, style_layers): contents = [] styles = [] for i in range(len(net)): X = net[i](X) if i in style_layers: styles.append(X) if i in content_layers: contents.append(X) return contents, styles .. raw:: html
.. raw:: html
Hai chức năng được định nghĩa dưới đây: hàm ``get_contents`` trích xuất các tính năng nội dung từ hình ảnh nội dung và hàm ``get_styles`` trích xuất các tính năng kiểu từ hình ảnh phong cách. Vì không cần cập nhật các thông số mô hình của VGG được đào tạo trước trong quá trình đào tạo, chúng tôi có thể trích xuất nội dung và các tính năng phong cách ngay cả trước khi bắt đầu đào tạo. Vì hình ảnh tổng hợp là một tập hợp các tham số mô hình được cập nhật để chuyển kiểu, chúng ta chỉ có thể trích xuất nội dung và tính năng kiểu dáng của hình ảnh tổng hợp bằng cách gọi hàm ``extract_features`` trong quá trình đào tạo. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python def get_contents(image_shape, device): content_X = preprocess(content_img, image_shape).copyto(device) contents_Y, _ = extract_features(content_X, content_layers, style_layers) return content_X, contents_Y def get_styles(image_shape, device): style_X = preprocess(style_img, image_shape).copyto(device) _, styles_Y = extract_features(style_X, content_layers, style_layers) return style_X, styles_Y .. raw:: html
.. raw:: html
.. code:: python def get_contents(image_shape, device): content_X = preprocess(content_img, image_shape).to(device) contents_Y, _ = extract_features(content_X, content_layers, style_layers) return content_X, contents_Y def get_styles(image_shape, device): style_X = preprocess(style_img, image_shape).to(device) _, styles_Y = extract_features(style_X, content_layers, style_layers) return style_X, styles_Y .. raw:: html
.. raw:: html
Defining the Loss Function -------------------------- Bây giờ chúng ta sẽ mô tả chức năng mất để chuyển phong cách. Chức năng mất bao gồm mất nội dung, mất phong cách và mất hoàn toàn biến thể. Mất nội dung ~~~~~~~~~~~~ Tương tự như chức năng mất trong hồi quy tuyến tính, mất nội dung đo lường sự khác biệt về các tính năng nội dung giữa hình ảnh tổng hợp và hình ảnh nội dung thông qua chức năng mất bình phương. Hai đầu vào của hàm mất bình phương là cả hai đầu ra của lớp nội dung được tính toán bởi hàm ``extract_features``. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python def content_loss(Y_hat, Y): return np.square(Y_hat - Y).mean() .. raw:: html
.. raw:: html
.. code:: python def content_loss(Y_hat, Y): # We detach the target content from the tree used to dynamically compute # the gradient: this is a stated value, not a variable. Otherwise the loss # will throw an error. return torch.square(Y_hat - Y.detach()).mean() .. raw:: html
.. raw:: html
Phong cách mất ~~~~~~~~~~~~~~ Mất phong cách, tương tự như mất nội dung, cũng sử dụng chức năng mất bình phương để đo lường sự khác biệt về phong cách giữa hình ảnh tổng hợp và hình ảnh phong cách. Để thể hiện đầu ra kiểu của bất kỳ layer style nào, trước tiên chúng ta sử dụng hàm ``extract_features`` để tính toán đầu ra layer style. Giả sử rằng đầu ra có 1 ví dụ, :math:`c` kênh, chiều cao :math:`h` và chiều rộng :math:`w`, chúng ta có thể chuyển đổi đầu ra này thành ma trận :math:`\mathbf{X}` với :math:`c` hàng và :math:`hw` cột. Ma trận này có thể được coi là sự nối của :math:`c` vectơ :math:`\mathbf{x}_1, \ldots, \mathbf{x}_c`, mỗi vectơ có chiều dài :math:`hw`. Ở đây, vector :math:`\mathbf{x}_i` đại diện cho tính năng phong cách của kênh :math:`i`. Trong ma trận *Gram* của các vectơ :math:`\mathbf{X}\mathbf{X}^\top \in \mathbb{R}^{c \times c}`, phần tử :math:`x_{ij}` trong hàng :math:`i` và cột :math:`j` là tích chấm của vectơ :math:`\mathbf{x}_i` và :math:`\mathbf{x}_j`. Nó đại diện cho mối tương quan của các tính năng phong cách của các kênh :math:`i` và :math:`j`. Chúng tôi sử dụng ma trận Gram này để đại diện cho đầu ra kiểu của bất kỳ layer style nào. Lưu ý rằng khi giá trị của :math:`hw` lớn hơn, nó có thể dẫn đến các giá trị lớn hơn trong ma trận Gram. Cũng lưu ý rằng chiều cao và chiều rộng của ma trận Gram đều là số kênh :math:`c`. Để cho phép mất kiểu không bị ảnh hưởng bởi các giá trị này, hàm ``gram`` bên dưới chia ma trận Gram cho số phần tử của nó, tức là :math:`chw`. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python def gram(X): num_channels, n = X.shape[1], d2l.size(X) // X.shape[1] X = X.reshape((num_channels, n)) return np.dot(X, X.T) / (num_channels * n) .. raw:: html
.. raw:: html
.. code:: python def gram(X): num_channels, n = X.shape[1], X.numel() // X.shape[1] X = X.reshape((num_channels, n)) return torch.matmul(X, X.T) / (num_channels * n) .. raw:: html
.. raw:: html
Rõ ràng, hai đầu vào ma trận Gram của hàm mất bình phương để mất kiểu được dựa trên đầu ra lớp kiểu cho hình ảnh tổng hợp và hình ảnh phong cách. Người ta cho rằng ma trận Gram ``gram_Y`` dựa trên hình ảnh phong cách đã được tính toán trước. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python def style_loss(Y_hat, gram_Y): return np.square(gram(Y_hat) - gram_Y).mean() .. raw:: html
.. raw:: html
.. code:: python def style_loss(Y_hat, gram_Y): return torch.square(gram(Y_hat) - gram_Y.detach()).mean() .. raw:: html
.. raw:: html
Tổng Biến Thể Mất ~~~~~~~~~~~~~~~~~ Đôi khi, hình ảnh tổng hợp đã học có rất nhiều nhiễu tần số cao, tức là, đặc biệt là các pixel sáng hoặc tối. Một phương pháp giảm tiếng ồn phổ biến là *tổng biến thể biểu tượng*. Biểu thị bằng :math:`x_{i, j}` giá trị điểm ảnh ở tọa độ :math:`(i, j)`. Giảm tổng tổn thất biến thể .. math:: \sum_{i, j} \left|x_{i, j} - x_{i+1, j}\right| + \left|x_{i, j} - x_{i, j+1}\right| làm cho các giá trị của các pixel lân cận trên hình ảnh tổng hợp gần hơn. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python def tv_loss(Y_hat): return 0.5 * (np.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() + np.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean()) .. raw:: html
.. raw:: html
.. code:: python def tv_loss(Y_hat): return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() + torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean()) .. raw:: html
.. raw:: html
Chức năng mất ~~~~~~~~~~~~~ Chức năng mất của chuyển phong cách là tổng trọng số của mất nội dung, mất phong cách, và mất biến thể tổng số. Bằng cách điều chỉnh các siêu tham số trọng lượng này, chúng ta có thể cân bằng giữa việc lưu giữ nội dung, chuyển kiểu và giảm nhiễu trên hình ảnh tổng hợp. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python content_weight, style_weight, tv_weight = 1, 1e3, 10 def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram): # Calculate the content, style, and total variance losses respectively contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip( contents_Y_hat, contents_Y)] styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip( styles_Y_hat, styles_Y_gram)] tv_l = tv_loss(X) * tv_weight # Add up all the losses l = sum(10 * styles_l + contents_l + [tv_l]) return contents_l, styles_l, tv_l, l .. raw:: html
.. raw:: html
.. code:: python content_weight, style_weight, tv_weight = 1, 1e3, 10 def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram): # Calculate the content, style, and total variance losses respectively contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip( contents_Y_hat, contents_Y)] styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip( styles_Y_hat, styles_Y_gram)] tv_l = tv_loss(X) * tv_weight # Add up all the losses l = sum(10 * styles_l + contents_l + [tv_l]) return contents_l, styles_l, tv_l, l .. raw:: html
.. raw:: html
Initializing the Synthesized Image ---------------------------------- Trong chuyển phong cách, hình ảnh tổng hợp là biến duy nhất cần được cập nhật trong quá trình đào tạo. Do đó, chúng ta có thể xác định một mô hình đơn giản, ``SynthesizedImage`` và coi hình ảnh tổng hợp như các tham số mô hình. Trong mô hình này, tuyên truyền chuyển tiếp chỉ trả về các tham số mô hình. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python class SynthesizedImage(nn.Block): def __init__(self, img_shape, **kwargs): super(SynthesizedImage, self).__init__(**kwargs) self.weight = self.params.get('weight', shape=img_shape) def forward(self): return self.weight.data() .. raw:: html
.. raw:: html
.. code:: python class SynthesizedImage(nn.Module): def __init__(self, img_shape, **kwargs): super(SynthesizedImage, self).__init__(**kwargs) self.weight = nn.Parameter(torch.rand(*img_shape)) def forward(self): return self.weight .. raw:: html
.. raw:: html
Tiếp theo, ta định nghĩa hàm ``get_inits``. Hàm này tạo ra một ví dụ mô hình ảnh tổng hợp và khởi tạo nó thành hình ảnh ``X``. Ma trận gram cho hình ảnh phong cách ở các lớp phong cách khác nhau, ``styles_Y_gram``, được tính toán trước khi đào tạo. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python def get_inits(X, device, lr, styles_Y): gen_img = SynthesizedImage(X.shape) gen_img.initialize(init.Constant(X), ctx=device, force_reinit=True) trainer = gluon.Trainer(gen_img.collect_params(), 'adam', {'learning_rate': lr}) styles_Y_gram = [gram(Y) for Y in styles_Y] return gen_img(), styles_Y_gram, trainer .. raw:: html
.. raw:: html
.. code:: python def get_inits(X, device, lr, styles_Y): gen_img = SynthesizedImage(X.shape).to(device) gen_img.weight.data.copy_(X.data) trainer = torch.optim.Adam(gen_img.parameters(), lr=lr) styles_Y_gram = [gram(Y) for Y in styles_Y] return gen_img(), styles_Y_gram, trainer .. raw:: html
.. raw:: html
Đào tạo ------- Khi đào tạo mô hình để chuyển phong cách, chúng tôi liên tục trích xuất các tính năng nội dung và tính năng phong cách của hình ảnh tổng hợp và tính toán chức năng mất mát. Dưới đây xác định vòng lặp đào tạo. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch): X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y) animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[10, num_epochs], ylim=[0, 20], legend=['content', 'style', 'TV'], ncols=2, figsize=(7, 2.5)) for epoch in range(num_epochs): with autograd.record(): contents_Y_hat, styles_Y_hat = extract_features( X, content_layers, style_layers) contents_l, styles_l, tv_l, l = compute_loss( X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram) l.backward() trainer.step(1) if (epoch + 1) % lr_decay_epoch == 0: trainer.set_learning_rate(trainer.learning_rate * 0.8) if (epoch + 1) % 10 == 0: animator.axes[1].imshow(postprocess(X).asnumpy()) animator.add(epoch + 1, [float(sum(contents_l)), float(sum(styles_l)), float(tv_l)]) return X .. raw:: html
.. raw:: html
.. code:: python def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch): X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y) scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8) animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[10, num_epochs], legend=['content', 'style', 'TV'], ncols=2, figsize=(7, 2.5)) for epoch in range(num_epochs): trainer.zero_grad() contents_Y_hat, styles_Y_hat = extract_features( X, content_layers, style_layers) contents_l, styles_l, tv_l, l = compute_loss( X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram) l.backward() trainer.step() scheduler.step() if (epoch + 1) % 10 == 0: animator.axes[1].imshow(postprocess(X)) animator.add(epoch + 1, [float(sum(contents_l)), float(sum(styles_l)), float(tv_l)]) return X .. raw:: html
.. raw:: html
Bây giờ chúng ta bắt đầu đào tạo mô hình. Chúng tôi giải thích chiều cao và chiều rộng của nội dung và phong cách hình ảnh lên 300 x 450 pixel. Chúng tôi sử dụng hình ảnh nội dung để khởi tạo hình ảnh tổng hợp. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python device, image_shape = d2l.try_gpu(), (450, 300) net.collect_params().reset_ctx(device) content_X, contents_Y = get_contents(image_shape, device) _, styles_Y = get_styles(image_shape, device) output = train(content_X, contents_Y, styles_Y, device, 0.9, 500, 50) .. figure:: output_neural-style_5de8ca_140_0.svg .. raw:: html
.. raw:: html
.. code:: python device, image_shape = d2l.try_gpu(), (300, 450) # PIL Image (h, w) net = net.to(device) content_X, contents_Y = get_contents(image_shape, device) _, styles_Y = get_styles(image_shape, device) output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50) .. figure:: output_neural-style_5de8ca_143_0.svg .. raw:: html
.. raw:: html
Chúng ta có thể thấy rằng hình ảnh tổng hợp giữ lại cảnh quan và đối tượng của hình ảnh nội dung và chuyển màu sắc của hình ảnh phong cách cùng một lúc. Ví dụ, hình ảnh tổng hợp có các khối màu giống như trong hình ảnh phong cách. Một số khối này thậm chí còn có kết cấu tinh tế của nét cọ. Tóm tắt ------- - Chức năng mất mát thường được sử dụng trong chuyển phong cách bao gồm ba phần: (i) mất nội dung làm cho hình ảnh tổng hợp và hình ảnh nội dung gần với các tính năng nội dung; (ii) mất phong cách làm cho hình ảnh tổng hợp và hình ảnh phong cách gần gũi trong các tính năng phong cách; và (iii) mất biến đổi tổng thể giúp giảm nhiễu trong the synthesized tổng hợp image hình ảnh. - Chúng ta có thể sử dụng CNN được đào tạo trước để trích xuất các tính năng hình ảnh và giảm thiểu chức năng mất mát để liên tục cập nhật hình ảnh tổng hợp dưới dạng tham số mô hình trong quá trình đào tạo. - Chúng ta sử dụng ma trận Gram để biểu diễn các đầu ra kiểu dáng từ các layer style. Bài tập ------- 1. Làm thế nào để đầu ra thay đổi khi bạn chọn các lớp nội dung và kiểu khác nhau? 2. Điều chỉnh các siêu tham số trọng lượng trong chức năng mất mát. Đầu ra có giữ được nhiều nội dung hơn hoặc ít tiếng ồn hơn? 3. Sử dụng nội dung khác nhau và hình ảnh phong cách. Bạn có thể tạo ra những hình ảnh tổng hợp thú vị hơn không? 4. Chúng ta có thể áp dụng chuyển kiểu cho văn bản không? Hint: you may refer to the survey paper by Hu et al. :cite:`Hu.Lee.Aggarwal.ea.2020`. .. raw:: html
mxnetpytorch
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html