suzuzusu日記

(´・ω・`)

kerasのLSTMのstatelessとstatefulの切り替え

kerasのLSTMなどのRNN系のモデルは状態を保持するstatefulなモデルと状態を保持しないstatelessなモデルがあります.その切り替え方法をメモしておきます.

切り替え方法

あらかじめ同じネットワークのstatelessなモデルとstatefulなモデルを別々に作成する.

from keras.models import Sequential
from keras.layers import Dense, LSTM


stateless_model = Sequential()
stateless_model.add(LSTM(hidden_unit, input_shape=(seq_len, 1)))
stateless_model.add(Dense(1))

stateful_model = Sequential()
stateful_model.add(LSTM(hidden_unit, batch_input_shape=(1, 1, 1), stateful=True))
stateful_model.add(Dense(1))

以下のようにネットワークが同じなので重みの共有が可能となる.これによって切り替えることができる.

# stateless to stateful
stateful_model.load_weights(stateless_model.get_weights())

# hdf5に保存して重みの共有も可能
stateful_model.load_weights('model.hdf5')

Example

statelessなモデルでトレーニングをしてstatefulなモデルでテストをするコードを以下に示します.再帰的に予測したいときはstatefulなモデルの方が便利です.

入力は以下のような三角波です.

f:id:suzuzusu:20191213011127p:plain

import numpy as np
import random
from keras.models import Sequential
from keras.layers import Dense, LSTM
from keras.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt

def hankel_matrix(x,seq_len):
    n = x.shape[0]
    stride = x.strides[0]
    return np.lib.stride_tricks.as_strided(x, shape=(n-seq_len+1, seq_len), strides=(stride,stride)).copy()

def triangle_wave(t, m=10000):
    arange = np.arange(1, m+1)
    return 8/(np.pi**2)*np.sum(np.sin(arange*np.pi/2)*np.sin(arange*t)/(arange**2))

seed = 0
hidden_unit = 100
random.seed(seed)
np.random.seed(seed)

# dataset
N = 500
x_n = []
ts = np.linspace(0, 500, N)
for t in ts:
    tmp = triangle_wave(t)
    x_n.append(tmp)
x_n = np.asarray(x_n)
seq_len = 100
H = hankel_matrix(x_n, seq_len=seq_len)
X = H[:-1,:].reshape(-1, seq_len, 1)
Y = H[1:,-1]

# plot data
plt.plot(x_n[:50])
plt.savefig('data.png')
plt.show()

cp_cb = ModelCheckpoint(filepath = 'model.hdf5', monitor='val_loss', verbose=1, save_best_only=True, mode='auto')

# stateless
train_model = Sequential()
train_model.add(LSTM(hidden_unit, input_shape=(seq_len, 1)))
train_model.add(Dense(1))
train_model.compile(loss='mean_squared_error', optimizer='adam')
train_model.fit(X, Y, epochs=20, callbacks = [cp_cb], validation_split=0.3, shuffle=True)

# stateful
predict_model = Sequential()
predict_model.add(LSTM(hidden_unit, batch_input_shape=(1, 1, 1), stateful=True))
predict_model.add(Dense(1))
predict_model.load_weights('model.hdf5')


input_len = 30
x_n = []
xs = []
ts = np.linspace(500, 1000, N)
for t in ts:
    tmp = triangle_wave(t)
    x_n.append(tmp)
    xs.append(tmp)
xs = xs[:input_len]

# test input
for o in xs:
    p = predict_model.predict(np.asarray([o]).reshape(1, 1, 1))[0,0]

# predict
predict_len = 100
xs.append(p)
for i in range(predict_len-1):
    p = predict_model.predict(np.asarray([p]).reshape(1, 1, 1))[0,0]
    xs.append(p)

# plot predict
plt.plot(xs[:input_len], label='input value')
plt.plot(np.arange(predict_len) + input_len, x_n[input_len:][:predict_len], label='true value')
plt.plot(np.arange(predict_len) + input_len, xs[input_len:], label='lstm predict value')
plt.xlabel('time')
plt.legend(loc='upper right')
plt.savefig('fig.png')
plt.show()

f:id:suzuzusu:20191213011913p:plain

Colab

gist.github.com

参考