2016-11-25 11 views
3

Kerasディープネットをトレーニングするためのカスタム目的関数を作成したい。私は不均衡なデータの分類を研究しています。私はscikit-learnでF1スコアを多く使用しています。Kerasのカスタム目的にはTensorの評価が必要

(from sklearn.metric import f1_score) 

def F1Loss(y_true, y_pred): 

    return 1. - f1_score(y_true, y_pred) 

しかし、scikit-学ぶから、このf1_score方法はnumpyのを必要とします。トレーニングしながら、最小限に抑えるためにKerasのための損失関数/目的としてそれを使用するために - 私は、したがって、F1のメトリック(F1スコア1)を反転するというアイデアを持っていました配列またはリストを使用してF1スコアを計算します。私は、Tensorsが.eval()を使用して、それらの数が少ない配列の相手に評価される必要があることを発見しました。このタスクを実行するには、TensorFlowセッションが必要です。

Kerasが使用するセッションオブジェクトがわかりません。私はKerasバックエンドがどこかで定義された独自のセッションオブジェクトを持っていると仮定して、以下のコードを試してみましたが、これもうまくいきませんでした。

from keras import backend as K 

K.eval(y_true) 

確かに、これは瞬間Tensorflow私は本当にKerasの深い仕組みを理解していないので、暗闇の中でのショットでしたか。

私の質問は次のとおりです。y_truey_predテンソルを評価するにはどうすればよいですか?

答えて

2

あなたの問題はTheanoで不連続な目的を実行すると、古典的な問題です。これは2つの理由のは不可能だbeacuse:

  1. F1-スコアは不連続である:hereあなたは、ニューラルネットワークトレーニングの目的関数から予想されなければならないものを読むことができます。 F1-得点はこの条件を満たしていないので、ニューラルネットワークの訓練には使用できません。
  2. テンソルとナンシー配列の間に同等物はありません:これは基本的な問題です。テーラーテーラーは、学校方程式の中でxのようです。代数変数は、代入可能なオブジェクトと等価であるとは期待できません。一方、計算グラフの一部として、目的を計算するためにテンソル操作を提供する必要があります。そうでなければ、それを区別することはできません。ニューラルネットワークの通常の訓練方法のほとんどを不可能にするパラメータ。
+0

これは非常に役に立ちました!私はこれらの点を自分で考えていないだろう。ありがとうございました。私は別の解決策を探します。 – TBZ92

0

あなたは予測しているとnumpyの配列形式で実際のテンソルは、私はあなたがこのコードスニペットを使用できることを推測する場合:

model.fit_generator(train_generator, validation_data=val_generator, nb_val_samples=X_val.shape[0], 
         samples_per_epoch=X_train.shape[0], nb_epoch=nb_epoch, verbose=1, 
         callbacks=[model_checkpoint, reduce_lr, tb], max_q_size=1000) 
correct_prediction = tf.equal(tf.argmax(actual_tensor,1), tf.argmax(predicted_tensor,1)) 

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 

そしてkerasで

、私はあなたがこれを使用することができると思います

train_generatorとval_generatorは、トレーニング中にトレーニングと検証データを生成します。これは、トレーニング中に損失と精度をプリントします。このことができます

希望...

+0

多分私はあなたの答えを誤解しているかもしれませんが、私はそれらをnumpyの配列形式にしていません。私は実際にテンソル 'y_true'と' y_pred'をnumpy配列に変換する方法を探しています。私はnumpy配列形式でそれらを欲しいので、私はF1関数にフィードすることができます。これは単にスカラー番号を返します。 – TBZ92

+0

.eval()メソッドから返されるオブジェクトはnumpy配列で、http:// stackoverflowを確認してください。質問/ 34097281 /テンソルからテンポラリにテンソルを変換するテンソルフロー – sansingh

+0

はい、わかっています。だから私は明らかに私は私の質問の記述でこれを試していると明言した。セッションオブジェクトなしでは動作しません。Kerasがデフォルトで使用するセッションオブジェクトは何か分かりません。しかし、上記の答えで述べたように、F1メトリックを使用することは、数学的に有効なアプローチではないため、別の方法を試してみます。 – TBZ92