Tensorflowセッションを通じてKerasモデルを使用しようとしています。しかし、結果はmodel.predict
とsess.run
と異なります。 Tensorflowセッションを通じてKersモデルと連携する方法はありますか?tf.session.runのネットワーク出力は、keras.Model.predictで取得したネットワーク出力と大きく異なります。
Tensorflowバージョン:1.4.0
Kerasバージョン:2.1.1
from sklearn.datasets.samples_generator import make_circles
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
import numpy as np
import tensorflow as tf
from keras import backend as K
sess = tf.Session()
K.tensorflow_backend.set_session(sess)
X, y = make_circles(n_samples=1000,
noise=0.1,
factor=0.2,
random_state=0)
model = Sequential()
model.add(Dense(4, input_shape=(2,), activation='tanh'))
model.add(Dense(1, activation='sigmoid'))
model.compile(SGD(lr=0.5), 'binary_crossentropy', metrics=['accuracy'])
model.fit(X, y, epochs=20)
print("Keras model. First prediction: " + str(model.predict(np.c_[0, 0])))
print("Keras model. Second prediction: " + str(model.predict(np.c_[1.5, 1.5])))
with sess.as_default():
y_tensor = model.outputs[0]
x_tensor = model.inputs[0]
sess.run(tf.global_variables_initializer())
print("TF model. First prediction: " + str(sess.run(y_tensor, feed_dict={x_tensor: np.c_[0, 0]})))
print("TF model. Second prediction: " + str(sess.run(y_tensor, feed_dict={x_tensor: np.c_[1.5, 1.5]})))
https://www.tensorflow.org/api_docs/python/tf/keras – scarecrow