マイトレーニングスクリプトは、TensorFlowモデルを訓練するために、非常にわずかオンラインチュートリアルからの変更:feed_dict(要約操作)
def train(data_set_dir, train_set_dir):
data = data_input.read_data_sets(data_set_dir, train_set_dir)
with tf.Graph().as_default():
global_step = tf.Variable(0, trainable=False)
# defines placeholders (type=tf.float32)
images_placeholder, labels_placeholder = placeholder_inputs(batch_size, image_size, channels)
logits = model.inference(images_placeholder, num_classes)
loss = loss(logits, labels_placeholder, num_classes)
train_op = training(loss, global_step, batch_size)
saver = tf.train.Saver(tf.all_variables())
summary_op = tf.merge_all_summaries()
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
for step in range(max_steps):
start_time = time.time()
feed_dict = fill_feed_dict(data, images_placeholder, labels_placeholder, batch_size)
_, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
# ... continue to print loss_value, run summaries and save checkpoints
上記と呼ばれるplaceholder_inputs機能は次のとおりです。
def placeholder_inputs(batch_size, img_size, channels):
images_pl = tf.placeholder(tf.float32,
shape=(batch_size, img_size, img_size, channels), name='images')
labels_pl = tf.placeholder(tf.float32,
shape=(batch_size, img_size, img_size), name='labels')
return images_pl, labels_pl
明確にするために、私が扱っているデータは、セグメント化問題におけるピクセルごとの分類のためのものです。上記のように、これはバイナリ分類の問題です。
そしてfeed_dict機能は次のとおりです。
私がで立ち往生していdef fill_feed_dict(data_set, images_pl, labels_pl, batch_size):
images_feed, labels_feed = data_set.next_batch(batch_size)
feed_dict = {images_pl: images_feed, labels_pl: labels_feed}
return feed_dict
:トレースバックは私のplaceholder_inputs
機能から「ラベル」テンソルによって引き起こされていると、それを明らかにし
tensorflow.python.framework.errors.InvalidArgumentError: You must feed a value for placeholder tensor 'labels' with dtype float and shape [1,750,750]
[[Node: labels = Placeholder[dtype=DT_FLOAT, shape=[1,750,750], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
。さらに、このエラーは、2つのプレースホルダ間で、私が見る限り、ランダムに移動し続けます。一度、それは 'ラベル' [labels_pl
]テンソル、別の時間、それは私の '画像' [images_pl
]テンソルです。詳細に
エラー:
File ".../script.py", line 32, in placeholder_inputs
shape=(batch_size, img_size, img_size), name='labels')
File ".../tensorflow/python/ops/array_ops.py", line 895, in placeholder
name=name)
File ".../tensorflow/python/ops/gen_array_ops.py", line 1238, in _placeholder
name=name)
File ".../tensorflow/python/ops/op_def_library.py", line 704, in apply_op
op_def=op_def)
File ".../tensorflow/python/framework/ops.py", line 2260, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/tensorflow/python/framework/ops.py", line 1230, in __init__
self._traceback = _extract_stack()
私が試した何/確認:
は無駄にはもちろん、forループの外にfeed_dictを置きます。
batch_size要件に対応するのに十分なデータがトレーニングデータディレクトリにあることを確認しました。
プレースホルダのdtypeを指定する際に複数のバリエーションがあります。「float」がスタックトレース内のキーであったとします。
クロスチェックされたデータシェイプ。これらは、プレースホルダで指定されたものとまったく同じです。
恐らくこれは私が思うよりもはるかに単純な問題です。たぶん私はここで見ることができないマイナーなタイプミスです。提案?私はすべての選択肢を使い果たしたと信じています。問題について新しい光を当てる人を探してください。
私はthisエラーの説明を参照しました。
アップデート:私はdidnの
また{<tf.Tensor 'images:0' shape=(1, 750, 750, 3) dtype=float32>:
array([[[[-0.1556225 , -0.13209309, -0.15954407],
[-0.15954407, -0.12032838, -0.13601466],
.....
[-0.03405387, 0.04829907, 0.09535789]]]], dtype=float32),
<tf.Tensor 'labels:0' shape=(1, 750, 750) dtype=float32>:
array([[[ 0., 0., 0., ..., 0., 0., 0.],
.....
[ 0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)}
何か:(ここではコメントで提案されているように)と期待値がプレースホルダに供給されていることに気づい
はsession.run
前print feed_dict
をしました先に言及していません: ループは初めて実行されます。したがって、私はstep = 0
の最初の値の出力を得て、次にstep=0
に指定したの文を印刷した後すぐに終了します。
アップデート2:
問題があった場所私は考え出しました。 summary_op
を印刷していました。しかし、なぜこれは私を超えているのですか?これは私がforループでそれを印刷する方法です:
if step % 100 == 0:
summary_str = sess.run(summary_op)
summary_writer.add_summary(summary_str, step)
このブロックをコメントアウトすると、トリックが実行されます。なぜこれが間違っているのだろうか?
アップデート3:以下
回答を解決しました。私が気づいたのは、TensorFlow CIFAR-10 exampleは、feed_dict
の明示的な言及なしで、同様のsess.run
を実行しており、正常に動作しているということです。どのくらい正確に動作しますか?
ナンプィ配列のデフォルトは 'np.float64'ですが、' DT_FLOAT'は 'np.float32'と同じですので、' .as_type(np.float32) 'を追加してください。 –
@YaroslavBulatovつまり、プレースホルダへの入力配列の 'astype(np.float32)'を意味すると仮定します。 – mshiv
プレースホルダーに入力される実際の形状とdtypeを把握するために、各session.runコールの前にprintを追加することがあります。 –