2017-12-17 5 views
0

私は、和積ネットワークを実装し、このコードを実行しているtensorflow.map_fnによって発生したエラー:TensorArray dtypeはdoubleですが、Opはdtype floatを書き込もうとしています。

#SPN architecture 
D = 10 
K = 100 
N = 3 

def lang_model(sample): 
    #MODEL 

    #1st layer 
    prev_ivs = tf.one_hot(sample, K) 

    #2nd layer 
    S1 = tf.matmul(W1,prev_ivs, transpose_b=True) 

    #3rd layer 
    S2 = tf.transpose(tf.matmul(W2, tf.reshape(S1, [1, N*D]), transpose_b=True)) 

    #4th layer 
    P1 = tf.multiply(S2, S2) 

    #5th layer 
    S3 = tf.add(tf.multiply(S2, W3_1), tf.multiply(P1, W3_2)) 

    #6th layer 
    #output layer 
    prob_labels = tf.multiply(W4,S3) 
    norm_const = tf.matmul(W4,S3, transpose_b=True) 
    return (prob_labels, norm_const) 


#input 
inputs = tf.placeholder(tf.int64, shape=[None,N]) 

labels = tf.placeholder(tf.int64,shape=[None]) 

# sums weights 
W1 = tf.get_variable('W1', initializer=tf.random_normal([D, K], stddev=0.03)) 
W2 = tf.get_variable('W2', initializer=tf.random_normal([K, D * N], stddev=0.03)) 
W3_1 = tf.get_variable('W3_1', initializer=tf.random_normal([K], stddev=0.03)) 
W3_2 = tf.get_variable('W3_2', initializer=tf.random_normal([K], stddev=0.03)) 
W4 = tf.get_variable('W4', initializer=tf.random_normal([1,K], stddev=0.03)) 
tf.get_variable_scope().reuse_variables() 

batch_results = tf.map_fn(lang_model, inputs, dtype=(tf.float64, tf.float64)) 
init = tf.global_variables_initializer() 
with tf.Session() as sess: 
    sess.run(init) 
    input = np.array([[34, 59, 6],[5,95,13]]) 
    label = [5, 76] 
    feed_dict ={inputs: input, labels: label} 
    test = sess.run(batch_results, feed_dict=feed_dict) 
    print(str(test)) 

と、次のエラーが立ち上がっています:

InvalidArgumentError(see above for traceback): TensorArray dtype is double but Op is trying to write dtype float. 
[[Node: map/while/TensorArrayWrite_1/TensorArrayWriteV3 = TensorArrayWriteV3[T=DT_FLOAT, _class=["loc:@map/while/MatMul_3"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](map/while/TensorArrayWrite_1/TensorArrayWriteV3/Enter, map/while/Identity, map/while/MatMul_3, map/while/Switch_2:1)]] 
Process finished with exit code 1 

それは私のコードまたはそれの問題であれば、私は理解することはできませんtensorflowにおけるmap_fnの問題です。ダブルはtf.float64でなければなりませんか?だから私はlang_model内tf.int64するtf.float64し、それを唱えるためのプレースホルダを設定しようとしたが、ERROはまだ上昇

答えて

1

問題がlang_model戻りfloat32 sおよびmap_fnはそれがfloat64秒を返すと言うことです。問題を解決する一つの方法は、float64 sと関数の出力をキャストすることです:

prob_labels = tf.cast(tf.multiply(W4,S3), tf.float64) 
norm_const = tf.cast(tf.matmul(W4,S3, transpose_b=True), tf.float64) 
+0

[OK]を、今では動作します!ありがとう! – iXB3

関連する問題