:問題は、私は、次のような3クラス分類問題の損失関数を定義しようとしてい
def func_loss(y_true, y_pred):
return -K.mean(K.prod(K.cast(K.argmax(y_pred, axis=1), K.floatx()) - 1.0, K.cast(K.argmax(y_true, axis=1), K.floatx()) - 1.0))
私y
は、次のようになります。 [[1,0,0], [0,1,0], [1,0,0], [0,0,1], ...]
私の3クラスラベルは直感的にクラス "-1"、 "0"、 "+1"のワンホットエンコーディングです。私は "+/- 1"の正しいラベリングを最大限にし、 "+/- 1"の間違ったラベリングを最小限に抑え、すべての "0"ラベルを無視する。
私は、この損失関数でモデルをコンパイルすると、これは私が得るものです:
Traceback (most recent call last):
File "", line 1, in
File "/usr/local/lib/python2.7/dist-packages/keras/models.py", line 547, in compile
**kwargs)
File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 622, in compile
sample_weight, mask)
File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 324, in weighted
score_array = fn(y_true, y_pred)
File "", line 2, in func_loss
File "/usr/local/lib/python2.7/dist-packages/keras/backend/tensorflow_backend.py", line 464, in prod
axis = _normalize_axis(axis, ndim(x))
File "/usr/local/lib/python2.7/dist-packages/keras/backend/tensorflow_backend.py", line 435, in _normalize_axis
if axis is not None and axis < 0:
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 547, in nonzero
raise TypeError("Using a tf.Tensor as a Python bool is not allowed. "
TypeError: Using a tf.Tensor as a Python bool is not allowed. Use if t is not None: instead of if t: to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.
私は、この損失関数と周りに少しの微調整をしようとしてきたが、私は、コンパイル時にそれらのすべては、いくつかのエラーになりますモデル。私はこのことがどのように機能するかについて基本的な誤解があるかもしれないと思う。助けてもらえますか?
編集: 新しい損失関数:
def func_loss(y_true, y_pred):
return -K.mean((K.cast(K.argmax(y_pred, axis=1), K.floatx()) - 1.0)* (K.cast(K.argmax(y_true, axis=1), K.floatx()) - 1.0))
ありがとうございます。私は 'K.argmax(y_true、axis = -1) - 1'と' K.argmax(y_pred、axis = -1) - 1'の間で要素的な乗算を行うことを意味しました。あなたは私の新しい損失機能を見直すことができますか?私はこの損失関数でうまくコンパイルすることができますが、それは私に 'ValueError:None values supported 'を与えます。 –
@TengyuLiuあなたはこれまでの解決策を見つけましたか?私はコスト関数で 'K.argmax'を使うときに同様のエラーが出ます。私は' K.argmax'を使って問題が分かったと思います。 – adalca
@adalca私は正確な解決策を思い出すことができませんが、当時の私の間違いの1つは、私の損失機能が分化できないということでした。エラーが構文エラーから分離されていることが確かであれば、これは見るべき方向です。 –