0
import_graph_def()
でRNNネットワークを読み込み、推論してみます。 しかし、変数を取得するのにtf.trainable_variables()
を使用することはできません。次のコードでTensorflow、モデルエラーを読み込むためにimport_graph_def()を使用
、またtf.trainable_variables()
戻り、[](何もリスト) 、私はsaver = tf.train.Saver()
を使用し、tensorflowレポート
def eval_on_test(graph_path):
batch_size = 80
train_begin = 0
train_end = 3000
with tf.Graph().as_default() as graph:
with open(graph_path, 'rb') as f:
tf_graph = tf.GraphDef()
print("Loading graph_def from {}".format(graph_path))
tf_graph.ParseFromString(f.read())
return_elements = tf.import_graph_def(tf_graph, name="", return_elements=['input_x:0', 'output_y:0', 'pred:0', 'loss:0'])
X = return_elements[0]
Y = return_elements[1]
pred = return_elements[2]
loss = return_elements[3]
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
print("graph loaded, start testing")
with tf.Session(config=tf_config) as sess:
init_op = sess.graph.get_operation_by_name('init')
sess.run(init_op)
print(tf.trainable_variables())
batch_index,train_x,train_y=get_train_data(batch_size,time_step,train_begin,train_end)
for batch in range(len(batch_index)-1):
loss_ = sess.run(loss, feed_dict={X:train_x[batch_index[batch]:batch_index[batch+1]],Y:train_y[batch_index[batch]:batch_index[batch+1]]})
print(batch, loss_)
「を保存するためのノー変数」任意の助けいただければ幸いです。 Saver
は、この問題を解決するために、あなたはまた、すべてのコレクションを復元しますtf.train.import_meta_graph
を試すことができ、グラフ内の任意の変数を見つけることができない理由
ありがとうございました。それは私の質問を解決します。私はグラフを保存してグラフを復元しようとしていました。その問題を解決して、https://www.tensorflow.org/extend/tool_developers/ – zhangc