Note
Go to the end to download the full example code. or to run this example in your browser via Binder
Implementing LSTM in tensorflow from scratch
The purpose of this notebook is to illustrate how to build an LSTM from scratch in Tensorflow. Although the Tensorflow has implementation of LSTM in Keras. But since it comes with a lot of implementation options, reading the code of Tensorflow for LSTM can be confusing at the start. Therefore here is vanilla implementation of LSTM in Tensorflow. It has been shown that the results of this vanilla LSTM are full reproducible with Keras’LSTM. This shows that the simple implementation of LSTM in Tensorflow just has four equations and a for loop through time.
import os
import random
import numpy as np
np.__version__
'1.21.6'
import tensorflow as tf
tf.__version__
'2.7.0'
def seed_all(seed):
"""reset seed for reproducibility"""
np.random.seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
if int(tf.__version__.split('.')[0]) == 1:
tf.compat.v1.random.set_random_seed(seed)
elif int(tf.__version__.split('.')[0]) > 1:
tf.random.set_seed(seed)
from tensorflow.keras.layers import Layer, Input, Dense
from tensorflow.keras.layers import LSTM as KLSTM
from tensorflow.keras.models import Model
from tensorflow.python.ops import array_ops
from tensorflow.python.keras import backend as K
assert tf.__version__ > "2.1", "results are not reproducible with Tensorflow below 2"
num_inputs = 3 # number of input features
lstm_units = 32
lookback_steps = 5 # also known as time_steps or sequence length
num_samples = 10 # length of x,y
class SimpleLSTM(Layer):
"""A simplified implementation of LSTM layer with keras
"""
def __init__(self, units, **kwargs):
super(SimpleLSTM, self).__init__(**kwargs)
self.activation = tf.nn.tanh
self.rec_activation = tf.nn.sigmoid
self.units = units
def call(self, inputs):
initial_state = tf.zeros((10, self.units)) # todo
last_output, outputs, states = K.rnn(
self.cell,
inputs,
[initial_state, initial_state]
)
return last_output
def cell(self, inputs, states):
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state
k_i, k_f, k_c, k_o = array_ops.split(self.kernel, num_or_size_splits=4, axis=1)
x_i = K.dot(inputs, k_i)
x_f = K.dot(inputs, k_f)
x_c = K.dot(inputs, k_c)
x_o = K.dot(inputs, k_o)
i = self.rec_activation(x_i + K.dot(h_tm1, self.rec_kernel[:, :self.units]))
f = self.rec_activation(x_f + K.dot(h_tm1, self.rec_kernel[:, self.units:self.units * 2]))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1, self.rec_kernel[:, self.units * 2:self.units * 3]))
o = self.rec_activation(x_o + K.dot(h_tm1, self.rec_kernel[:, self.units * 3:]))
h = o * self.activation(c)
return h, [h, c]
def build(self, input_shape):
input_dim = input_shape[-1]
self.kernel = self.add_weight(
shape=(input_dim, self.units * 4),
name='kernel',
initializer="glorot_uniform")
self.rec_kernel = self.add_weight(
shape=(self.units, self.units * 4),
name='recurrent_kernel',
initializer="orthogonal")
self.built = True
return
inputs_tf = tf.range(150, dtype=tf.float32)
inputs_tf = tf.reshape(inputs_tf, (num_samples, lookback_steps, num_inputs))
seed_all(313)
lstm = SimpleLSTM(lstm_units)
h1 = lstm(inputs_tf)
h1_sum = tf.reduce_sum(h1)
print(K.eval(h1_sum))
25.964834
Now check the results of original lstm of Keras
seed_all(313)
lstm = KLSTM(lstm_units,
recurrent_activation="sigmoid",
unit_forget_bias=False,
use_bias=False,
)
h2 = lstm(inputs_tf)
h2_sum = tf.reduce_sum(h2)
print(K.eval(h2_sum))
25.964834
with bias
class LSTMWithBias(Layer):
"""A simplified implementation of LSTM layer with keras
"""
def __init__(self, units, use_bias=True, **kwargs):
super(LSTMWithBias, self).__init__(**kwargs)
self.activation = tf.nn.tanh
self.rec_activation = tf.nn.sigmoid
self.units = units
self.use_bias = use_bias
def call(self, inputs):
initial_state = tf.zeros((10, self.units)) # todo
last_output, outputs, states = K.rnn(
self.cell,
inputs,
[initial_state, initial_state]
)
return last_output
def cell(self, inputs, states):
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state
k_i, k_f, k_c, k_o = array_ops.split(self.kernel, num_or_size_splits=4, axis=1)
x_i = K.dot(inputs, k_i)
x_f = K.dot(inputs, k_f)
x_c = K.dot(inputs, k_c)
x_o = K.dot(inputs, k_o)
if self.use_bias:
b_i, b_f, b_c, b_o = array_ops.split(
self.bias, num_or_size_splits=4, axis=0)
x_i = K.bias_add(x_i, b_i)
x_f = K.bias_add(x_f, b_f)
x_c = K.bias_add(x_c, b_c)
x_o = K.bias_add(x_o, b_o)
i = self.rec_activation(x_i + K.dot(h_tm1, self.rec_kernel[:, :self.units]))
f = self.rec_activation(x_f + K.dot(h_tm1, self.rec_kernel[:, self.units:self.units * 2]))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1, self.rec_kernel[:, self.units * 2:self.units * 3]))
o = self.rec_activation(x_o + K.dot(h_tm1, self.rec_kernel[:, self.units * 3:]))
h = o * self.activation(c)
return h, [h, c]
def build(self, input_shape):
input_dim = input_shape[-1]
self.bias = self.add_weight(
shape=(self.units * 4,),
name='bias',
initializer="zeros")
self.kernel = self.add_weight(
shape=(input_dim, self.units * 4),
name='kernel',
initializer="glorot_uniform")
self.rec_kernel = self.add_weight(
shape=(self.units, self.units * 4),
name='recurrent_kernel',
initializer="orthogonal")
self.built = True
return
seed_all(313)
seed_all(313)
lstm = LSTMWithBias(lstm_units)
h1 = lstm(inputs_tf)
h1_sum = tf.reduce_sum(h1)
print(K.eval(h1_sum))
25.964834
seed_all(313)
lstm = KLSTM(lstm_units,
recurrent_activation="sigmoid",
unit_forget_bias=False)
h2 = lstm(inputs_tf)
h2_sum = tf.reduce_sum(h2)
print(K.eval(h2_sum))
25.964834
implementing temporal loop
# so far we had been using k.rnn() function to implement the temporal (for) loop
# of LSTM. Let's see what is inside it!
class LSTM(Layer):
"""A simplified implementation of LSTM layer with keras
"""
def __init__(self, units, use_bias=True, **kwargs):
super(LSTM, self).__init__(**kwargs)
self.activation = tf.nn.tanh
self.rec_activation = tf.nn.sigmoid
self.units = units
self.use_bias = use_bias
def call(self, inputs, **kwargs):
initial_state = tf.zeros((10, self.units)) # todo
inputs = tf.transpose(inputs, [1, 0, 2])
lookback, _, _ = inputs.shape
state = [initial_state, initial_state]
outputs, states = [], []
for time_step in range(lookback):
_out, state = self.cell(inputs[time_step], state)
outputs.append(_out)
states.append(state)
outputs = tf.stack(outputs)
states = tf.stack(states)
outputs = tf.transpose(outputs, [1, 0, 2])
last_output = outputs[:, -1]
return last_output
def cell(self, inputs, states):
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state
k_i, k_f, k_c, k_o = array_ops.split(self.kernel, num_or_size_splits=4, axis=1)
x_i = K.dot(inputs, k_i)
x_f = K.dot(inputs, k_f)
x_c = K.dot(inputs, k_c)
x_o = K.dot(inputs, k_o)
if self.use_bias:
b_i, b_f, b_c, b_o = array_ops.split(
self.bias, num_or_size_splits=4, axis=0)
x_i = K.bias_add(x_i, b_i)
x_f = K.bias_add(x_f, b_f)
x_c = K.bias_add(x_c, b_c)
x_o = K.bias_add(x_o, b_o)
i = self.rec_activation(x_i + K.dot(h_tm1, self.rec_kernel[:, :self.units]))
f = self.rec_activation(x_f + K.dot(h_tm1, self.rec_kernel[:, self.units:self.units * 2]))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1, self.rec_kernel[:, self.units * 2:self.units * 3]))
o = self.rec_activation(x_o + K.dot(h_tm1, self.rec_kernel[:, self.units * 3:]))
h = o * self.activation(c)
return h, [h, c]
def build(self, input_shape):
input_dim = input_shape[-1]
self.bias = self.add_weight(
shape=(self.units * 4,),
name='bias',
initializer="zeros")
self.kernel = self.add_weight(
shape=(input_dim, self.units * 4),
name='kernel',
initializer="glorot_uniform")
self.rec_kernel = self.add_weight(
shape=(self.units, self.units * 4),
name='recurrent_kernel',
initializer="orthogonal")
self.built = True
return
seed_all(313)
lstm = LSTM(lstm_units)
h1 = lstm(inputs_tf)
h1_sum = tf.reduce_sum(h1)
print(K.eval(h1_sum))
25.964834
seed_all(313)
lstm = KLSTM(lstm_units,
recurrent_activation="sigmoid",
unit_forget_bias=False)
h2 = lstm(inputs_tf)
h2_sum = tf.reduce_sum(h2)
print(K.eval(h2_sum))
25.964834
adding some more options
class LSTM(Layer):
"""A simplified implementation of LSTM layer with keras
"""
def __init__(
self,
units,
use_bias=True,
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
return_state=False,
return_sequences=False,
time_major=False,
** kwargs
):
super(LSTM, self).__init__(**kwargs)
self.activation = tf.nn.tanh
self.rec_activation = tf.nn.sigmoid
self.units = units
self.use_bias = use_bias
self.kernel_initializer = kernel_initializer
self.recurrent_initializer = recurrent_initializer
self.bias_initializer = bias_initializer
self.return_state = return_state
self.return_sequences = return_sequences
self.time_major=time_major
def call(self, inputs, **kwargs):
initial_state = tf.zeros((10, self.units)) # todo
if not self.time_major:
inputs = tf.transpose(inputs, [1, 0, 2])
lookback, _, _ = inputs.shape
state = [initial_state, initial_state]
outputs, states = [], []
for time_step in range(lookback):
_out, state = self.cell(inputs[time_step], state)
outputs.append(_out)
states.append(state)
outputs = tf.stack(outputs)
h_s = tf.stack([states[i][0] for i in range(lookback)])
c_s = tf.stack([states[i][1] for i in range(lookback)])
if not self.time_major:
outputs = tf.transpose(outputs, [1, 0, 2])
h_s = tf.transpose(h_s, [1, 0, 2])
c_s = tf.transpose(c_s, [1, 0, 2])
states = [h_s, c_s]
last_output = outputs[:, -1]
else:
states = [h_s, c_s]
last_output = outputs[-1]
h = last_output
if self.return_sequences:
h = outputs
if self.return_state:
return h, states
return h
def cell(self, inputs, states):
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state
k_i, k_f, k_c, k_o = array_ops.split(self.kernel, num_or_size_splits=4, axis=1)
x_i = K.dot(inputs, k_i)
x_f = K.dot(inputs, k_f)
x_c = K.dot(inputs, k_c)
x_o = K.dot(inputs, k_o)
if self.use_bias:
b_i, b_f, b_c, b_o = array_ops.split(
self.bias, num_or_size_splits=4, axis=0)
x_i = K.bias_add(x_i, b_i)
x_f = K.bias_add(x_f, b_f)
x_c = K.bias_add(x_c, b_c)
x_o = K.bias_add(x_o, b_o)
i = self.rec_activation(x_i + K.dot(h_tm1, self.rec_kernel[:, :self.units]))
f = self.rec_activation(x_f + K.dot(h_tm1, self.rec_kernel[:, self.units:self.units * 2]))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1, self.rec_kernel[:, self.units * 2:self.units * 3]))
o = self.rec_activation(x_o + K.dot(h_tm1, self.rec_kernel[:, self.units * 3:]))
h = o * self.activation(c)
return h, [h, c]
def build(self, input_shape):
input_dim = input_shape[-1]
self.bias = self.add_weight(
shape=(self.units * 4,),
name='bias',
initializer=self.bias_initializer)
self.kernel = self.add_weight(
shape=(input_dim, self.units * 4),
name='kernel',
initializer=self.kernel_initializer)
self.rec_kernel = self.add_weight(
shape=(self.units, self.units * 4),
name='recurrent_kernel',
initializer=self.recurrent_initializer)
self.built = True
return
seed_all(313)
lstm = LSTM(lstm_units, return_sequences=True)
h1 = lstm(inputs_tf)
h1_sum = tf.reduce_sum(h1)
print(K.eval(h1_sum))
115.15204
seed_all(313)
lstm = KLSTM(lstm_units,
recurrent_activation="sigmoid",
unit_forget_bias=False,
return_sequences=True
)
h2 = lstm(inputs_tf)
h2_sum = tf.reduce_sum(h2)
print(K.eval(h2_sum))
115.15204
builing Model and training
# It is possible to use our vanilla LSTM as a layer in Keras Model.
seed_all(313)
inp = Input(batch_shape=(10, lookback_steps, num_inputs))
lstm = LSTM(8)(inp)
out = Dense(1)(lstm)
model = Model(inputs=inp, outputs=out)
model.compile(loss='mse')
xx = np.random.random((100, lookback_steps, num_inputs))
y = np.random.random((100, 1))
h = model.fit(x=xx, y=y, batch_size=10, epochs=10)
Epoch 1/10
1/10 [==>...........................] - ETA: 5s - loss: 0.2666
10/10 [==============================] - 1s 806us/step - loss: 0.2250
Epoch 2/10
1/10 [==>...........................] - ETA: 0s - loss: 0.1784
10/10 [==============================] - 0s 688us/step - loss: 0.1670
Epoch 3/10
1/10 [==>...........................] - ETA: 0s - loss: 0.1832
10/10 [==============================] - 0s 676us/step - loss: 0.1329
Epoch 4/10
1/10 [==>...........................] - ETA: 0s - loss: 0.0520
10/10 [==============================] - 0s 838us/step - loss: 0.1106
Epoch 5/10
1/10 [==>...........................] - ETA: 0s - loss: 0.1114
10/10 [==============================] - 0s 668us/step - loss: 0.0976
Epoch 6/10
1/10 [==>...........................] - ETA: 0s - loss: 0.0948
10/10 [==============================] - 0s 695us/step - loss: 0.0920
Epoch 7/10
1/10 [==>...........................] - ETA: 0s - loss: 0.1118
10/10 [==============================] - 0s 679us/step - loss: 0.0903
Epoch 8/10
1/10 [==>...........................] - ETA: 0s - loss: 0.0413
10/10 [==============================] - 0s 697us/step - loss: 0.0900
Epoch 9/10
1/10 [==>...........................] - ETA: 0s - loss: 0.1029
10/10 [==============================] - 0s 715us/step - loss: 0.0893
Epoch 10/10
1/10 [==>...........................] - ETA: 0s - loss: 0.0631
10/10 [==============================] - 0s 717us/step - loss: 0.0886
print(np.sum(h.history['loss']))
1.183198258280754
# now compare the results by using original Keras LSTM i.e. KLSTM
seed_all(313)
inp = Input(batch_shape=(10, lookback_steps, num_inputs))
lstm = KLSTM(8,
recurrent_activation="sigmoid",
unit_forget_bias=False
)(inp)
out = Dense(1)(lstm)
model = Model(inputs=inp, outputs=out)
model.compile(loss='mse')
xx = np.random.random((100, lookback_steps, num_inputs))
y = np.random.random((100, 1))
h = model.fit(x=xx, y=y, batch_size=10, epochs=10)
Epoch 1/10
1/10 [==>...........................] - ETA: 6s - loss: 0.2666
10/10 [==============================] - 1s 1ms/step - loss: 0.2250
Epoch 2/10
1/10 [==>...........................] - ETA: 0s - loss: 0.1784
10/10 [==============================] - 0s 978us/step - loss: 0.1670
Epoch 3/10
1/10 [==>...........................] - ETA: 0s - loss: 0.1832
10/10 [==============================] - 0s 943us/step - loss: 0.1329
Epoch 4/10
1/10 [==>...........................] - ETA: 0s - loss: 0.0520
10/10 [==============================] - 0s 921us/step - loss: 0.1106
Epoch 5/10
1/10 [==>...........................] - ETA: 0s - loss: 0.1114
10/10 [==============================] - 0s 911us/step - loss: 0.0976
Epoch 6/10
1/10 [==>...........................] - ETA: 0s - loss: 0.0948
10/10 [==============================] - 0s 941us/step - loss: 0.0920
Epoch 7/10
1/10 [==>...........................] - ETA: 0s - loss: 0.1118
10/10 [==============================] - 0s 936us/step - loss: 0.0903
Epoch 8/10
1/10 [==>...........................] - ETA: 0s - loss: 0.0413
10/10 [==============================] - 0s 897us/step - loss: 0.0900
Epoch 9/10
1/10 [==>...........................] - ETA: 0s - loss: 0.1029
10/10 [==============================] - 0s 911us/step - loss: 0.0893
Epoch 10/10
1/10 [==>...........................] - ETA: 0s - loss: 0.0631
10/10 [==============================] - 0s 905us/step - loss: 0.0886
print(np.sum(h.history['loss']))
1.183198258280754
Total running time of the script: (0 minutes 2.033 seconds)