2017-11-11 17 views
3

手続き的に生成された(無限の)データソースがあり、これをハイレベルのTensorflow Estimatorの入力として使用して、イメージベースの3Dオブジェクト検出器。"TypeError:テンソルフローのエラーTensor 'エラーEstimator

私はちょうどTensorflor見積もりQuickstartのようにデータセットを設定し、私のdataset_input_fnは、機能のタプルを返し、Tensor年代にラベルを付け、Estimator.train関数は指定と同じように、そしてどのようにこのtutorial showsが、私はときにエラーを取得しています列車の関数を呼び出すしようとしている:

TypeError: 'Tensor' object is not iterable.

私が間違って何をしているのですか?ここで


def data_generator(): 
     """ 
     Generator for image (features) and ground truth object positions (labels) 

     Sample an image and object positions from a procedurally generated data source 
     """ 
     while True: 
      source.step() # generate next data point 

      object_ground_truth = source.get_ground_truth() # list of 9 floats 
      cam_img = source.get_cam_frame() # image (224, 224, 3) 
      yield (cam_img, object_ground_truth) 

    def dataset_input_fn(): 
     """ 
     Tensorflow `Dataset` object from generator 
     """ 

     dataset = tf.data.Dataset.from_generator(data_generator, (tf.uint8, tf.float32), \ 
      (tf.TensorShape([224, 224, 3]), tf.TensorShape([9]))) 
     dataset = dataset.batch(16) 

     iterator = dataset.make_one_shot_iterator() 

     features, labels = iterator.get_next() 
     return features, labels 

    def main(): 
     """ 
     Estimator [from Keras model](https://www.tensorflow.org/programmers_guide/estimators#creating_estimators_from_keras_models) 

     Try to call `est_vgg.train()` leads to the error 
     """ 
     .... 
     est_vgg16 = tf.keras.estimator.model_to_estimator(keras_model=keras_vgg16) 
     est_vgg16.train(input_fn=dataset_input_fn, steps=10) 
     .... 

full code

されています。ここで(ノート物事は異なり、この質問から命名されている)

は、スタックトレースです:

Traceback (most recent call last): 
    File "./rock_detector.py", line 155, in <module> 
    main() 
    File "./rock_detector.py", line 117, in main 
    est_vgg16.train(input_fn=dataset_input_fn, steps=10) 
    File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 302, in train 
    loss = self._train_model(input_fn, hooks, saving_listeners) 
    File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 711, in _train_model 
    features, labels, model_fn_lib.ModeKeys.TRAIN, self.config) 
    File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 694, in _call_model_fn 
    model_fn_results = self._model_fn(features=features, **kwargs) 
    File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 145, in model_fn 
    labels) 
    File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 92, in _clone_and_build_model 
    keras_model, features) 
    File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 58, in _create_ordered_io 
    for key in estimator_io_dict: 
    File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 505, in __iter__ 
    raise TypeError("'Tensor' object is not iterable.") 
TypeError: 'Tensor' object is not iterable. 
+0

私はあなたが望むと思う: 'get_next = iterator.get_next(); est_vgg16.train(input_fn = get_next、steps = 10'ですが、私はkerasを使用していませんので、そこで使用されている '.train'関数には完全に慣れていません) –

+0

エラーの完全なスタックトレースを教えてください? – mrry

+0

スタックトレースでポストが更新されました。上位レベルのAPIを使って何が起こっているのか分かりにくいです。tfの低レベルインターフェイスに切り替えることで多くの努力をしました。しかし、手動でジェネレーターを "フィード"するだけです。上位レベルのapiについての良い点は、すべてのトレーニングと詳細を処理し、おそらく処理を最適化できるということです。 – matwilso

答えて

3

は、あなたの入力機能を作ります次のような機能の辞書を返します。

def dataset_input_fn(): 
    ... 
    features, labels = iterator.get_next() 
    return {'image': features}, labels 
+0

これは問題を解決しました。 'dataset_input_fn'の' tf.uint8'を 'tf.float32'に変更しなければなりませんでした。 – matwilso

関連する問題