2017-06-04 10 views
2

カスタムモデルを作成し、tf.estimatorを使用してカプセル化することができます。しかし、私は例をもって書類を見つけることができないようです。TensorFlowでのtf.estimatorの使用方法に関するドキュメント

私はあなたのモデルを 'model_fn'の中で定義しなければならないことを知っていますが、この関数から正確に何を返すべきですか?また、私は損失と訓練のステップを 'model_fn'またはネットワーク内に置くことになっています。 tf.estimatorで動作させるには、以下のコードをどのように変更すればよいですか?本当に助けに感謝します。

def test_model(features,labels): 
    X = tf.placeholder(tf.float32,shape=(None,1),name="Data_Input") 
    #Output 
    Y = tf.placeholder(tf.float32,shape=(None,1),name="Target_Labels") 
    W = tf.Variable(tf.random_normal([0],stddev=stddev0)) 
    b = tf.Variable(tf.random_normal([0],stddev=stddev0)) 

    Ypredict = W*X + b 
    return Ypredict 

estimator = tf.estimator.Estimator(model_fn = test_model) 

答えて

1

tf.estimator.EstimatorSpecオブジェクトを返す必要があります。効果に何か:

def model_fn(features, labels, mode, params): 
    /* 
    Your marvelous model 
    */ 
    loss = tf.losses.softmax_cross_entropy(onehot_labels=labels_onehot, logits=logits) 
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) 
    train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step()) 
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) 

良くウォークスルーのために、hereを参照してくださいので、それまでよりあります。

関連する問題