2017-10-17 21 views
0

私は、マルチラベル分類の問題に対して、所定のしきい値で完全一致の割合を表示するeval_metric_op関数を作成しようとしています。次の関数は、与えられたしきい値に基づいて0(完全一致なし)または1(完全一致)を返します。TensorFlowの「完全一致」eval_metric_opを作成するにはどうすればよいですか?

def exact_match(y_true, y_logits, threshold): 
    y_pred = np.round(y_logits-threshold+0.5) 
    return int(np.array_equal(y_true, y_pred)) 

y_true = np.array([1,1,0]) 
y_logits = np.array([0.67, 0.9, 0.55]) 

print(exact_match(y_true, y_logits, 0.5)) 
print(exact_match(y_true, y_logits, 0.6)) 

0.5の閾値関数は0 0.6の閾値が非常に正確である[1,1,0]の予測をもたらすを返すように間違っている[1,1,1]の予測をもたらします関数は1を返します。

私はこの関数をテンソルフロー評価メトリックに変換したいのですが、これを行うには誰にでもアドバイスできますか?

私は以下のtensorflowのOPSを使用して同じロジックを取得することができますが、私は、カスタムeval_metric_opにこれを作る方法を完全にはよく分からない:上記のコードは、0のexact_match_50になります

import tensorflow as tf 

def exact_match_fn(y_true, y_logits, threshold): 
    #pred = tf.equal(tf.round(y_logits), tf.round(y_true)) 
    predictions = tf.to_float(tf.greater_equal(y_logits, threshold)) 
    pred_match = tf.equal(predictions, tf.round(y_true)) 
    exact_match = tf.reduce_min(tf.to_float(pred_match)) 
    return exact_match 

graph = tf.Graph() 
with graph.as_default(): 
    y_true = tf.constant([1,1,0], dtype=tf.float32) 
    y_logits = tf.constant([0.67,0.9,0.55], dtype=tf.float32) 
    exact_match_50 = exact_match_fn(y_true, y_logits, 0.5) 
    exact_match_60 = exact_match_fn(y_true, y_logits, 0.6) 

sess = tf.InteractiveSession(graph=graph) 
print(sess.run([exact_match_50, exact_match_60])) 

(で少なくとも1予測は間違っています)、exact_match_60は1(すべてのラベルが正しい)です。

単にtf.contrib.metrics.streaming_mean()を使用するだけで十分ですか、それとも良い方法がありますか?あなたのexact_match_fnの出力は、評価のために使用することができるオペアンプである

tf.contrib.metrics.streaming_mean(exact_match(y_true, y_logits, threshold)) 

答えて

1

:私はこれを実装します。バッチで平均値を求めたい場合は、reduce_minを変更して、関連する軸を減らしてください。

など。あなたy_true/y_logits各形状(batch_size, n)

def exact_match_fn(y_true, y_logits, threshold): 
    #pred = tf.equal(tf.round(y_logits), tf.round(y_true)) 
    predictions = tf.to_float(tf.greater_equal(y_logits, threshold)) 
    pred_match = tf.equal(predictions, tf.round(y_true)) 
    exact_match = tf.reduce_min(tf.to_float(pred_match), axis=1) 
    return exact_match 


def exact_match_prop_fn(*args): 
    return tf.reduce_mean(exact_match_fn(*args)) 

を持っている場合、これはあなたのバッチの平均を与えるだろう。データセット全体の平均を求めたいのであれば、マッチ(またはcorrecttotalカウント)を集めて、セッション/テンソルフローの外側で評価しますが、streaming_meanはおそらくそれだけではありません。

+0

はい、これは間違いなく軸を追加する= 1 - その訂正のおかげで!私はこれの一部だと思います。私は私の実験から、streaming_mean()が総計を集めることを扱うので、ここで唯一欠けていると思います。 – reese0106

関連する問題