それとも、単にモデルの定義のためのupdated Estimator API of TensorFlow r1.1
APIを使用するには、パラメータ、戻り値の型や関数名のみで、いくつかの小さな変化と非常に似ています。ここで私が使用している例です。
def model_fn():
def _build_model(features, labels, mode, params):
# 1. Configure the model via TensorFlow operations
# Connect the first hidden layer to input layer (features) with relu activation
y = tf.contrib.layers.fully_connected(features, num_outputs=64, activation_fn=tf.nn.relu,
weights_initializer=tf.contrib.layers.xavier_initializer())
y = tf.contrib.layers.fully_connected(y, num_outputs=64, activation_fn=tf.nn.relu,
weights_initializer=tf.contrib.layers.xavier_initializer())
y = tf.contrib.layers.fully_connected(y, num_outputs=1, activation_fn=tf.nn.sigmoid,
weights_initializer=tf.contrib.layers.xavier_initializer())
predictions = y
# 2. Define the loss function for training/evaluation
if mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL:
loss = tf.reduce_mean((predictions - labels) ** 2)
else:
loss = None
if mode != tf.estimator.ModeKeys.PREDICT:
eval_metric_ops = {
"rmse": tf.metrics.root_mean_squared_error(tf.cast(labels, tf.float32), predictions),
"accuracy": tf.metrics.accuracy(tf.cast(labels, tf.float32), predictions),
"precision": tf.metrics.precision(tf.cast(labels, tf.float32), predictions)
}
else:
eval_metric_ops = None
# 3. Define the training operation/optimizer
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = tf.contrib.layers.optimize_loss(
loss=loss,
global_step=tf.contrib.framework.get_global_step(),
learning_rate=0.001,
optimizer="Adam")
else:
train_op = None
if mode == tf.estimator.ModeKeys.PREDICT:
predictions_dict = {"pred": predictions}
else:
predictions_dict = None
# 5. Return predictions/loss/train_op/eval_metric_ops in ModelFnOps object
return tf.estimator.EstimatorSpec(mode=mode,
predictions=predictions_dict,
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops)
return _build_model
そして、あなたはこのように、このモデルを使用することができます。
e = tf.estimator.Estimator(model_fn=model_fn(), params=None)
e.train(input_fn=input_fn(), steps=1000)
TensorFlowのR1.1のための入力機能の例は、私の中に見つけることができます答えhere
[cnn_mnistチュートリアル](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/layers/cnn_mnist.py)に従うと同様の問題が発生します。エラー情報に基づいて、私は 'tensorflow.contrib.learn.SKCompat import SKCompat'から' SKCompat() 'とEstimatorをラップしました。しかし、それは動作しません...エラー: "SKCompatというモジュールはありません"。あまりにも助けが必要です! – user3768495