2017-06-28 9 views
3

私はTFに迷惑メールだので、簡単に行く。TensorFlowエスティメータに入力パイプラインを渡す

ラベル付きのディレクトリにある一連の画像から簡単なCNNを訓練する必要があります。多くを見回した後、私はTF入力パイプラインを準備するこのコードを調理し、イメージアレイを印刷することができました。 TFのドキュメントに与えられた

image_list, label_list = load_dataset() 

    imagesq = ops.convert_to_tensor(image_list, dtype=dtypes.string) 
    labelsq = ops.convert_to_tensor(label_list, dtype=dtypes.int32) 

    # Makes an input queue 
    input_q = tf.train.slice_input_producer([imagesq, labelsq], 
               shuffle=True) 

    file_content = tf.read_file(input_q[0]) 
    train_image = tf.image.decode_png(file_content,channels=3) 
    train_label = input_q[1] 

    train_image.set_shape([120,120,3]) 

    # collect batches of images before processing 
    train_image_batch, train_label_batch = tf.train.batch(
     [train_image, train_label], 
     batch_size=5 
     # ,num_threads=1 
    ) 

    with tf.Session() as sess: 
     # initialize the variables 
     sess.run(tf.global_variables_initializer()) 
     # initialize the queue threads to start to shovel data 
     coord = tf.train.Coordinator() 
     threads = tf.train.start_queue_runners(coord=coord) 
     # print "from the train set:" 
     for i in range(len(image_list)): 
      print sess.run(train_image_batch) 
     # sess.run(train_image) 
     # sess.run(train_label) 
     # classifier.fit(input_fn=lambda: (train_image, train_label), 
     #    steps=100, 
     #    monitors=[logging_hook]) 

     # stop our queue threads and properly close the session 
     coord.request_stop() 
     coord.join(threads) 
     sess.close() 

しかしMNISTの例を見て、私は彼らが見積もりクラスと一緒に cnn_model_fn を使用して参照してください。

私は自分自身を定義しましたcnn_model_fn 2つを組み合わせたいと思います。これで前進する方法を教えてください。このコードは、

「は同じではありません、入力グラフおよびレイヤグラフ」それは、パイプラインがそうでなければ、その空に、セッションが実行された場合にのみ移入されると、それはとValueErrorを与えると思われる

classifier = learn.Estimator(model_fn=cnn_model_fn, model_dir='./test_model') 
classifier.fit(input_fn=lambda: (train_image, train_label), 
steps=100, 
monitors=[logging_hook]) 

してください動作しません。助けて。

答えて

2

私は自分自身テンソルフローに新しいので、塩の穀物でこれを取る。

AFAICTでは、「テンソル」または「操作」を作成するtf APIのいずれかを呼び出すと、Graphというコンテキストに作成されます。

さらに、Estimatorを実行すると、実行ごとに新しい空きが作成されます。Graphmodel_fninput_fnを実行することによってGraphが生成されます。tf APIは、この新しいGraphのコンテキストで「テンソル」と「操作」を追加します。

model_fninput_fnの戻り値は、部品が正しく接続されるように参照を提供するだけです。Graphにはすでにそれらが含まれています。

が作成され、関連する操作が暗黙的なデフォルトのGraphに追加される前に入力操作が既に作成されています(自動的に作成されます)。したがって、Estimatorが新しいモデルを作成し、モデルをmodel_fnに設定すると、入力とモデルは2つの異なるグラフに表示されます。

これを修正するには、input_fnを変更する必要があります。 (image, labels)のペアをlambdaにラップするだけでなく、入力の構造全体を関数にラップして、EstimatorがすべてのAPI呼び出しの副作用としてinput_fnを実行すると、すべての入力操作とテンソルがコンテキストで作成されるようにします正しいGraph

関連する問題