2017-04-10 8 views
2

私は訓練されたモデルの凍ったグラフを持っています。私はいつも同じ値を入力するtf.placeholderを持っています。プレースホルダを既存のグラフの定数で置き換えることはできますか?

代わりにtf.constantに置き換えることができるかどうか疑問に思っていました。 どういうわけか - 例があれば分かります!

編集:ここでは、私は推論を実行するために事前に訓練を受けた(他の人による)モデルを使用しています質問

を視覚化するために、コードでどのように見えるかです。モデルは、.pbの拡張子を持つ固定グラフファイルとしてローカルに格納されます。

コードは次のようになります。

# load graph 
graph = load_graph('frozen.pb') 
session = tf.Session(graph=graph) 

# Get input and output tensors 
images_placeholder = graph.get_tensor_by_name("input:0") 
output = graph.get_tensor_by_name("output:0") 
phase_train_placeholder = graph.get_tensor_by_name("phase_train:0") 

feed_dict = {images_placeholder: images, phase_train_placeholder: False} 

result = session.run(output, feed_dict=feed_dict) 

問題は、私はいつも私の目的のためにphase_train_placeholder: Falseを養うということですので、それはそのプレースホルダを排除し、tf.constant(False, dtype=bool, shape=[])

+1

コードのいくつかの構造を追加すると、情報に基づいた提案や助けが得られますか? – kmario23

+0

@ kmario23提案をいただきありがとうございます。質問を視覚化するためのコードをいくつか追加しました。 – CrowbarKZ

答えて

1
のようなものと交換することが可能です場合、私は思っていました

私は適切な方法を見つけることはできませんでしたが、グラフのdefを再構築し、代わりに必要なノードを置き換えることで、ハッキリなやり方でやりました。 thisコードに触発されています。ここで

は(スーパーハック、ご自身の責任で使用)コードです:

INPUT_GRAPH_DEF_FILE = 'path/to/file' 
OUTPUT_GRAPH_DEF_FILE = 'another/one' 

# Get NodeDef of a constant tensor we want to put in place of 
# the placeholder. 
# (There is probably a better way to do this) 
example_graph = tf.Graph() 
with tf.Session(graph=example_graph): 
    c = tf.constant(False, dtype=bool, shape=[], name='phase_train') 
    for node in example_graph.as_graph_def().node: 
     if node.name == 'phase_train': 
      c_def = node 

# load our graph 
graph = load_graph(INPUT_GRAPH_DEF_FILE) 
graph_def = graph.as_graph_def() 

# Create new graph, and rebuild it from original one 
# replacing phase train node def with constant 
new_graph_def = graph_pb2.GraphDef() 
for node in graph_def.node: 
    if node.name == 'phase_train': 
     new_graph_def.node.extend([c_def]) 
    else: 
     new_graph_def.node.extend([copy.deepcopy(node)]) 

# save new graph 
with tf.gfile.GFile(OUTPUT_GRAPH_DEF_FILE, "wb") as f: 
    f.write(new_graph_def.SerializeToString()) 
1

私は最近、上記の解答を書き換えなければなりませんでした。

import tensorflow as tf 
import sys 
from tensorflow.core.framework import graph_pb2 
import copy 


INPUT_GRAPH_DEF_FILE = sys.argv[1] 
OUTPUT_GRAPH_DEF_FILE = sys.argv[2] 

# load our graph 
def load_graph(filename): 
    graph_def = tf.GraphDef() 
    with tf.gfile.FastGFile(filename, 'rb') as f: 
     graph_def.ParseFromString(f.read()) 
    return graph_def 
graph_def = load_graph(INPUT_GRAPH_DEF_FILE) 

target_node_name = sys.argv[3] 
c = tf.constant(False, dtype=bool, shape=[], name=target_node_name) 

# Create new graph, and rebuild it from original one 
# replacing phase train node def with constant 
new_graph_def = graph_pb2.GraphDef() 
for node in graph_def.node: 
    if node.name == target_node_name: 
     new_graph_def.node.extend([c.op.node_def]) 
    else: 
     new_graph_def.node.extend([copy.deepcopy(node)]) 

# save new graph 
with tf.gfile.GFile(OUTPUT_GRAPH_DEF_FILE, "wb") as f: 
    f.write(new_graph_def.SerializeToString()) 
関連する問題