私は、トレーニングデータセットとしてCASIA(顔認識データセット)を使用して、TensorFlowスリムモデルvggを使用して分類するモデルを訓練しました。 LFWデータセットを使用してモデルをテストしたい、それは顔マッチングタスクです。 softmaxレイヤではなく、fc7/fc8のようなネット機能を抽出し、フィーチャ間の距離を比較して、それらが同じ人であるかどうかを判断する必要があります。 スリムモデルの機能を抽出するにはどうすればよいですか?フォワードランニング時にテンソルフロースリムモデルVGGからフィーチャーを抽出する方法は?
ここでは、トレーニングコードの一部です。
import tensorflow as tf
from tensorflow.contrib.slim.python.slim.nets import vgg
slim = tf.contrib.slim
FLAGS = tf.app.flags.FLAGS
def tower_loss(scope):
images, labels = read_and_decode()
with slim.arg_scope(vgg.vgg_arg_scope()):
logits, end_points = vgg.vgg_16(images, num_classes=FLAGS.num_classes)
_ = cal_loss(logits, labels)
losses = tf.get_collection('losses', scope)
total_loss = tf.add_n(losses, name='total_loss')
return total_loss