テンソルフロー(Caffeのもの)を使用し始めていますが、私は損失sparse_softmax_cross_entropy_with_logits
を使用しています。この関数は、onehotエンコーディングの代わりに0,1,...C-1
のようなラベルを受け入れます。さて、クラスラベルに応じて重み付けを使いたいと思います。 softmax_cross_entropy_with_logits
(1つのホットエンコーディング)を使用する場合、これはおそらく行列乗算を使って行うことができることを知っています。sparse_softmax_cross_entropy_with_logits
と同じことを行う方法はありますか?sparse_softmax_cross_entropy_with_logitsを使用してテンソルフローの加重クロスエントロピー損失を実装するにはどうすればよいですか?
6
A
答えて
1
クラスウェイトにはロジットが掛けられているため、sparse_softmax_cross_entropy_with_logitsにはまだ動作します。 「Tensorフローのクラス・アンバランス・バイナリ分類器の損失関数」については、this solutionを参照してください。サイドノートとして
、この方法は、重量を損失係数として作用
tf.nn.sparse_softmax_cross_entropy_with_logits.
を用いたクロスエントロピー損失のためである直接sparse_softmax_cross_entropy
tf.contrib.losses.sparse_softmax_cross_entropy(logits, labels, weight=1.0, scope=None)
に重みを通過することができます。スカラが提供されている場合、損失は単純に与えられた値でスケーリングされます。重みがサイズ[batch_size]のテンソルである場合、損失重みは対応する各サンプルに適用されます。
2
特にバイナリ分類では、weighted_cross_entropy_with_logits
があり、重み付きソフトマックスクロスエントロピーが計算されます。
は、高効率非重み付け操作(SparseSoftmaxXentWithLogitsOp
を参照)のためにテールになっているため、「プラグ可能」ではありません。SparseXentEigenImpl
マルチクラスのケースでは、あなたのオプションはワンホットエンコーディングに切り替えるか、または既に提案されているようにハッピーな方法でtf.losses.sparse_softmax_cross_entropy
損失関数を使用するか、現在のバッチのラベルに応じてウェイトを渡す必要があります。
2
import tensorflow as tf
import numpy as np
np.random.seed(123)
sess = tf.InteractiveSession()
# let's say we have the logits and labels of a batch of size 6 with 5 classes
logits = tf.constant(np.random.randint(0, 10, 30).reshape(6, 5), dtype=tf.float32)
labels = tf.constant(np.random.randint(0, 5, 6), dtype=tf.int32)
# specify some class weightings
class_weights = tf.constant([0.3, 0.1, 0.2, 0.3, 0.1])
# specify the weights for each sample in the batch (without having to compute the onehot label matrix)
weights = tf.gather(class_weights, labels)
# compute the loss
tf.losses.sparse_softmax_cross_entropy(labels, logits, weights).eval()
関連する問題
- 1. テンソルフローはどのように実装されていますか?
- 2. テンソルフローによるペアワイズ損失関数はどのように実装できますか?
- 3. テンソルフローの損失を計算するにはどうすればよいですか?
- 4. テンソルフローでマルチクラスヒンジ損失を実装する方法
- 5. テンソルフローでケラスのようなサブサンプルを実装するにはどうすればよいですか?
- 6. pytorch nn.CrossEntropyLoss()でのクロスエントロピー損失
- 7. テンソルフローでtf.string_split()を使用するにはどうすればよいですか?
- 8. ケラスのクロスエントロピー損失のために体重不均衡クラスを行うには?
- 9. 体重初期化の状態をテンソルフローに追加するにはどうすればよいですか?
- 10. のtypecriptでは、コールバックを使用してインターフェイスを実装するにはどうすればよいですか?
- 11. プロトコルを使用して同時実行スレッドを実装するにはどうすればよいですか?
- 12. Acceptヘッダーを尊重するようにIHttpActionResultを実装するにはどうすればよいですか?
- 13. tf.nn.softmax_cross_entropy_with_logitsはテンソルフローのsoftmaxクロスエントロピーをどのように計算できますか?
- 14. 損失関数のkerasテンソルでnumpy関数を使用するにはどうすればよいですか?
- 15. デフォルトのユーザーモデルを使用してFireLoopでresetPasswordを実装するにはどうすればよいですか?
- 16. 損失の多い接続でcapistranoタスクを実行するにはどうすればよいですか?
- 17. SWFobjectを使用して実装されたプレーヤーをカスタマイズするにはどうすればよいですか?
- 18. テンソルフローでマトリックスマスク演算を実装するにはどうすればよいですか?
- 19. パッケージを追加してSDKをホストツールとして実装するにはどうすればよいですか?
- 20. デリゲートを使用してCosmosの星評価を実装するにはどうすればよいですか?
- 21. 複数のパディングを使用してgridlayoutを実装するにはどうすればよいですか?
- 22. PHPを使用してこのシナリオを実装するにはどうすればよいですか?
- 23. フューテックスを使用してこのメソッドを実装するにはどうすればよいですか?
- 24. 以下のロジックをPythonコードを使用して実装するにはどうすればよいですか?
- 25. Scalaでvariadicメソッドを使用してJavaインターフェイスを実装するにはどうすればよいですか?
- 26. キャッシュを使用してAndroidでオフラインストレージを実装するにはどうすればよいですか?
- 27. Node.JSを使用してAzure上でCOMETを実装するにはどうすればよいですか?
- 28. drupal7でtheme_item_list()を使用してマルチレベルリストを実装するにはどうすればよいですか?
- 29. Javaを使用してRabbitMQでヘッダ交換を実装するにはどうすればよいですか?
- 30. C++で循環キューを使用してキューを実装するにはどうすればよいですか?
1つのホットラベルを回避する方法があった場合、私は思っていた。提供されるリンクでは、重みベクトルと1つのホットラベルの行列を乗算する必要があるが、まだありますので。別の方法は、長さバッチサイズのウェイトベクトルを直接使用することですが、その後、すべてのバッチに対してこのベクトルを計算する必要があります。どのように私はそれを定義することができます(それはラベルに依存する)ので、onehotラベルの行列を計算する必要はありませんか? –
私はこの答えが正しいとは思わない。 'tf.contrib.losses.sparse_softmax_cross_entropy'の重みは、クラスごとではなく、サンプルごとです。 – andong777
それは正しいです、それはちょうど迷惑です。各更新の重みを渡すと、現在の更新に含まれる特定のクラスに依存します。だから、あなたがサイズ3のバッチを持っていて、クラスが1,1,2だったら。そして、クラス1を50%で加重したければ、この損失関数を使用し、重み引数に値[0.5,0.5,1.0]のテンソルを渡します。それはあなたのクラスを効果的に重くするだろう...エレガント?いいえ。 –