'Wide & Deep Learning'モデルを自分のデータセットでトレーニングしようとしています。モデルをトレーニングセットに適合させるとこのエラーが発生します。ValueError:ロジットとターゲットの形状が同じでなければなりません。
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-15-8f5351c1fdf8> in <module>()
----> 1 m.fit(input_fn=train_input_fn, steps=200)
/Users/prisma/anaconda/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.pyc in fit(self, x, y, input_fn, steps, batch_size, monitors, max_steps)
331 steps=steps,
332 monitors=monitors,
--> 333 max_steps=max_steps)
334 logging.info('Loss for final step: %s.', loss)
335 return self
/Users/prisma/anaconda/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.pyc in _train_model(self, input_fn, steps, feed_fn, init_op, init_feed_fn, init_fn, device_fn, monitors, log_every_steps, fail_on_nan_loss, max_steps)
660 features, targets = input_fn()
661 self._check_inputs(features, targets)
--> 662 train_op, loss_op = self._get_train_ops(features, targets)
663
664 # Add default monitors.
/Users/prisma/anaconda/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.pyc in _get_train_ops(self, features, targets)
188 logits = self._logits(features, is_training=True)
189 if self._enable_centered_bias:
--> 190 centered_bias_step = [self._centered_bias_step(targets, features)]
191 else:
192 centered_bias_step = []
/Users/prisma/anaconda/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.pyc in _centered_bias_step(self, targets, features)
272 with ops.name_scope(None, "centered_bias", (targets, features)):
273 training_loss = self._target_column.training_loss(
--> 274 logits, targets, features)
275 # Learn central bias by an optimizer. 0.1 is a convervative lr for a
276 # single variable.
/Users/prisma/anaconda/lib/python2.7/site-packages/tensorflow/contrib/layers/python/layers/target_column.pyc in training_loss(self, logits, target, features, name)
204 """
205 target = target[self.name] if isinstance(target, dict) else target
--> 206 loss_unweighted = self._loss_fn(logits, target)
207
208 weight_tensor = self.get_weight_tensor(features)
/Users/prisma/anaconda/lib/python2.7/site-packages/tensorflow/contrib/layers/python/layers/target_column.pyc in _log_loss_with_two_classes(logits, target)
387 target = array_ops.expand_dims(target, dim=[1])
388 loss_vec = nn.sigmoid_cross_entropy_with_logits(logits,
--> 389 math_ops.to_float(target))
390 return loss_vec
391
/Users/prisma/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/nn.pyc in sigmoid_cross_entropy_with_logits(logits, targets, name)
432 except ValueError:
433 raise ValueError("logits and targets must have the same shape (%s vs %s)"
--> 434 % (logits.get_shape(), targets.get_shape()))
435
436 # The logistic loss formula from above is
ValueError: logits and targets must have the same shape ((?, 1) vs (13647309, 24))
(13647309,24)の代わりにロジットの形状が(?、1)の理由を理解できません。 input_fn関数は、サイズ(13647309,24)と形状のラベルテンソル(13647309,24)のフィーチャーディクティックを返します。私の言う限り、ロジットはモデルの出力でなければなりませんが、DNNLinearCombinedClassifierに出力サイズを指定する場所はありませんので、出力サイズはラベルサイズと同じになるように自動的に調整されます。それは(13647309、24)です。なぜこのエラーが発生するのかわかりませんが、私のモデルには何か問題があると思います。コード全体が貼り付けるには時間がかかりすぎるので、ここでモデル構築部分を貼り付けます。
model_dir = tempfile.mkdtemp()
m = tf.contrib.learn.DNNLinearCombinedClassifier(
model_dir=model_dir,
linear_feature_columns=wide_columns,
dnn_feature_columns=deep_columns,
dnn_hidden_units=[100, 50])
私はテンソルフローチュートリアルからモデルのパラメータを変更しませんでした。私は自分のデータセットの観点から 'wide_columns'と 'deep_columns'を定義しました。モデルや入力機能に問題がありますか? tf.learn apiのウェブサイトでDNNLinearCombinedClassifierのリファレンスを見つけることができません。
更新:入力機能のためのコード
def input_fn(df):
continuous_cols = {k: tf.constant(df[k].values)
for k in CONTINUOUS_COLUMNS}
categorical_cols = {k: tf.SparseTensor(
indices=[[i, 0] for i in range(df[k].size)],
values=df[k].values,
shape=[df[k].size, 1])
for k in CATEGORICAL_COLUMNS}
feature_cols = dict(continuous_cols.items() + categorical_cols.items())
label = tf.constant(df[Label_COLUMNS].values)
return feature_cols, label
'Label_COLUMNS' 内の24個のチャンネルがあります。
input_fnのurコードを表示できますか? –
確かに。私はそれを追加しました。 –