Activation Functions in Neural Networks

Goal

This post aims to introduce activation functions used in neural networks using pytorch. image

Reference

Libraries

In [47]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline 

import torch.nn as nn

Activation Functions

In this post, the following activation functions are introduced.

  • Sigmoid
  • Tanh
  • ReLU
  • Leaky ReLU

Create a input data

In [48]:
x = np.linspace(-5, 5, 100)
t_x = torch.tensor(x)
d_activations = {}

Sigmoid

Sigmoid(x)=11+exp(x)
In [49]:
d_activations['Sigmoid'] = nn.Sigmoid()
y = d_activations['Sigmoid'](t_x)

plt.plot(t_x.numpy(), y.numpy(), '.');
plt.title('Sigmoid activation function');
plt.xlabel('Input x'); plt.ylabel('Output y')
plt.ylim([-1.5, 1.5]);

Tanh

Tanh(x)=exexex+ex
In [50]:
d_activations['Tanh'] = nn.Tanh()
y = d_activations['Tanh'](t_x)

plt.plot(t_x.numpy(), y.numpy(), '.');
plt.title('Tanh activation function');
plt.xlabel('Input x'); plt.ylabel('Output y')
plt.ylim([-1.5, 1.5]);

ReLU

Relu stands for Rectified linear unit Relu(x)=max(0,x)

In [51]:
act_function = 'ReLU'
d_activations[act_function] = nn.ReLU()
y = d_activations[act_function](t_x)

plt.plot(t_x.numpy(), y.numpy(), '.');
plt.title(f'{act_function} activation function');
plt.xlabel('Input x'); plt.ylabel('Output y')
plt.ylim([-1.5, 1.5]);

Leaky ReLU

LeakyReLU()=max(0,x)+negative_slopemin(0,x)

negative_slope is a small value added to the output output when x<0

In [52]:
act_function = 'Leeky ReLU'
d_activations[act_function] = nn.LeakyReLU()
y = d_activations[act_function](t_x)

plt.plot(t_x.numpy(), y.numpy(), '.');
plt.title(f'{act_function} activation function');
plt.xlabel('Input x'); plt.ylabel('Output y')
plt.ylim([-1.5, 1.5]);

Summary

In [63]:
fig, axes = plt.subplots(1, len(d_activations), figsize=(16, 3))

for i, f in enumerate(d_activations.keys()):
    
    axes[i].plot(t_x.numpy(), d_activations[f](t_x).numpy(), '.');
    axes[i].set_title(f'{f} activation function');
    axes[i].set_xlabel('Input x'); axes[i].set_ylabel('Output y')
    axes[i].set_ylim([-1.5, 1.5]);
plt.tight_layout()

Comments