2017-06-19 11 views
2

Tensorflowの入力パイプラインとペアになっている(すなわち、テンソルを介してネットワークの入力入力を与える)、事前に作成されたInception-V3モデルをKerasから使用します。 これは私のコードです:テンソル入力を使用するとKerasモデルの予測が変化する

my_image.jpgは私が分類したい任意の画像である
import tensorflow as tf 
from keras.preprocessing.image import load_img, img_to_array 
from keras.applications.inception_v3 import InceptionV3, decode_predictions, preprocess_input 
import numpy as np 

img_sample_filename = 'my_image.jpg' 
img = img_to_array(load_img(img_sample_filename, target_size=(299,299))) 
img = preprocess_input(img) 
img_tensor = tf.constant(img[None,:]) 

# WITH KERAS: 
model = InceptionV3() 
pred = model.predict(img[None,:]) 
pred = decode_predictions(np.asarray(pred)) #<------ correct prediction! 
print(pred) 

# WITH TF: 
model = InceptionV3(input_tensor=img_tensor) 
init = tf.global_variables_initializer() 

with tf.Session() as sess: 
    from keras import backend as K 
    K.set_session(sess) 

    sess.run(init) 
    pred = sess.run([model.output], feed_dict={K.learning_phase(): 0}) 

pred = decode_predictions(np.asarray(pred)[0]) 
print(pred)        #<------ wrong prediction! 

kerasのpredict関数を使用して予測を計算すると、結果は正しいです。しかし、イメージ配列からテンソルを作成し、そのテンソルをinput_tensor=...経由でモデルに送り、sess.run([model.output], ...)で予測を計算すると、結果は非常に間違っています。

異なる動作の理由は何ですか?このようにKerasネットワークを使用することはできませんか?

答えて

1

最後にInceptionV3コードを検索したところ、sess.run(init)InceptionV3のコンストラクタにロードされた重みを上書きします。 この問題が発生した-dirty-fixは、sess.run(init)の後にウェイトをリロードしています。

from keras.applications.inception_v3 import get_file, WEIGHTS_PATH 

with tf.Session() as sess: 
    from keras import backend as K 
    K.set_session(sess) 

    sess.run(init) 
    weights_path = get_file(
       'inception_v3_weights_tf_dim_ordering_tf_kernels.h5', 
       WEIGHTS_PATH, 
       cache_subdir='models', 
       md5_hash='9a0d58056eeedaa3f26cb7ebd46da564') 
    model.load_weights(weights_path) 
    pred = sess.run([model.output], feed_dict={K.learning_phase(): 0}) 

get_file()のパラメータはInceptionV3のコンストラクタから直接採取し、私の例では、image_data_format='channels_last'で完全なネットワークの重みを復元に固有のものですされています。 私はthis Github issueでこれに対応するより良い回避策があるかと尋ねました。私はこの情報を更新する必要があります。

+0

すべての変数を初期化する代わりに、変数のサブセットを初期化することができます(モデルの事前訓練された重みを含む)。 – abhinavkulkarni

関連する問題