Style Transfer using Pytorch (Part 1)


This post aims to follow the tutorial NEURAL TRANSFER USING PYTORCH step-by-step. Part 1 is about image loading. The following images for content and style are loaded as PyTorch tensor.




In [40]:
# Torch & Tensorflow
import torch
import tensorflow as tf

# Visualization
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
%matplotlib inline


In [37]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Load an image

In [32]:
d_path = {}
d_path['content'] = tf.keras.utils.get_file('turtle.jpg','')
d_path['style'] = tf.keras.utils.get_file('kandinsky.jpg','')
In [33]:
fig, axes = plt.subplots(1, 2, figsize=(16, 8))
for i, key in enumerate(d_path.keys()):
    img = plt.imread(d_path[key])    
    axes[i].set_title(f'Images for {key}')

Create an image loader for PyTorch

To use PyTorch, images need to be loaded as tensor through the image loader.

In [43]:
# desired size of the output image
imsize = 512 if torch.cuda.is_available() else 128  # use small size if no gpu

loader = transforms.Compose([
    transforms.Resize(imsize),  # scale imported image
    transforms.ToTensor()])  # transform it into a torch tensor

def image_loader(image_name):
    image =

    # fake batch dimension required to fit network's input dimensions
    image = loader(image).unsqueeze(0)
    return, torch.float)
In [41]:
style_img = image_loader(d_path['content'])
content_img = image_loader(d_path['style'])

Plot tensor images as PIL image

In [80]:
unloader = transforms.ToPILImage()  # reconvert into PIL image

def imshow(tensor, title=None):
    image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
    image = image.squeeze(0)      # remove the fake batch dimension
    image = unloader(image)
    if title is not None:
    plt.pause(0.001) # pause a bit so that plots are updated

Check transform process

In [87]:
totensor = transforms.ToTensor()
scale = transforms.Resize(128)
composed = transforms.Compose([transforms.Resize(128),

# Apply each of the above transforms on sample.
fig = plt.figure(figsize=(16, 6))
img = Image.fromarray(plt.imread(d_path['style']))  
print(f'Original image size: {img.size}')
for i, tsfrm in enumerate([totensor, scale, composed]):
    transformed_sample = tsfrm(img)
    if i != 1:
        transformed_sample = transformed_sample.cpu().clone().squeeze(0)
        transformed_sample = unloader(transformed_sample)

    ax = plt.subplot(1, 3, i + 1)
Original image size: (1000, 657)
In [ ]:


Comments powered by Disqus