2017-05-01 7 views
1

私はcntk.metrics.classification_error()の正しい使い方を理解しようとしており、その真実に対して予測のバッチを確認するために使っています。CNTK classification_error()の使用方法は?

Python API docsに基づいて)以下のおもちゃ例:

array([[ 0., 0., 1.], 
     [ 0., 0., 1.], 
     [ 0., 0., 1.]], dtype=float32) 

私が表示された正方行列ではなく、ベクトルを得ることができる方法があります:

import numpy as np 
from cntk.metrics import classification_error 

predictions = np.asarray([[1., 2., 3., 4.],[1., 2., 3., 4.],[1., 2., 3., 4.]], dtype=np.float32) 
labels = np.asarray([[0., 0., 0., 1.],[0., 0., 0., 1.],[0., 0., 1., 0.]], dtype=np.float32) 
classification_error(predictions, labels).eval() 

には、以下の結果が得られます私は大きなバッチを処理したいと考えているので、非効率ですか?

classification_error()を呼び出すときに、私はaxisキーワードを使用してみましたが、私はaxis=0またはaxis=1を設定するかどうか、私は空の結果を取得します。

答えて

2

これは、CNTKがユーザーフレンドリーになるために発生します。タイプについて混乱してしまいます.--)分類エラーが正しくないために分かります。

タイピング情報を少し追加すると、セマンティクスが正しく取得されます。

p = C.input(4) 
y = C.input(4) 
classification_error(p, y).eval({p:predictions, y:labels}) 
array([[ 0.], 
     [ 0.], 
     [ 1.]], dtype=float32) 

私たちは混乱を避けるための修正を行います。

+0

ありがとうございました!私は関数の入力を作成するべきだと思っていたはずですが、ドキュメント内の例ではそれを明示的にしていないので、私の考えは変わりませんでした。 –

関連する問題