2017-07-11 11 views
0

次のコードをtensorflow 1.2に変換しようとしていますが、お返事ありがとうございます。"tf.nn.dynamic_rnn"を実行すると、TypeErrorが発生します。予想されるint32、代わりにタイプ '_Message'のテンソルが含まれています

out_gt, controller_final_state_gt = tf.nn.dynamic_rnn(
cell=cell_with_ground_truth, 
inputs=rnn_inputs_with_ground_truth, 
sequence_length=[SEQ_LEN]*BATCH_SIZE, 
initial_state=controller_initial_state_gt, 
dtype=tf.float32, 
swap_memory=True, 
time_major=False) 

すべての入力変数:

cell_with_ground_truth:<__main__.SamplingRNNCell object at 0x7f3f88383250> 
rnn_inputs_with_ground_truth:(<tf.Tensor 'dropout/mul:0' shape=(4, 10, 128) dtype=float32>, <tf.Tensor 'div:0' shape=(4, 10, 3) dtype=float32>) 
[SEQ_LEN]:[10] 
BATCH_SIZE:4 
controller_initial_state_gt:(<tf.Tensor 'Identity_3:0' shape=(4, 3) dtype=float32>, LSTMStateTuple(c=<tf.Tensor 'Identity_4:0' shape=(4, 32) dtype=float32>, h=<tf.Tensor 'Identity_5:0' shape=(4, 32) dtype=float32>)) 

エラーMessgae:

--------------------------------------------------------------------------- 
TypeError         Traceback (most recent call last) 
<ipython-input-9-1911ca77ee35> in <module>() 
    53   out_gt, controller_final_state_gt = tf.nn.dynamic_rnn(cell=cell_with_ground_truth, inputs=rnn_inputs_with_ground_truth, 
    54       sequence_length=[SEQ_LEN]*BATCH_SIZE, initial_state=controller_initial_state_gt, dtype=tf.float32, 
---> 55       swap_memory=True, time_major=False) 
    56 
    57  with tf.variable_scope("predictor", reuse=True): 

/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.pyc in dynamic_rnn(cell, inputs, sequence_length, initial_state, dtype, parallel_iterations, swap_memory, time_major, scope) 
    572   swap_memory=swap_memory, 
    573   sequence_length=sequence_length, 
--> 574   dtype=dtype) 
    575 
    576  # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. 

。 。

/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/tensor_util.pyc in _AssertCompatible(values, dtype) 
    300  else: 
    301  raise TypeError("Expected %s, got %s of type '%s' instead." % 
--> 302      (dtype.name, repr(mismatch), type(mismatch).__name__)) 
    303 
    304 

TypeError: Expected int32, got list containing Tensors of type '_Message' instead. 

答えて

0

この問題は解決しました。 TF 1.4 コンテキスト= tf.concat(1、[prev_output、visual_feats])、

:cell_with_ground_truthのオブジェクトが1つのSamplingRNNCell関数ここで

から来ているので

これは、私は1本のラインを見つけましたこのsolution

と、

同じ(1、[prev_output、visual_feats])= tf.concat 文脈:、それは次のように修正されるべきです

関連する問題