2017-07-31 6 views
1

TensorflowのEstimator APIで無限にランダムデータを生成するinput_fnを作成することは可能ですか?Tensorflowのinput_fnで無限ランダムトレーニングデータを生成

これは私が望むものを基本的には次のとおりです。

def create_input_fn(function_to_generate_one_sample_with_label): 
    def _input_fn(): 
     ### some code ### 
     return feature_cols, labels 

私は、このようなEstimatorインスタンスに機能を使用したい:

def data_generator(): 
    features = ... generate a (random) feature vector ... 
    lablel = ... create suitable label ... 
    return features, labels 

input_fn = create_input_fn(data_generator) 
estimator.train(input_fn=input_fn, steps=ANY_NUMBER_OF_STEPS) 

ポイントがために訓練することができることです必要な数のステップを実行し、必要なトレーニングデータをオンザフライで生成します。これはモデル調整のためのもので、さまざまな複雑な訓練データを試すことができ、訓練データに合わせてモデルの能力を知ることができます。


編集 JKMが示唆したように、私はこのように、実際の発電機を使用してみました:

def create_input_fn(function, batch_size=100): 
    def create_generator(): 
     while True: 
      features = ... generate <batch_size> feature vectors ... 
      lablel = ... create <batch_size> labels ... 
      yield features, label 
    g = create_generator() 
    def _input_fn(): 
     return next(g) 
    return _input_fn 

私はそれを実行するために取得するバッチサイズを追加する必要がありました。今は実行されますが、input_fnは1回だけ呼び出されるため、新しいデータは生成されません。生成された最初の<batch_size>サンプルを列車に入れるだけです。提供されたinput_fnを使用してデータをリフレッシュするようエスティメータに指示する方法はありますか?

答えて

0

警告の言葉 - 私はではありません自分自身Tensorflowで動作します、私はAPIのドキュメントを外しています。

言われましたが、そこに邪魔がなければ、必要なことをすることができるはずです。ジェネレータをジェネレータ(返されるのではなく、機能とラベルを返す)にして、世代を無限ループにするだけです。たとえば、次のようになります。

def data_generator(): 
    while True: 
     #do generatey things here 
     yield feature, labels 

この関数は、呼び出しごとに新しい値を生成する度に繰り返し呼び出すことができます。

+0

ありがとうございます。それは助けられましたが、それでも私が望むことをやっていません。編集された質問を参照してください:) –

1

は、私はあなたが必要とする、あなたが最近TfのデータセットのAPIを使用して、所望の動作を得ることができると思いtensorflow> = 1.2.0

# Define number of samples and input shape for each iteration 
# you can set minval or maxval as per you data distribution and label distributon requirements 
num_samples = [20000,] 
input_shape = [32, 32, 3] 
dataset = tf.contrib.data.Dataset.from_tensor_slices((tf.random_normal([num_examples+input_shape]), tf.random_uniform([num_samples], minval=0, maxval=5))) 
# Define batch_size 
batch_size = 128 
dataset = dataset.batch(batch_size) 
# Define iterator 
iterator = dataset.make_initializable_iterator() 
# Get one batch 
next_example, next_label = iterator.get_next() 
# calculate loss from the estimator fucntion you are using 
estimator_loss = some_estimator(next_example, next_label) 
# Set number of Epochs here 
num_epochs = 100 
for _ in range(num_epochs): 
    sess.run(iterator.initializer) 
    while True: 
     try: 
      _loss = sess.run(estimator_loss) 
     except tf.errors.OutOfRangeError: 
      break 
+0

新しいランダム値を得るには 'sess.run(iterator.initializer)'で 'num_samples'のステップごとにイテレータを再初期化することが理にかなっていると思います。 – bodokaiser

0

は、私が聞いても:あなたがあなたのデータを生成するデータの増強を行っていますか?そうであれば、テンソルフローフレームワークのランダム関数を使用する限り、呼び出されるinput_fn関数は無限個のランダムサンプルを生成するはずです。 (tf.random_uniformの代わりに、numpyなどの対応するメソッドがあります)これは私にとっては役に立ちます。

関連する問題