2017-07-14 10 views
2

私は句読点を予測するためにtf.estimator APIを使用しています。私はTFRecordsとtf.train.shuffle_batchを使って前処理されたデータでそれを訓練しました。今私は予測をしたい。静的なNumPyデータをtf.constantに送り、これをinput_fnから返すことができます。TensorFlow Estimator APIを使用して非同期予測を実行するにはどうすればよいですか?

しかし、私はシーケンスデータを扱っており、一度に1つの例をフィードする必要があり、次の入力は以前の出力に依存しています。私はまた、HTTPリクエストを介してデータ入力を処理できるようにしたい。

estimator.predictが呼び出されるたびに、チェックポイントを再ロードしてグラフ全体を再作成します。これは遅くて高価です。ですから、input_fnにデータを動的にフィードできるようにする必要があります。

私の現在の試みは、ほぼこれです:

feature_input = tf.placeholder(tf.int32, shape=[1, MAX_SUBSEQUENCE_LEN]) 
q = tf.FIFOQueue(1, tf.int32, shapes=[[1, MAX_SUBSEQUENCE_LEN]]) 
enqueue_op = q.enqueue(feature_input) 

def input_fn(): 
    return q.dequeue() 

estimator = tf.estimator.Estimator(model_fn, model_dir=model_file) 
predictor = estimator.predict(input_fn=input_fn) 
sess = tf.Session() 
output = None 

while True: 
    x = get_numpy_data(x, output) 
    if x is None: 
     break 
    sess.run(enqueue_op, {feature_input: x}) 
    output = predictor.next() 
    save_to_file(output) 

sess.close() 

私は、次のエラーを取得していますが: ValueError: Input graph and Layer graph are not the same: Tensor("EmbedSequence/embedding_lookup:0", shape=(1, 200, 128), dtype=float32) is not from the passed-in graph.

どのように私は非同期的で予測の1を取得するためにinput_fnを通じて私の既存のグラフにデータをプラグインすることができます時間?

答えて

3

主な問題は、すべてのテンソルをinput_fnの内部に作成する必要があるか、同じグラフに追加しないことです。エンキュー操作を実行する必要がありましたが、入力関数から返されたものにアクセスすることは不可能でした。

私はEstimatorクラスを継承し、カスタムを作成することになった私は、動的に予測キューにデータを追加し、結果を返すことができる機能を予測する:

# async_estimator.py 

import six 
import tensorflow as tf 
from tensorflow.python.estimator.estimator import Estimator 
from tensorflow.python.estimator.estimator import _check_hooks_type 
from tensorflow.python.estimator import model_fn as model_fn_lib 
from tensorflow.python.framework import ops 
from tensorflow.python.framework import random_seed 
from tensorflow.python.training import saver 
from tensorflow.python.training import training 


class AsyncEstimator(Estimator): 

    def async_predictor(self, 
       dtype, 
       shape=None, 
       predict_keys=None, 
       hooks=None, 
       checkpoint_path=None): 
     """Returns a tuple of functions: first runs predicitons on the model, second cleans up 
     Args: 
      dtype: the dtype of the input 
      shape: the shape of the input placeholder (optional) 
      predict_keys: list of `str`, name of the keys to predict. It is used if 
      the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used 
      then rest of the predictions will be filtered from the dictionary. If 
      `None`, returns all. 
      hooks: List of `SessionRunHook` subclass instances. Used for callbacks 
      inside the prediction call. 
      checkpoint_path: Path of a specific checkpoint to predict. If `None`, the 
      latest checkpoint in `model_dir` is used. 
     Returns: 
      (predict, finish): tuple of functions 

      predict: runs a single prediction and returns the results 
       Args: 
        x: NumPy array of input 
       Returns: 
        Evaluated value of the prediction 

      finish: closes the session, allowing the program to exit 

     Raises: 
      ValueError: Could not find a trained model in model_dir. 
      ValueError: if batch length of predictions are not same. 
      ValueError: If there is a conflict between `predict_keys` and 
      `predictions`. For example if `predict_keys` is not `None` but 
      `EstimatorSpec.predictions` is not a `dict`. 
     """ 
     hooks = _check_hooks_type(hooks) 
     # Check that model has been trained. 
     if not checkpoint_path: 
      checkpoint_path = saver.latest_checkpoint(self._model_dir) 
     if not checkpoint_path: 
      raise ValueError('Could not find trained model in model_dir: {}.'.format(
       self._model_dir)) 

     with ops.Graph().as_default() as g: 
      random_seed.set_random_seed(self._config.tf_random_seed) 
      training.create_global_step(g) 
      input_placeholder = tf.placeholder(dtype=dtype, shape=shape) 
      queue = tf.FIFOQueue(1, dtype, shapes=shape) 
      enqueue_op = queue.enqueue(input_placeholder) 
      features = queue.dequeue() 
      estimator_spec = self._call_model_fn(features, None, 
               model_fn_lib.ModeKeys.PREDICT) 
      predictions = self._extract_keys(estimator_spec.predictions, predict_keys) 
      mon_sess = training.MonitoredSession(
        session_creator=training.ChiefSessionCreator(
         checkpoint_filename_with_path=checkpoint_path, 
         scaffold=estimator_spec.scaffold, 
         config=self._session_config), 
        hooks=hooks) 

      def predict(x): 
       if mon_sess.should_stop(): 
        raise StopIteration 
       mon_sess.run(enqueue_op, {input_placeholder: x}) 
       preds_evaluated = mon_sess.run(predictions) 
       if not isinstance(predictions, dict): 
        return preds_evaluated 
       else: 
        preds = [] 
        for i in range(self._extract_batch_length(preds_evaluated)): 
         preds.append({ 
          key: value[i] 
          for key, value in six.iteritems(preds_evaluated) 
         }) 
        return preds 

      def finish(): 
       mon_sess.close() 

      return predict, finish 

そして、ここではそれを使用する大まかなコードです:

import tensorflow as tf 
from async_estimator import AsyncEstimator 


def doPrediction(model_fn, model_dir, max_seq_length): 
    estimator = AsyncEstimator(model_fn, model_dir=model_dir) 
    predict, finish = estimator.async_predictor(dtype=tf.int32, shape=(1, max_seq_length)) 
    output = None 

    while True: 
     # my input is dependent on the previous output 
     x = get_numpy_data(output) 
     if x is None: 
      break 
     output = predict(x) 
     save_to_disk(output) 

    finish() 

注:これは私のニーズに対応する簡単な解決策ですが、他の場合には修正する必要があります。それはTensorFlow 1.2.1に取り組んでいます。

TFは正式にこのようなものを採用して、Estimatorで動的な予測をより簡単に行うことを望みます。

関連する問題