Style Transfer using Pytorch (Part 2)

Libraries

In [76]:
import pandas as pd

# Torch & Tensorflow
import torch
import torch.nn as nn
import torch.nn.functional as F
import tensorflow as tf

# Visualization
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

Loss Functions

Content loss

Content loss is calculated using MSE (Mean Square Error) between the content images and the output image:

$$MSE = \frac{1}{n} \sum^{n}_{i=1} (Y_{content} - Y_{output})^2$$
In [7]:
class ContentLoss(nn.Module):

    def __init__(self, target,):
        super(ContentLoss, self).__init__()
        self.target = target.detach()

    def forward(self, input):
        self.loss = F.mse_loss(input, self.target)
        return input
In [15]:
content_loss = ContentLoss(torch.Tensor([0.5]))
content_loss.forward(torch.Tensor([0.]))
content_loss.loss
Out[15]:
tensor(0.2500)

Style loss

To compute style loss, we need to define the function to compute a gram matrix for style feature.

$$ \mathbf{G} = \mathbf{F}^\top \mathbf{F} $$

After converting features into gram matrix for style and output images, MSE is computed based on them.

In [71]:
def gram_matrix(input):
    # Get the size of tensor
    # a: batch size
    # b: number of feature maps
    # c, d: the dimension of a feature map
    a, b, c, d = input.size() 
    
    # Reshape the feature 
    features = input.view(a * b, c * d)

    # Multiplication
    display(pd.DataFrame(features.numpy()))

    G = torch.mm(features, features.t())  
    
    # Normalize 
    G_norm = G.div(a * b * c * d)
    return G_norm
In [72]:
class StyleLoss(nn.Module):

    def __init__(self, target_feature):
        super(StyleLoss, self).__init__()
        print('Reshaped Feature for Target feature:')
        self.target = gram_matrix(target_feature).detach()

    def forward(self, input):
        print('Reshaped Feature for Input feature:')
        G = gram_matrix(input)
        print('Gram matrix for input feature:')
        display(pd.DataFrame(G.numpy()))
        print('Gram matrix for target feature:')
        display(pd.DataFrame(self.target.numpy()))
        self.loss = F.mse_loss(G, self.target)
        return input
In [73]:
sample_input = torch.tensor([[[[1., .0], 
                               [.0, 1.]], 
                              [[1., .0], 
                               [.0, 1.]], 
                              [[1, .0], 
                               [.0, 1.]]]])
sample_input.size()
Out[73]:
torch.Size([1, 3, 2, 2])
In [74]:
style_loss = StyleLoss(sample_input)
style_loss.forward(torch.zeros(1, 3, 2, 2))
style_loss.loss
Reshaped Feature for Target feature:
0 1 2 3
0 1.0 0.0 0.0 1.0
1 1.0 0.0 0.0 1.0
2 1.0 0.0 0.0 1.0
Reshaped Feature for Input feature:
0 1 2 3
0 0.0 0.0 0.0 0.0
1 0.0 0.0 0.0 0.0
2 0.0 0.0 0.0 0.0
Gram matrix for input feature:
0 1 2
0 0.0 0.0 0.0
1 0.0 0.0 0.0
2 0.0 0.0 0.0
Gram matrix for target feature:
0 1 2
0 0.166667 0.166667 0.166667
1 0.166667 0.166667 0.166667
2 0.166667 0.166667 0.166667
Out[74]:
tensor(0.0278)

Comments

Comments powered by Disqus