2017-05-16 8 views
0

私は、トレーニングデータセットとしてCASIA(顔認識データセット)を使用して、TensorFlowスリムモデルvggを使用して分類するモデルを訓練しました。 LFWデータセットを使用してモデルをテストしたい、それは顔マッチングタスクです。 softmaxレイヤではなく、fc7/fc8のようなネット機能を抽出し、フィーチャ間の距離を比較して、それらが同じ人であるかどうかを判断する必要があります。 スリムモデルの機能を抽出するにはどうすればよいですか?フォワードランニング時にテンソルフロースリムモデルVGGからフィーチャーを抽出する方法は?

ここでは、トレーニングコードの一部です。

import tensorflow as tf 
from tensorflow.contrib.slim.python.slim.nets import vgg 
slim = tf.contrib.slim 
FLAGS = tf.app.flags.FLAGS 

def tower_loss(scope): 
    images, labels = read_and_decode() 
    with slim.arg_scope(vgg.vgg_arg_scope()): 
     logits, end_points = vgg.vgg_16(images, num_classes=FLAGS.num_classes) 
    _ = cal_loss(logits, labels) 
    losses = tf.get_collection('losses', scope) 
    total_loss = tf.add_n(losses, name='total_loss') 
    return total_loss 

答えて

0

あなたが抽出したい特定の機能のtf.get_default_graph().get_tensor_by_name("VGG16/fc16:0")または任意のテンソル名を使用して試すことができます。あなたが抽出されているテンソルの名前を確認するには

、あなたは彼らがあなたが取得している項目はテンソルであることを示すのような名前の末尾に:0を置くことを忘れないでください

を試すことができます。

0

スリムモデルのend_pointsを取得して機能を抽出します。

関連する問題