2017-10-24 11 views
1

は、次のコードを考えてみましょう:Estimatorを何度も練習するときにテンソルがこのグラフの要素ではないと言うのはなぜですか?

import tensorflow as tf 

from tensorflow.python.estimator.model_fn import EstimatorSpec 
from tensorflow.contrib.keras.api.keras.layers import Dense 


def model_fn_1(features, labels, mode): 
    x = [[1]] 
    labels = [[10]] 
    m = tf.constant([[1, 2], [3, 4]], tf.float32) 
    lookup = tf.nn.embedding_lookup(m, x, name='embedding_lookup') 

    preds = Dense(1)(lookup) 
    loss = tf.reduce_mean(labels - preds) 
    train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss, tf.train.get_global_step()) 

    eval_metric_ops = {'accuracy': tf.metrics.accuracy(labels, preds)} 
    return EstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops) 


model_1 = tf.estimator.Estimator(model_fn_1) 
model_1.train(input_fn=lambda: None, steps=1) 

予想したように、私はmodel_1.train(input_fn=lambda: None, steps=1)複数回実行することができますし、トレーニングが前回の実行から継続されます。

さて、次のコードを考えてみます。この場合

import tensorflow as tf 
import numpy as np 

from tensorflow.python.estimator.model_fn import EstimatorSpec 
from tensorflow.contrib.keras.api.keras.layers import Embedding, Dense 

def model_fn_2(features, labels, mode): 
    x = tf.constant([[1]]) 
    labels = [[10]] 
    m = np.array([[1, 2], [3, 4]]) 
    m = Embedding(2, 2, weights=[m], input_length=1, name='embedding_lookup') 
    lookup = m(x) 

    preds = Dense(1)(lookup) 
    loss = tf.reduce_mean(labels - preds) 
    train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss, tf.train.get_global_step()) 

    eval_metric_ops = {'accuracy': tf.metrics.accuracy(labels, preds)} 
    return EstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops) 


model_2 = tf.estimator.Estimator(model_fn_2) 
model_2.train(input_fn=lambda: None, steps=1) 

を、私は一度だけmodel_2.train(input_fn=lambda: None, steps=1)を実行することができますし、私は再びそれを実行しようとしたとき、私は次のエラーを取得する:

ValueError: Fetch argument cannot be interpreted as a Tensor. (Tensor Tensor("embedding_lookup/embeddings:0", shape=(2, 2), dtype=float32_ref) is not an element of this graph.)

なぜこれが起こり、どのように修正できますか?

答えて

2

テンソルフローケラのバックエンドでは、バグまたはサポートされないケースがあります。セッションはグローバルにキャッシュされ、クリアされません。

from tensorflow.contrib.keras.python.keras.backend import clear_session 
clear_session() 

trainの呼び出しで手動でクリアすることができます。短い理由:2番目のtrainコールは、新しいノードを持つ新しいグラフを作成しますが、フードの下のセッションは前のグラフを保持しているため、互換性がありません。

+0

ボンネット下のセッションで常に前のグラフが保持されている場合、最初のスニペットのコードに対してセッションをクリアする必要はありません。 – mauna

関連する問題