2016-09-05 15 views
11

TensorFlow docsは、TFRecordReader、TextLineReader、QueueRunnerなどのキューとキューを使用してデータを読み取る方法について説明しています。ジェネレータを使用して入力を生成するTensorFlowネットワークのトレーニング方法は?

私がしたいのは、はるかに単純です。私は(X、y)タプルとしてトレーニングデータの無限のシーケンスを生成するPythonジェネレータ関数を持っています(両方ともnumpy配列で、最初の次元はバッチですサイズ)。私はちょうどそのデータを入力として使ってネットワークを訓練したいと思う。

データを生成するジェネレータを使用してTensorFlowネットワークをトレーニングする簡単な自己完結型の例がありますか? (MNISTまたはCIFAR例の線に沿って)

+2

['tf.data.Dataset.from_generator'](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator)があります。これはあなたの場合に役立つかもしれません。 – Jakub

答えて

15

は、データを生成する関数があるとします。

def generator(data): 
    ... 
    yield (X, y) 

を今すぐあなたのモデルのアーキテクチャについて説明し、別の関数が必要です。これは、Xを処理し、yを出力(例えば、ニューラルネットワーク)として予測しなければならない任意の関数とすることができる。

(回帰の場合には、例えば、クロスエントロピーまたはMSE)関数がXと入力としてyは、何らかの方法でXからYの予測を計算受け入れ、損失関数を返すと仮定するYおよびYを予測との間:

def neural_network(X, y): 
    # computation of prediction for y using X 
    ... 
    return loss(y, y_pred) 

モデルの仕事をするために、あなたはXとYの両方のためのプレースホルダを定義し、セッションを実行する必要があります。

X = tf.placeholder(tf.float32, shape=(batch_size, x_dim)) 
y = tf.placeholder(tf.float32, shape=(batch_size, y_dim)) 

プレースホルダはSOMETですあなたはfeed_dictしてセッションを実行するときに指定する必要があり、「自由変数」のような興:

with tf.Session() as sess: 
    # variables need to be initialized before any sess.run() calls 
    tf.global_variables_initializer().run() 

    for X_batch, y_batch in generator(data): 
     feed_dict = {X: X_batch, y: y_batch} 
     _, loss_value, ... = sess.run([train_op, loss, ...], feed_dict) 
     # train_op here stands for optimization operation you have defined 
     # and loss for loss function (return value of neural_network function) 

は、あなたがそれを有用見つけるだろう願っています。ただし、これは完全に動作する実装ではなく、ほとんど詳細を指定していないため疑似コードであることに注意してください。

+0

には、次の関数を手動でコーディングするのではなく、ジェネレータ関数を推定モデルに渡す方法がありますか? – skadoosh

+0

@skadooshあなたはKerasの使用を検討すべきだと思います。 –

+2

@skadoosh - テンソルフローを使いたい場合、バージョン1.6では 'tf.data.Dataset.from_generator'で' tf.data.Dataset'を受け付ける 'Estimator.train'ができます。 – Jakub

関連する問題