2016-03-23 10 views
4

TensorFlowに典型的なCNNモデルがあるとします。TensorFlow:異なる入力テンソルでネットワークを再実行しますか?

def inference(images): 
    # images: 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 
    conv_1 = conv_layer(images, 64, 7, 2) 
    pool_2 = pooling_layer(conv_1, 2, 2) 
    conv_3 = conv_layer(pool_2, 192, 3, 1) 
    pool_4 = pooling_layer(conv_3, 2, 2) 
    ... 
    conv_28 = conv_layer(conv_27, 1024, 3, 1) 
    fc_29 = fc_layer(conv_28, 512) 
    fc_30 = fc_layer(fc_29, 4096) 
    return fc_30 

典型的なフォワード・パスは次のように行うことができる:

images = input() 
logits = inference(images) 
output = sess.run([logits]) 

は今私のinput機能は今left_imagesright_images(ステレオカメラ)、引数のペアを返すとします。 right_imagesconv_28まで、left_imagesfc_30まで実行したいです。したがって、この

images = tf.placeholder(tf.float32, [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3]) 
left_images, right_images = input() 
conv_28, fc_30 = inference(images) 
right_images_val = sess.run([conv_28], feed_dict={images: right_images}) 
left_images_val = sess.run([fc_30], feed_dict={images: left_images}) 

のようなものしかし、これは私が当時TensorFlowにそれを養うためにinputsを評価することを避けるためにしたい

TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, or numpy ndarrays.

で失敗します。異なる引数でinferenceを2回呼び出すと、conv_layerなどの関数が変数を作成するため、機能しません。

異なる入力テンソルでネットワークを再実行することはできますか?

+0

"イメージ"の宣言方法がわかりません。それはtf.placeholderですか?もしそうでなければ、それはすべきです。 – RaduK

答えて

2

Tensorflow shared Variablesはあなたが探しているものです。推測ではのすべてのコールをtf.get_variable()に置き換えます。次に、実行することができます。

images_left, images_right = input() 
with tf.variable_scope("logits") as scope: 
    logits_left = inference(images_left) 
    scope.reuse_variables() 
    logits_right = inference(images_right) 
output = sess.run([logits_left, logits_right]) 

変数は、2回目の推論の呼び出しで再度作成されません。左右の画像は同じ重みを使用して処理されます。また私のTensorflow CNN training toolkitをチェックしてください(trainingコードを見てください)。このテクニックを使用して、同じTensorFlowグラフで検証とトレーニングの転送を実行します。

+1

FYIあなたの 'train.py'リンクが壊れています。 – dimo414

+0

thx!それを私が直した。 – MarvMind

+1

変更のために関連するコンテンツはもう表示されません。 Protip:Githubのリンクには常に[commit SHA](https://github.com/TensorVision/TensorVision/blob/25b2f634550d12987f247ac1d0c5db6824bb8c81/bin/tv-train)が含まれています – oseiskar

関連する問題