2017-05-30 18 views
0

私はマルチクラスの分類でうまく動作する深いCNNを持っています。チャレンジを「アップグレード」して、マルチラベル分類の問題を練習したいと思います。複数ラベルの分類:閾値の学習方法は?

tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=y_pred) 

しかし、私は奇妙な予測で終わる:

Prediction for Im1 : [ 0.59275776 0.08751075 0.37567005 0.1636796 0.42361438 0.08701646 

0.38991812 0.54468459 0.34593087 0.82790571]

私はシグモイドで私のソフトマックスを交換し、最小限に抑えるために自分のネットワークを訓練しようとしたそうするには

Prediction for Im1 : [ 0.52609032 0.07885984 0.45780018 0.04995904 0.32828355 0.07349177 

0.35400775 0.36479294 0.30002621 0.84438241

Prediction for Im1 : [ 0.58714485 0.03258472 0.3349618 0.03199361 0.54665488 0.02271551 

0.43719986 0.54638696 0.20344526 0.88144571]

だから、私は、試料がクラスをしteに属しているかどうかを決定するために私のネットワークは、各クラスのしきい値を学ぶようにしようと思いました。

だから私は私のコードにこれを追加しました:

initial = tf.truncated_normal([numberOfClasses], stddev=0.1) 
W_thresh = tf.Variable(initial) 

y_predict_thresh = int(y_predict > W_thresh) 

しかし、私はエラーがあります?

TypeError: int() argument must be a string or a number, not 'Tensor'. 

誰もが、このエラーを回避する方法(前進私を助けるためにどんな考えを持ってい私のデータセットが本当に不均衡であるという事実は、これらの「一定の」予測を引き起こすのでしょうか?マルチラベル分類のための他の提案?あなたはまだそれを必要としていますが、変換関数tf.to_int32, tf.to_int64をtensorflow使用できるかどうか分からない/

答えて

2

:私はちょうどしきい値を行うことはバックプロパゲーションのために本当にクールではないかもしれないことに気づい:

はあなたに

EDITありがとうございます。評価する前に、式はPythonのオブジェクトであるため、単にそれをint()にキャストすることはできません。

これは、何が必要ありません:

with tf.Session() as sess: 
    check = sess.run([tf.to_int64(W1 > W2)]) 
+0

はあなたansewerあちこちありがとうございます。私は実際に私の問題に対する解決策は閾値ではないが、不均衡なデータセットがあるという事実に取り組んでいることに気づいた。しかし、あなたの答えは私の "閾値"の問題では問題ありません。もう一度ありがとう:) –

+0

あなたは大歓迎です! ;-) – gionni

関連する問題