2017-11-13 2 views
1

mx.io.DataIterクラスを使用してカスタムデータイテレータを作成しました。 Gluonインターフェイスでこのデータイテレータを使用する最も簡単な方法は何ですか?Mxnet Gluonカスタムデータイテレータ

私はドキュメントを読んだので簡単には見つかりませんでした。私の考えの1つはイテレータとしてそれを使用し、次のように各バッチからデータとラベルを取得することでした。

for e in range(epochs): 
    train_iter.reset() 
    for batch_data in train_iter: 
     data = nd.concatenate(([d for d in batch_data.data])) 
     label = nd.concatenate(([l for l in batch_data.label])) 
     with autograd.record(): 
      output = net(data) 
      loss = softmax_cross_entropy(output, label) 
     loss.backward() 
     trainer.step(batch_size) 
     print(nd.mean(loss).asscalar()) 

しかし、これはバッチごとに連結する必要があるため、最適ではない可能性があります。

  1. これを達成する最適な方法は何ですか。すなわち、グルーオンのための簡単なカスタムイテレータを書く方法はありますか?

  2. 上記のケースでコンテキスト情報を追加するにはどうすればよいですか?

答えて

1

あなたのアプローチはうまくいくと思います。基本的にdatabatch_data.datalabelからbatch_data.labelに取得し、ネットワークに供給することができます。

なぜデータとラベルを連結する必要があるのか​​分かりません。おそらくネットワーク定義と関係があります。

データを分割して複数のGPUでトレーニングする必要がある場合は、gluon.utils.split_and_load関数を使用してその作業を行うことができます。

+0

ありがとうございます。私は、バッチデータを実装するためにはるかにクリーンな方法があることを望みました。グルーオン互換のAPIを持っていないことは、非常に奇妙で驚くようです! – krishnakamathk