Activation Functions in Neural Networks

Goal

This post aims to introduce activation functions used in neural networks using pytorch. image

Reference

Libraries

In [47]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline 

import torch.nn as nn

Activation Functions

In this post, the following activation functions are introduced.

  • Sigmoid
  • Tanh
  • ReLU
  • Leaky ReLU

Create a input data

In [48]:
x = np.linspace(-5, 5, 100)
t_x = torch.tensor(x)
d_activations = {}

Sigmoid

$$ Sigmoid (x) = \frac{1} {1 + exp (-x)} $$
In [49]:
d_activations['Sigmoid'] = nn.Sigmoid()
y = d_activations['Sigmoid'](t_x)

plt.plot(t_x.numpy(), y.numpy(), '.');
plt.title('Sigmoid activation function');
plt.xlabel('Input x'); plt.ylabel('Output y')
plt.ylim([-1.5, 1.5]);

Tanh

$$Tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}$$
In [50]:
d_activations['Tanh'] = nn.Tanh()
y = d_activations['Tanh'](t_x)

plt.plot(t_x.numpy(), y.numpy(), '.');
plt.title('Tanh activation function');
plt.xlabel('Input x'); plt.ylabel('Output y')
plt.ylim([-1.5, 1.5]);

ReLU

Relu stands for Rectified linear unit $$Relu(x) = max(0, x) $$

In [51]:
act_function = 'ReLU'
d_activations[act_function] = nn.ReLU()
y = d_activations[act_function](t_x)

plt.plot(t_x.numpy(), y.numpy(), '.');
plt.title(f'{act_function} activation function');
plt.xlabel('Input x'); plt.ylabel('Output y')
plt.ylim([-1.5, 1.5]);

Leaky ReLU

$$ LeakyReLU() = max(0, x) + negative\_slope * min(0, x) $$

negative_slope is a small value added to the output output when $x < 0$

In [52]:
act_function = 'Leeky ReLU'
d_activations[act_function] = nn.LeakyReLU()
y = d_activations[act_function](t_x)

plt.plot(t_x.numpy(), y.numpy(), '.');
plt.title(f'{act_function} activation function');
plt.xlabel('Input x'); plt.ylabel('Output y')
plt.ylim([-1.5, 1.5]);

Summary

In [63]:
fig, axes = plt.subplots(1, len(d_activations), figsize=(16, 3))

for i, f in enumerate(d_activations.keys()):
    
    axes[i].plot(t_x.numpy(), d_activations[f](t_x).numpy(), '.');
    axes[i].set_title(f'{f} activation function');
    axes[i].set_xlabel('Input x'); axes[i].set_ylabel('Output y')
    axes[i].set_ylim([-1.5, 1.5]);
plt.tight_layout()

Comments

Comments powered by Disqus