2017-08-31 40 views
0

これは、訓練されたモデルを保存し復元する方法の例です。 これは初心者に役立つことを願っています。tensorflow:訓練されたモデルを保存/復元する方法

relu活性化機能を備えた1つの隠れ層ニューラルネットワークを生成する。 (聞いたreluはSigmoidよりもはるかに良く、特に隠れたレイヤーが多いニューラルネットワークの方が優れていることが証明されています)

トレーニングデータは明らかにXORです。

電車や "tf_train_save.py"

import tensorflow as tf 
import numpy as np 

x = np.matrix([[0, 0], [0, 1], [1, 0], [1, 1]]) 
y = np.matrix([[0], [1], [1], [0]]) 

n_batch = x.shape[0] 
n_input = x.shape[1] 
n_hidden = 5 
n_classes = y.shape[1] 

X = tf.placeholder(tf.float32, [None, n_input], name="X") 
Y = tf.placeholder(tf.float32, [None, n_classes], name="Y") 

w_h = tf.Variable(tf.random_normal([n_input, n_hidden], stddev=0.01), tf.float32, name="w_h") 
w_o = tf.Variable(tf.random_normal([n_hidden, n_classes], stddev=0.01), tf.float32, name="w_o") 

l_h = tf.nn.relu(tf.matmul(X, w_h)) 
hypo = tf.nn.relu(tf.matmul(l_h, w_o), name="output") 

cost = tf.reduce_mean(tf.square(Y-hypo)) 
train = tf.train.GradientDescentOptimizer(0.1).minimize(cost) 

init = tf.global_variables_initializer() 

with tf.Session() as sess: 
    sess.run(init) 

    for epoch in range(1000): 
     for i in range(4): 
      sess.run(train, feed_dict = {X:x[i,:], Y:y[i,:]}) 

    result = sess.run([hypo, tf.floor(hypo+0.5)], feed_dict={X:x}) 

    print(*result[0]) 
    print(*result[1]) 

    output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"]) 
    tf.train.write_graph(output_graph_def, "./logs/mp_logs", "test.pb", False) 

ロード "tf_load.py"

import tensorflow as tf 
from tensorflow.python.platform import gfile 
import numpy as np 

x = np.matrix([[0, 0], [0, 1], [1, 0], [1, 1]]) 
y = np.matrix([[0], [1], [1], [0]]) 

with gfile.FastGFile("./logs/mp_logs/test.pb",'rb') as f: 
    graph_def = tf.GraphDef() 
    graph_def.ParseFromString(f.read()) 
    tf.import_graph_def(graph_def, name='') 

with tf.Session() as sess: 
    X = sess.graph.get_tensor_by_name("X:0") 
    print(X) 
    output = sess.graph.get_tensor_by_name("output:0") 
    print(output) 

    tf.global_variables_initializer().run() 

    result = sess.run([output, tf.floor(output+0.5)], feed_dict={X:x}) 

    print(*result[0]) 
    print(*result[1]) 

が簡単な方法があるだろう救いますか?

+0

あなたの質問のタイトルはあなたが求めているものと一致していないようです。タイトルの質問を想定して、あなたのコーディングはあなたが期待していることをしていますか?私はロードスクリプトの初期化について不思議です。 –

+0

あなたが読み込んだウェイト変数netherを保存するので、コードが正しくありません。これを見てくださいhttps://stackoverflow.com/questions/33759623/tensorflow-how-to-save-restore-a-model?rq=1 –

+0

@ EricPlatonそれは動作します。もっと簡単な方法があれば私はちょうど興味があった。テンソル名も保存しているようです。 –

答えて

0

あなたはconvert_variables_to_constantsを使用しています。あなたは本当にトレーニング面で良いです。通行人のための注意、そのAPIはv1.0に登場しました(もし私が少しのAPIを追跡したと誤解されていなければ)。

ロード側では、最小のコードは1つのコマンドが短いと思います。すべての変数を定数に変換した場合、復元時に初期化する変数はありません。したがって、行:

tf.global_variables_initializer().run() 

何もしません。しかし、v1.3のdocsから:

ただし、var_listが空の場合、関数は実行可能なOpを返します。 Opは効果がありません。

負荷スクリプトにはグローバルvaribaleを有していない、とtf.global_variables_initializer()tf.variables_initializer(tf.global_variables())と同等であるため、操作は何もしません。

+1

私はテンソルの名前を 'input'や 'output'のように扱わずにリストアすることを期待していました。例が見つかりませんでした。 VGGishのソースコードを読んでいる間に可能だと思いました。しかし、私はそれを誤解しました。グラフを定義する関数を作成し、関数の生成と復元の両面で使用しました。 私は同じことをしなければならず、グラフファイルとpyファイルを一緒に扱わなければなりません –

関連する問題