2017-06-08 5 views
1

私はチュートリアルでTensorboardに埋め込みを視覚化する方法を学んでいます。コード全体はhereです。彼は畳み込みニューラルネットワークを使ってモデルを訓練する。ここで埋め込みを構築するコアの一部である:embedding_inputネットワークの完全に接続された層であるDNNCLassifierを使用するときに埋め込みを視覚化する方法 - Tensorflow

embedding = tf.Variable(tf.zeros([1024, embedding_size]), name="test_embedding") 
assignment = embedding.assign(embedding_input) 

config = tf.contrib.tensorboard.plugins.projector.ProjectorConfig() 
embedding_config = config.embeddings.add() 
embedding_config.tensor_name = embedding.name 
embedding_config.sprite.image_path = LOGDIR + 'sprite_1024.png' 
embedding_config.metadata_path = LOGDIR + 'labels_1024.tsv' 
# Specify the width and height of a single thumbnail. 
embedding_config.sprite.single_image_dim.extend([28, 28]) 
tf.contrib.tensorboard.plugins.projector.visualize_embeddings(writer, config) 

私はDNNCLassifierオブジェクトを使用して、サイバネティック攻撃を分類しています。これは私のコードです。

def input_fn(data_set): 
    continuous_cols = {k: tf.constant(data_set[k].values) 
        for k in CONTINUOUS_COLUMNS} 
    categorical_cols = {k: tf.SparseTensor(
        indices=[[i, 0] for i in range(data_set[k].size)], 
        values=data_set[k].values, 
        dense_shape=[data_set[k].size, 1]) 
         for k in CATEGORICAL_COLUMNS} 
    feature_cols = dict(continuous_cols.items() + categorical_cols.items()) 
    label = tf.constant(data_set[LABEL_COLUMN].values) 

    return feature_cols, label 


def train_input_fn(): 
    return input_fn(training_set) 

def eval_input_fn(): 
    return input_fn(test_set) 

classifier = tf.contrib.learn.DNNClassifier(model_dir=model_dir, 
     feature_columns=deep_columns, 
     n_classes=5, 
     hidden_units=[12, 4], 
     label_keys=label_keys_values, 
     optimizer=tf.train.ProximalAdagradOptimizer(
      learning_rate=0.1, 
      l1_regularization_strength=0.001 
     )) 

classifier.fit(input_fn=train_input_fn, steps=200) 
results = classifier.evaluate(input_fn=eval_input_fn, steps=1) 

embedding_inputとしてそれを使用するDNNClassifierからニューラルネットワークの出力層を取得する方法はありますか?私はTensorboardに埋め込みとして攻撃の種類を視覚化したい。

答えて

1

一つの方法は、あなたがあなたのDNNClassifier

を初期化するときには、モニター用のparamsを渡すことができtf.contrib.learn.monitors.ValidationMonitor

を使用しているhttps://www.tensorflow.org/get_started/monitors

+0

ありがとう参照してください。しかし、私のケースではどのようなメトリックを使うのか分かりません(テンソルボードの埋め込みセクションで、nnの出力に基づいてクラスタ化された攻撃を可視化する) – Andrei

関連する問題