Style Transfer using Pytorch (Part 2)
Goal¶
This post aims to explain the concept of style transfer step-by-step. Part 2 is about loss functions.
Reference
Libraries¶
In [76]:
import pandas as pd
# Torch & Tensorflow
import torch
import torch.nn as nn
import torch.nn.functional as F
import tensorflow as tf
# Visualization
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline
Loss Functions¶
Content loss¶
Content loss is calculated using MSE (Mean Square Error) between the content images and the output image:
$$MSE = \frac{1}{n} \sum^{n}_{i=1} (Y_{content} - Y_{output})^2$$In [7]:
class ContentLoss(nn.Module):
def __init__(self, target,):
super(ContentLoss, self).__init__()
self.target = target.detach()
def forward(self, input):
self.loss = F.mse_loss(input, self.target)
return input
In [15]:
content_loss = ContentLoss(torch.Tensor([0.5]))
content_loss.forward(torch.Tensor([0.]))
content_loss.loss
Out[15]:
Style loss¶
To compute style loss, we need to define the function to compute a gram matrix for style feature.
$$ \mathbf{G} = \mathbf{F}^\top \mathbf{F} $$After converting features into gram matrix for style and output images, MSE is computed based on them.
In [71]:
def gram_matrix(input):
# Get the size of tensor
# a: batch size
# b: number of feature maps
# c, d: the dimension of a feature map
a, b, c, d = input.size()
# Reshape the feature
features = input.view(a * b, c * d)
# Multiplication
display(pd.DataFrame(features.numpy()))
G = torch.mm(features, features.t())
# Normalize
G_norm = G.div(a * b * c * d)
return G_norm
In [72]:
class StyleLoss(nn.Module):
def __init__(self, target_feature):
super(StyleLoss, self).__init__()
print('Reshaped Feature for Target feature:')
self.target = gram_matrix(target_feature).detach()
def forward(self, input):
print('Reshaped Feature for Input feature:')
G = gram_matrix(input)
print('Gram matrix for input feature:')
display(pd.DataFrame(G.numpy()))
print('Gram matrix for target feature:')
display(pd.DataFrame(self.target.numpy()))
self.loss = F.mse_loss(G, self.target)
return input
In [73]:
sample_input = torch.tensor([[[[1., .0],
[.0, 1.]],
[[1., .0],
[.0, 1.]],
[[1, .0],
[.0, 1.]]]])
sample_input.size()
Out[73]:
In [74]:
style_loss = StyleLoss(sample_input)
style_loss.forward(torch.zeros(1, 3, 2, 2))
style_loss.loss
Out[74]:
Comments
Comments powered by Disqus