2016-08-01 4 views
3

パフォーマンスはかなり良いニューラルネットワークを構築しました。非Python環境でモデルをレプリケートしたいと思います。Python以外のレプリケーションでTensorFlowのウェイトとバイアスを保存またはエクスポートします

sess = tf.InteractiveSession() 
x = tf.placeholder(tf.float32, shape=[None, 23]) 
y_ = tf.placeholder(tf.float32, shape=[None, 2]) 
W = tf.Variable(tf.zeros([23,2])) 
b = tf.Variable(tf.zeros([2])) 
sess.run(tf.initialize_all_variables()) 
y = tf.nn.softmax(tf.matmul(x,W) + b) 

重みと偏りの解読可能な.csvまたは.txtを取得するにはどうすればよいですか?

EDIT:以下は私の完全なスクリプトです:

import csv 
import numpy 
import tensorflow as tf 

data = list(csv.reader(open("/Users/sjayaram/developer/TestApp/out/production/TestApp/data.csv"))) 
[[float(j) for j in i] for i in data] 
numpy.random.shuffle(data) 
results=data 

#delete results from data 
data = numpy.delete(data, [23, 24], 1) 
#delete data from results 
results = numpy.delete(results, range(23), 1) 

sess = tf.InteractiveSession() 
x = tf.placeholder(tf.float32, shape=[None, 23]) 
y_ = tf.placeholder(tf.float32, shape=[None, 2]) 
W = tf.Variable(tf.zeros([23,2])) 
b = tf.Variable(tf.zeros([2])) 
sess.run(tf.initialize_all_variables()) 
y = tf.nn.softmax(tf.matmul(x,W) + b) 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 

#train the model, saving 80 entries for testing 
#batch-size: 40 
for i in range(0, 3680, 40): 
    train_step.run(feed_dict={x: data[i:i+40], y_: results[i:i+40]}) 

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
print(accuracy.eval(feed_dict={x: data[3680:], y_: results[3680:]})) 

答えて

3

あなたがnumpyの配列として変数を取得し、テキストやCSVなどの内容を書き出すためにnumpy.savetxt()を使用することができます。

import numpy as np 

W_val, b_val = sess.run([W, b]) 

np.savetxt("W.csv", W_val, delimiter=",") 
np.savetxt("b.csv", b_val, delimiter=",") 

に留意されたいです。これは、distributed runtimeのTensorFlowのネイティブレプリケーションメカニズムを使用するほどパフォーマンスを向上させることはまずありません。

+0

生成された.csvファイルには数値データが含まれていません。数値ではなく "nan"だけです。 –

+0

numpy配列を単純に 'print'するとどうなりますか? (最も可能性の高い説明は、勾配の爆発などのためにあなたの重みがNaNになったということですが、あなたの小さなサンプルプログラムでは起こりません)。 – mrry

+0

"nan"の配列と同じものが得られます。 –

関連する問題