2016-04-05 7 views
0

seq2seq.rnn_decoder()を使用してTensorflowで順次RNNを設定しようとしています。 rnn_decoder()が必要とする入力はテンソルのリストなので、これを生成するためにrank-3テンソルを渡し、tf.unpack()を使ってリストにします。この問題は、私が渡したfloat32配列がtf.unpack()によってfloat64テンソルに変わり、残りのモデルと互換性がなくなったときに発生します。ここで私は)犯人(tf.unpackであることを私に納得させるために一緒に入れたコードは次のとおりです。Tensorflow unpack()メソッドのfloat64へのキャストを防ぐ方法

inputDat = loader.getSequential(BATCH_SIZE) 
print(inputDat.shape) 

出力(BATCH_SIZEは、系列長が10で、5である):

(10, 5, 3) 

その後、私はこれを読み込むことができますTensorflowセッションでデータ:

sess = tf.InteractiveSession() 
input_tensor = tf.constant(inputDat.astype('float32'), dtype=tf.float32) 
print "Input tensor type: " + str(type(input_tensor.eval()[0,0,0])) 
input_tensor = tf.unpack(inputDat) 
print "Input tensor shape: " + str(len(input_tensor)) + "x" + str(input_tensor[0].eval().shape) 
print "Input tensor type: " + str(type(input_tensor[0].eval()[0,0])) 

出力:

Input tensor type: <type 'numpy.float32'> 
Input tensor shape: 10x(5, 3) 
Input tensor type: <type 'numpy.float64'> 

ここには何が起こっていますか? FORループを使用して順次エントリのそれぞれを反復処理して再キャストすると、これを行う方法が間違っているように見えます。リストのすべてのメンバーをキャストするTensorflow内のメソッドを見つけることができません。

答えて

0

forループは必要ありません。tf.cast()を使用できます。

例:

input_tensor = tf.unpack(inputDat) # result is 64-bit 
input_tensor = tf.cast(input_tensor, tf.float32) # now it's 32-bit 
関連する問題