2016-05-10 11 views
1

Tensorflowフレームワークを回帰問題でテストしています。私の入力には多くの機能があり、出力は1つしかありません。 今のところ、それは働いていますが、私のコスト関数は私が望むものに対して正確ではないと思うので、私はかなり低い精度を持っています。予測出力ベクトルのためにそうニューラルネットワーク:コスト関数 "not differentiable"の設計/ Tensorflow

if (y_actual - y_predicted) < 0.1 * y_actual : 
    cost = 0 
else : 
    cost = 1 

[130, 35, 65, 43]実際の出力VS

私の目標は、実際の出力の10%の範囲内の出力を予測することで、それは、各出力のために言うことです[125, 10, 75, 40]のベクトル、私は2のコストを取得する必要があります(130と40は10%、35と65ではありません)

今のところ、私は二乗差を使用します。

cost = tf.reduce_sum(tf.squared_difference(y_predicted,y_)/(2*nb_batches)) 

だから実際に、私は、出力の和のためのかなり良い予測を得るが、私はそれが実際の出力の10%である場合の予測が正確であると考えるならば、私はわずか60%の精度を持っています。 私の精度は、そのように書かれている:

with tf.name_scope('accuracy'): 
    with tf.name_scope('correct_prediction'): 
     correct_prediction = tf.less_equal(tf.abs((y_predicted-y_)), 0.1*tf.abs((y_))) # True if y_diff <= 10% * y_actual 
    with tf.name_scope('accuracy'): 
     accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # count 1 if correct prediction is True ; else False ; Do the sum on all the data tested 
    tf.scalar_summary('accuracy', accuracy) 

私のような他のコスト関数書き込もうとしました:

cost = tf.reduce_sum(tf.cast((tf.greater_equal(tf.abs((y_predicted-y_)), 0.1*tf.abs(y_))),tf.float32) 

をしかし、私はエラーを取得:No gradients provided for any variableを。 tf.castの微分を得ることができないからだと私は思う。

私のニーズに合った損失機能を作成する適切な方法はありますか? 私が望む損失関数は「導出可能」ではないことを知っています。それを導出可能な関数に変換することが解決策になると思いますか?

私は...あなたの助けを

おかげで非常にはっきりしていたことを確認していません!

ポール

+0

私は、変数から損失関数へのパスがないときにエラーが発生すると思います。グラフ定義のバグである可能性があります。 – shekkizh

+0

はい、私はtf.castで、Tensorflowは変数に戻ることができないと思います...しかし、コストは浮動小数点にする必要があるので... –

+0

'tf.cast '? –

答えて

1

代わりに比較演算子を使用するのは、あなたが
tf.nn.relu(tf.abs(y_predicted-y_)-tf.abs(y_))
が微分コスト関数を取得するために使用することができます。

reluオペレーションはmax(0、x)を計算するので、実行していた作業の連続バージョンが得られます。主な違いは、あなたが離れていくにつれてペナルティを受けるため、バイナリプロパティを持たないことです(しかし、これは通常、回帰の目的です)。
これは、二乗損失よりも良い結果をもたらすはずです。

+0

あなたの答えをありがとう。 私はちょうどタイプミスがあると思う:私はそれがあるべきだと思う tf.nn.relu(tf.abs(y_predicted-y _) - 0.1 * tf.abs(y)) あなたが言ったように、 27 "は" 1 "の27個の誤差としてカウントされるが、精度はそうではない。 –

+0

実際には、私は二乗損失を使用したときよりも小さい精度を得ました...可能性は、非ヌル値の数を "数え"ますが、それを行い、npをテストするTF関数はわかりません.count_nonzero、それは動作していません。 また、エラー関数、つまりErf(alpha * x)を使用するのが適切かどうかは疑問でした。 –

関連する問題