2017-06-16 6 views
3

いくつかの転送学習アルゴリズムを開発しようと思って、私はいくつかの訓練されたニューラルネットワークを使用してレイヤーを追加します。私はTensorflowとPythonを使用しています。Tensorflowモデルで低レイヤーを追加

Tensorflowの既存のグラフを使用するのは非常に一般的です。たとえば、metaGraphsを使用してグラフをインポートした後、ノードを追加して新しい高レイヤーを設定します。

vgg_saver = tf.train.import_meta_graph(dir + '/vgg/results/vgg-16.meta') 
# Access the graph 
vgg_graph = tf.get_default_graph() 

# Retrieve VGG inputs 
self.x_plh = vgg_graph.get_tensor_by_name('input:0') 

# Choose some node 
output_conv =vgg_graph.get_tensor_by_name('conv1_2:0') 

# Build further operations 
output_conv_shape = output_conv.get_shape().as_list() 
W1 = tf.get_variable('W1', shape=[1, 1, output_conv_shape[3], 32],initializer=tf.random_normal_initializer(stddev=1e-1)) 
b1 = tf.get_variable('b1', shape=[32], initializer=tf.constant_initializer(0.1)) 
z1 = tf.nn.conv2d(output_conv, W1, strides=[1, 1, 1, 1], padding='SAME') + b1 
a = tf.nn.relu(z1) 

その後トレーニングに、あなたはあなたの層に加えて、以下のすべてのそれらを使用します。たとえば、私はこのコードhereが見つかりました。また、いくつかのレイヤーをフリーズし、セッション中に訓練された変数をインポートすることもできます。

しかし、私のアプローチでは、入力と最初のレイヤーの間に新しいレイヤーを追加し、したがって、グラフの最下部にノードを追加するだけでは、入力の直後にノードを挿入する必要があります。

今まで私はテンソルフローでそれを行う便利な方法が見つかっていません。何かお考えですか?それとも不可能なのでしょうか?

ありがとうございます。

+0

一方的な方法で最初の答えが進む可能性がありますhttps://stackoverflow.com/questions/33748552/tensorflow-how-to-replace-a-node-in-a-calculation-graph –

答えて

1

グラフの既存のレイヤー間にレイヤーを挿入することはできませんが、途中で再配線してグラフをインポートすることはできます。 Pietro Tortellaが指摘したように、Tensorflow: How to replace a node in a calculation graph?のアプローチが有効です。次に例を示します。

import tensorflow as tf 

with tf.Graph().as_default() as g1: 
    input1 = tf.placeholder(dtype=tf.float32, name="input_1") 
    l1 = tf.multiply(input1, tf.constant(2.0), name="mult_1") 
    l2 = tf.multiply(l1, tf.constant(3.0), name="mult_2") 

g1_def = g1.as_graph_def() 

with tf.Graph().as_default() as new_g: 
    new_input = tf.placeholder(dtype=tf.float32, name="new_input") 
    op_to_insert = tf.add(new_input, tf.constant(4.0), name="inserted_op") 
    mult_2, = tf.import_graph_def(g1_def, input_map={"input_1": op_to_insert}, 
            return_elements=["mult_2"]) 

元のグラフはthisのように見え、輸入グラフがthisのように見えます。

あなたがtf.train.import_meta_graphを使用したい場合、あなたはまだ

input_map={"input_1": op_to_insert} 

kwargに渡すことができます。それはimport_graph_defに渡されます。

+0

古い入力はありません接続され、実行されません。まだそれらを取り除きたい場合は、ここをクリックしてくださいhttps://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md – iga

+0

ありがとう!私はこれをすぐに試してみます。 –

関連する問題