2017-08-04 4 views
0

これは一般的なエラーですが、この問題を理解できません。ここに私のコードです:Tensorflow:dtype floatでプレースホルダテンソル 'input_image'の値を入力する必要があります

あなたが見ることができるように、私はFC7特徴を抽出するVGG_16 pretrainedモデルを使用してい
def convert_image(url): 

    checkpoint_file = './vgg_16.ckpt' 

    input_tensor = tf.placeholder(tf.float32, shape=(None,224,224,3), name='input_image') 
    scaled_input_tensor = tf.scalar_mul((1.0/255), input_tensor) 
    scaled_input_tensor = tf.subtract(scaled_input_tensor, 0.5) 
    scaled_input_tensor = tf.multiply(scaled_input_tensor, 2.0) 

    #Load the model 
    sess = tf.Session() 
    arg_scope = vgg_arg_scope() 
    with slim.arg_scope(arg_scope): 
     logits, end_points = vgg_16(scaled_input_tensor, is_training=False) 
    saver = tf.train.Saver() 
    saver.restore(sess, checkpoint_file) 

    response = requests.get(url) 
    img = Image.open(BytesIO(response.content)) 
    im = np.array(img, dtype='float32') 
    im = im.reshape(-1,224,224,3) 

    features = sess.run(end_points['vgg_16/fc7'], feed_dict={input_tensor: im}) 
    sess.close() 
    return np.squeeze(features) 

。コードの約50%がURLからイメージを取得し、それを224x224x3に変換します。他の50%は実際に特徴表現を得るためにテンソルフローを行う。

問題は、このコードを初めて実行するときに問題なく動作します。しかし、2回目に、私は上記のエラーを取得します。もちろん、 "im"はfloat32です(このエラーが発生しても)。だから私はこの問題が2回目にこの機能を実行したときに何かを混乱させることと関係していると思います。私が推測しなければならないのは、「保護者」の働き方と関係がありますが、私は何を正確に把握することができませんでした。

アイデア?

答えて

1

エラーは、おそらく、VGGモデル内の入力プレースホルダを使用する代わりに、input_tensorを再定義したことが原因です。入力画像imをネットワークに送る直前に変換を適用することができます。

さらに、それぞれの画像ごとにモデルを読み込みます。 代わりに、モデルを一度読み込み、ループ内の画像のリストを繰り返し処理します。 このようなもの:

def convert_images(url_list): 
    # Load the TF model 
    #..... 
    # Session, etc. 

    # Now, go over the list of images one by one 
    for url in url_list: 
     image = ... # get image 
     features = session.run(...) # extract features 
+0

これはコード作成の方がはるかに優れていることに同意します。実際にコード全体をリファクタリングして、モデルを複数回読み込む必要はありません。 なぜ私のコードが書かれたとおりに間違っていたのか分かりました。 – anon

+0

実際の問題は、モデルですでに使用可能な入力プレースホルダを使用する代わりに、入力(input_tensor)を再定義することです。ですから、あなたのモデルの中の 'input'を見て、それに画像を与えるべきです。 – Blackberry

関連する問題