2016-11-27 14 views
0

テンソルフローでニューラルネットワークを作成しました。このネットワークはマルチラベルです。 Ergo:1つの入力セット(この場合は3つ)の複数の出力ラベルを予測しようとします。現在、このコードを使用して、ネットワークが3つのラベルをどれくらい正確に予測しているかをテストします。テンソルフローネットワークのテスト:マルチラベル分類のためのin_top_k()置換

_, indices_1 = tf.nn.top_k(prediction, 3) 
_, indices_2 = tf.nn.top_k(item_data, 3) 
correct = tf.equal(indices_1, indices_2) 
accuracy = tf.reduce_mean(tf.cast(correct, 'float')) 
percentage = accuracy.eval({champion_data:input_data, item_data:output_data}) 

このコードは正常に動作します。問題は、indices_1で見つかった上位3項目がindices_2の上位5画像に含まれているかどうかをテストするコードを作成しようとしていることです。私はtensorflowがin_top_k()メソッドを持っていることを知っていますが、私が知っている限り、マルチラベルを受け付けません。現在私はforループを使ってそれらを比較しようとしています:

_, indices_1 = tf.nn.top_k(prediction, 5) 
_, indices_2 = tf.nn.top_k(item_data, 3) 
indices_1 = tf.unpack(tf.transpose(indices_1, (1, 0))) 
indices_2 = tf.unpack(tf.transpose(indices_2, (1, 0))) 
correct = [] 
for element in indices_1: 
    for element_2 in indices_2: 
     if element == element_2: 
      correct.append(True) 
     else: 
      correct.append(False) 
accuracy = tf.reduce_mean(tf.cast(correct, 'float')) 
percentage = accuracy.eval({champion_data:input_data, item_data:output_data}) 

しかし、それは動作しません。コードは実行されますが、精度は常に0.0です。

だから私は二つの質問の1持っている:

1)in_top_kのための簡単な交換が(あり)私は、カスタムコードを書くのではなく、使用することができるマルチラベル分類を受け入れていますか?

2)そうでない場合、私は何が間違っているのですか?0.0の精度が得られますか?

答えて

0

あなたがチェックしている

correct = tf.equal(indices_1, indices_2) 

を行うとだけでなく、これらの2つの指標は、同じ要素が含まれているが、かどうか、彼らは同じ位置に同じ要素が含まれているかどうか。これはあなたが望むようには聞こえません。

setdiff1d opは、どのインデックスがindices_1にあるのかを示しますが、indices_2にはありません。これを使用してエラーをカウントできます。

私は間違った結果を得る原因となっている可能性があります正確性のチェックで厳格すぎると思う。

+0

ありがとうございました!これは正しい方向への大きな一歩です。私のバージョンにはまだsetdiff1dがないので、テンソルフローを更新する必要がありました。エラーをどのように数えることができますか?私はいくつかのことを試しましたが、setdif1dの多くの違いを知る方法を理解できないようです。 –

関連する問題