# Style Transfer using Pytorch (Part 2)

## Goal¶

This post aims to explain the concept of style transfer step-by-step. Part 2 is about loss functions.

Reference

## Libraries¶

In [76]:
import pandas as pd

# Torch & Tensorflow
import torch
import torch.nn as nn
import torch.nn.functional as F
import tensorflow as tf

# Visualization
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline


## Loss Functions¶

### Content loss¶

Content loss is calculated using MSE (Mean Square Error) between the content images and the output image:

$$MSE = \frac{1}{n} \sum^{n}_{i=1} (Y_{content} - Y_{output})^2$$
In [7]:
class ContentLoss(nn.Module):

def __init__(self, target,):
super(ContentLoss, self).__init__()
self.target = target.detach()

def forward(self, input):
self.loss = F.mse_loss(input, self.target)
return input

In [15]:
content_loss = ContentLoss(torch.Tensor([0.5]))
content_loss.forward(torch.Tensor([0.]))
content_loss.loss

Out[15]:
tensor(0.2500)

### Style loss¶

To compute style loss, we need to define the function to compute a gram matrix for style feature.

$$\mathbf{G} = \mathbf{F}^\top \mathbf{F}$$

After converting features into gram matrix for style and output images, MSE is computed based on them.

In [71]:
def gram_matrix(input):
# Get the size of tensor
# a: batch size
# b: number of feature maps
# c, d: the dimension of a feature map
a, b, c, d = input.size()

# Reshape the feature
features = input.view(a * b, c * d)

# Multiplication
display(pd.DataFrame(features.numpy()))

G = torch.mm(features, features.t())

# Normalize
G_norm = G.div(a * b * c * d)
return G_norm

In [72]:
class StyleLoss(nn.Module):

def __init__(self, target_feature):
super(StyleLoss, self).__init__()
print('Reshaped Feature for Target feature:')
self.target = gram_matrix(target_feature).detach()

def forward(self, input):
print('Reshaped Feature for Input feature:')
G = gram_matrix(input)
print('Gram matrix for input feature:')
display(pd.DataFrame(G.numpy()))
print('Gram matrix for target feature:')
display(pd.DataFrame(self.target.numpy()))
self.loss = F.mse_loss(G, self.target)
return input

In [73]:
sample_input = torch.tensor([[[[1., .0],
[.0, 1.]],
[[1., .0],
[.0, 1.]],
[[1, .0],
[.0, 1.]]]])
sample_input.size()

Out[73]:
torch.Size([1, 3, 2, 2])
In [74]:
style_loss = StyleLoss(sample_input)
style_loss.forward(torch.zeros(1, 3, 2, 2))
style_loss.loss

Reshaped Feature for Target feature:

0 1 2 3
0 1.0 0.0 0.0 1.0
1 1.0 0.0 0.0 1.0
2 1.0 0.0 0.0 1.0
Reshaped Feature for Input feature:

0 1 2 3
0 0.0 0.0 0.0 0.0
1 0.0 0.0 0.0 0.0
2 0.0 0.0 0.0 0.0
Gram matrix for input feature:

0 1 2
0 0.0 0.0 0.0
1 0.0 0.0 0.0
2 0.0 0.0 0.0
Gram matrix for target feature:

0 1 2
0 0.166667 0.166667 0.166667
1 0.166667 0.166667 0.166667
2 0.166667 0.166667 0.166667
Out[74]:
tensor(0.0278)