2017-09-14 10 views
1

訓練を受けたMNIST caffe2チュートリアルモデルのC++結果を複製するのに苦労しています。私がしたのは、私は少しMNISTのpythonチュートリアル(利用可能なコードhere)を変更していると、Python側ですべて正常に動作します。Caffe2:C++でのMNISTチュートリアルモデルの読み込み方法と使用方法

私がmnist.pyを実行すると、ネット定義と初期化を持つ2つの ".pb"ファイルが得られます。

timg = np.fromfile('test_img.dat', dtype=np.uint8).reshape([28,28]) 
workspace.FeedBlob('data', (timg/256.).reshape([1,1,28,28]).astype(np.float32)) 
workspace.RunNet(net_def.name) 
workspace.FetchBlob('softmax') 
array([[ 1.23242417e-05, 6.76146897e-07, 9.01260137e-06, 
     1.60285403e-04, 9.54966026e-07, 6.82772861e-06, 
     2.20508967e-09, 9.99059498e-01, 2.71651220e-06, 
     7.47664250e-04]], dtype=float32) 

をだから、テスト画像は「7」であるかなり確信している(そしてそれは正しいです):私はpythonの側で、このネットをロードし、DBからいくつかの画像でそれを養うなら、私は正しい予測を取得します。

しかし、私はC++から同じ結果を得ることができません。私は、それが他のプロジェクト(herehere)で行われている方法を見て撮影したし、次が出ている:

C++ネットの初期化

QByteArray img_bytes; // where the raw image bytes are kept (size 28x28) 
caffe2::NetDef init_net, predict_net; 
caffe2::TensorCPU input; 
// predictor and it's input/output vectors 
std::unique_ptr<caffe2::Predictor> predictor; 
caffe2::Predictor::TensorVector input_vec; 
caffe2::Predictor::TensorVector output_vec; 
... 
QFile f("mnist_init_net.pb"); 

... 
auto barr = f.readAll(); 
if (! init_net.ParseFromArray(barr.data(), barr.size())) { 

... 
f.setFileName("mnist_predict_net.pb"); 

... 
barr = f.readAll(); 
if (! predict_net.ParseFromArray(barr.data(), barr.size())) { 

... 
predictor.reset(new caffe2::Predictor(init_net, predict_net)); 
input.Resize(std::vector<int>{{1, 1, IMG_H, IMG_W}}); 
input_vec.resize(1, &input); 

この初期化は問題なく動作します。

float* data = input.mutable_data<float>(); 
for (int i = 0; i < img_bytes.size(); ++i) 
    *data++ = float(img_bytes[i])/256.f; 

をし、最終的に私は、予測フィード::デプロイネットワークはスケーリング、私はこの(上記のPythonスニペットと同じ)を行う必要が浮かぶようにキャストしていないと私が行うので、次のことのように

を私は、同じファイルに乗る
if (! predictor->run(input_vec, &output_vec) || output_vec.size() < 1 
              || output_vec[0]->size() != 10) 
... 

結果が「7」は17%(ない99.9%)であり、残りのカテゴリが5〜10%のまわりにあるということです。

現在、私は立ち往生していますが、どこに問題があるかわからないので、ヒント/ヒント/ポインターをお読みください。

答えて

2

それはCaffe2の私の使い方ではなく、私の 前処理に問題がないことが判明。 img_bytesための基本的なタイプchar及び(GCC)にデフォルトで以来とQByteArray charは署名種類この変換&スケーリングされるいくつかの負の値(代わりの範囲内のフロートをもたらし

*data++ = float(img_bytes[i])/256.f; 

[0,1] )。正しいバージョンは次のとおりです。

​​
関連する問題