2017-03-14 6 views
4

ためinput_fn使用している場合:some_inputxy、およびbatch_sizeを持っている場合は設定BATCH_SIZE私はTFのハイレベルな見積もり使用していますtf.contrib.learn.Estimator

estim = tf.contrib.learn.Estimator(...) 
estim.fit (some_input) 

を、コードが実行されても警告付き。そこで私はinput_fnを使用しようとしましたが、xyをこのinput_fnに送信しましたが、batch_sizeは送信しませんでした。例が見つかりませんでした。

estim.fit/estim.evaluateへの入力としてinput_fnを使用し、batch_sizeも使用する単純な例を誰も共有できますか?

tf.train.batchを使用する必要がありますか?もしそうなら、高水準実装(tf.layers)にどのようにマージされますか?グラフのtf.Graph()またはセッションは分かりません。

WARNING::tensorflow:以下

は、私が得た警告です/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/monitors.pyから:657:呼び出し評価

(tensorflow.contrib.learn.python.learn.estimators.estimatorから)yは推奨されておらず、2016-12-01の後に削除されます。

アップデート手順: Estimatorは、Scikit Learnインターフェイスから別のクラスSKCompatに移動して に移動します。引数x、yおよびbatch_sizeはSKCompatクラスで利用可能な のみであり、Estimatorはinput_fnのみを受け入れます。

例変換:

EST =推定(...) - > EST = SKCompat(見積もり(...))

+1

将来的に_input_fn_に問題が発生する可能性がある人(現時点では、マニュアルはそれほど自明ではありません) - よく理解されている例を以下の[link](http:// blog。 mdda.net/ai/2017/02/25/estimator-input-fn)。 – Roi

答えて

4

link provided in Roi's own commentは確かにとても役に立ちました。私はしばらくの間だけでなく同じ質問に苦しんでいたので、私は上記の参考としてリンクが提供する答えを要約したいと思います:これは、この例のように使用することができます

def batched_input_fn(dataset_x, dataset_y, batch_size): 
    def _input_fn(): 
     all_x = tf.constant(dataset_x, shape=dataset_x.shape, dtype=tf.float32) 
     all_y = tf.constant(dataset_y, shape=dataset_y.shape, dtype=tf.float32) 
     sliced_input = tf.train.slice_input_producer([all_x, all_y]) 
     return tf.train.batch(sliced_input, batch_size=batch_size) 
    return _input_fn 

(TensorFlow V1を使用します。 1):

model = CustomModel(FLAGS.learning_rate) 
estimator= tf.estimator.Estimator(model_fn=model.build(), params=model.params()) 

estimator.train(input_fn=batched_input_fn(
     train.features, 
     train.labels, 
     FLAGS.batch_size), 
    steps=FLAGS.train_steps) 

残念ながら、このアプローチはがTen​​sorFlows低レベルAPIを使用して手差し()と比較して、又はtrain.shape[0] == batch_sizeとデータセット全体を使用して全くtrain.sliced_input_producer()train.batch()を使用していないと比較して、約10倍遅いです。少なくとも私のマシンでは(CPUのみ)。私は本当にこのアプローチがなぜそんなに遅いのだろうと思っています。何か案は?

編集:

私はtrain.batch()のパラメータとして> 1 num_threadsを使用して、少しそれをスピードアップすることができます。 2つのCPUを搭載したVMでは、このバッチング・メカニズムを使用して、デフォルトのnum_threads=1と比較してパフォーマンスを倍増できます。しかし、それでもなお、は5倍遅いです。 しかし、ネイティブシステムや、入力パイプライン用にすべてのCPUコアを使用し、モデル計算用にGPUを使用するシステムでは、結果が異なる場合があります。誰かがコメントに自分の結果を投稿できるなら、素晴らしいだろう。

関連する問題