0
私は、和積ネットワークを実装し、このコードを実行しているtensorflow.map_fnによって発生したエラー:TensorArray dtypeはdoubleですが、Opはdtype floatを書き込もうとしています。
#SPN architecture
D = 10
K = 100
N = 3
def lang_model(sample):
#MODEL
#1st layer
prev_ivs = tf.one_hot(sample, K)
#2nd layer
S1 = tf.matmul(W1,prev_ivs, transpose_b=True)
#3rd layer
S2 = tf.transpose(tf.matmul(W2, tf.reshape(S1, [1, N*D]), transpose_b=True))
#4th layer
P1 = tf.multiply(S2, S2)
#5th layer
S3 = tf.add(tf.multiply(S2, W3_1), tf.multiply(P1, W3_2))
#6th layer
#output layer
prob_labels = tf.multiply(W4,S3)
norm_const = tf.matmul(W4,S3, transpose_b=True)
return (prob_labels, norm_const)
#input
inputs = tf.placeholder(tf.int64, shape=[None,N])
labels = tf.placeholder(tf.int64,shape=[None])
# sums weights
W1 = tf.get_variable('W1', initializer=tf.random_normal([D, K], stddev=0.03))
W2 = tf.get_variable('W2', initializer=tf.random_normal([K, D * N], stddev=0.03))
W3_1 = tf.get_variable('W3_1', initializer=tf.random_normal([K], stddev=0.03))
W3_2 = tf.get_variable('W3_2', initializer=tf.random_normal([K], stddev=0.03))
W4 = tf.get_variable('W4', initializer=tf.random_normal([1,K], stddev=0.03))
tf.get_variable_scope().reuse_variables()
batch_results = tf.map_fn(lang_model, inputs, dtype=(tf.float64, tf.float64))
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
input = np.array([[34, 59, 6],[5,95,13]])
label = [5, 76]
feed_dict ={inputs: input, labels: label}
test = sess.run(batch_results, feed_dict=feed_dict)
print(str(test))
と、次のエラーが立ち上がっています:
InvalidArgumentError(see above for traceback): TensorArray dtype is double but Op is trying to write dtype float.
[[Node: map/while/TensorArrayWrite_1/TensorArrayWriteV3 = TensorArrayWriteV3[T=DT_FLOAT, _class=["loc:@map/while/MatMul_3"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](map/while/TensorArrayWrite_1/TensorArrayWriteV3/Enter, map/while/Identity, map/while/MatMul_3, map/while/Switch_2:1)]]
Process finished with exit code 1
それは私のコードまたはそれの問題であれば、私は理解することはできませんtensorflowにおけるmap_fnの問題です。ダブルはtf.float64でなければなりませんか?だから私はlang_model内tf.int64するtf.float64し、それを唱えるためのプレースホルダを設定しようとしたが、ERROはまだ上昇
[OK]を、今では動作します!ありがとう! – iXB3