TensorflowのEstimator APIで無限にランダムデータを生成するinput_fn
を作成することは可能ですか?Tensorflowのinput_fnで無限ランダムトレーニングデータを生成
これは私が望むものを基本的には次のとおりです。
def create_input_fn(function_to_generate_one_sample_with_label):
def _input_fn():
### some code ###
return feature_cols, labels
私は、このようなEstimator
インスタンスに機能を使用したい:
def data_generator():
features = ... generate a (random) feature vector ...
lablel = ... create suitable label ...
return features, labels
input_fn = create_input_fn(data_generator)
estimator.train(input_fn=input_fn, steps=ANY_NUMBER_OF_STEPS)
ポイントがために訓練することができることです必要な数のステップを実行し、必要なトレーニングデータをオンザフライで生成します。これはモデル調整のためのもので、さまざまな複雑な訓練データを試すことができ、訓練データに合わせてモデルの能力を知ることができます。
編集 JKMが示唆したように、私はこのように、実際の発電機を使用してみました:
def create_input_fn(function, batch_size=100):
def create_generator():
while True:
features = ... generate <batch_size> feature vectors ...
lablel = ... create <batch_size> labels ...
yield features, label
g = create_generator()
def _input_fn():
return next(g)
return _input_fn
私はそれを実行するために取得するバッチサイズを追加する必要がありました。今は実行されますが、input_fn
は1回だけ呼び出されるため、新しいデータは生成されません。生成された最初の<batch_size>
サンプルを列車に入れるだけです。提供されたinput_fn
を使用してデータをリフレッシュするようエスティメータに指示する方法はありますか?
ありがとうございます。それは助けられましたが、それでも私が望むことをやっていません。編集された質問を参照してください:) –