2017-05-24 7 views
0

私はGANに取り組んでおり、HyperGANを使ってアルゴリズムを実装することに決めました。 TensorFlowを使ったDCGANのラッパーです。 HyperGANはTFのチェックポイントメソッドを使用して出力を保存します。テンソルフローでのモデルの出力値と入力値はどのようにして得られますか?

その後、私が使用して負荷にモデルを実行しようとした:

import tensorflow as tf 
sess=tf.Session()  
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 
sess.run(tf.global_variables_initializer()) 

しかし、そのGAN以来、それは入力潜在ベクトルを必要とし、画像を出力します。これは

out_image = sess.run(last_node, feed_dict(input_node: value)) 

を使用して行われます。しかし、私はモデルをロードするので、私は最後のノードの名前であり、入力ノードプレースホルダの名前が何であるかわかりません。どのようにして最初にグラフを作成するために使用された名前を取得するのですか?私はTensorBoardを使って視覚化しようとしましたが、グラフが大きくなってしまいました。

答えて

0

あなたは、グラフ内テンソルのリストを印刷しようとする必要があります:グラフの最初の50個のノードを参照するために

with tf.Graph().as_default() as graph: 
.... 

count = 0 
for op in graph.get_operations(): 
    print op.values() 
    count+=1 
    if count == 50: 
     assert False 

、あなたはこのようなものが表示されます:

(<tf.Tensor 'import/Placeholder_only:0' shape=<unknown> dtype=float32>,) 
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_max:0' shape=() dtype=float32>,) 
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_min:0' shape=() dtype=float32>,) 
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_const:0' shape=(512,) dtype=quint8>,) 
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53:0' shape=(512,) dtype=float32>,) 

通常、ターミナルはターミナルで最初の入力ノード名が消えてしまうほど多くのテンソルを出力するので、カウントを入れます。

最後に、単に使用することをカウントするための行をコメントアウト:

#count = 0 
for op in graph.get_operations(): 
    print op.values() 
    #count+=1 
    #if count == 50: 
    # assert False 

をプリントアウト最後のいくつかのノード(すなわち、あなたの出力ノード)を取得します。

関連する問題