2017-10-30 6 views
0

Pythonインターフェイスでは、ミニバッチの例を使用して、予測をnet([[1,2],[3,4],[5,6]])のようにすることができます。MXNetのC++でミニバッチのサンプルを予測する方法は?

しかし、C++ではこれを行う方法が見つかりません。

1つの例を予測するためにネットを呼び出すと10ms必要とします。 10000例がある場合は予測を行う必要があり、それは予測して、それは速く作る

void OneInputOneOutputPredict(PredictorHandle pred_hnd, std::vector<mx_float> vector_data, std::vector<mx_float> &output) 
{ 
    MXPredSetInput(pred_hnd, "data", vector_data.data(), vector_data.size()); 

    // Do Predict Forward 
    MXPredForward(pred_hnd); 

    mx_uint output_index = 0; 
    mx_uint *shape = 0; 
    mx_uint shape_len; 
    MXPredGetOutputShape(pred_hnd, output_index, &shape, &shape_len); 
    size_t size = 1; 
    for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i]; 

    std::vector<float> data(size); 
    assert(0 == MXPredGetOutput(pred_hnd, output_index, &(data[0]), size)); 
    output = data; 
} 

//very long time 
for(int step=0;step<10000;step++) 
    OneInputOneOutputPredict(pred_hnd, vector_data, vector_label); 

は、我々はC++のコードか何かの方法ベクトル化使用することができ100Sのですか?

答えて

0

私はseq2seq model.Ifデータを使用して[[1,2],[3,4],[5,6]]のように見える場合はミニバッチ(バッチサイズ3)

const mx_uint input_shape_data[4] = {3, static_cast<mx_uint>(data_len)}; 

を予測したい場合は、もともと input_shape_dataは今、この

const mx_uint input_shape_data[4] = {1, static_cast<mx_uint>(data_len)}; 

のように見えます、今すぐリストにフラット化するだけ{1,2,3,4,5,6}、すべてはOKです

関連する問題