2017-06-05 6 views
4

転送学習では、フィーチャのデータセットを作成するためにネットワークを使用することがよくあります。 SVM)。TensorFlow:tf.contrib.data APIの「ステートフルノードを値でキャプチャできません」

# feature_extractor will create a CNN on top of the given tensor 
def features(feature_extractor, ...): 
    dataset = inputs(...) # This creates a dataset of (image, label) pairs 

    def map_example(image, label): 
     features = feature_extractor(image, trainable=False) 
     # Leaving out initialization from a checkpoint here... 
     return features, label 

    dataset = dataset.map(map_example) 

    return dataset 

データセットのイテレータを作成するとき、これは失敗した操作:

は、私は、データセットのAPI(tf.contrib.data)とdataset.map()を使ってこれを実装したいと思います。

これは事実です。ネットワークのカーネルとバイアスは変数であり、ステートフルです。この特定の例では、彼らはそうである必要はありません。

Opsと具体的にはtf.Variableオブジェクトをステートレスにする方法はありますか?

私は単に定数としてそれらを作成することはできませんtf.layersを使用して、定数どちらを作成することはありませんtrainable=Falseを設定するが、ちょうどGraphKeys.TRAINABLE_VARIABLESコレクションに変数を追加しませんよので。

答えて

9

残念ながら、tf.Variableは本質的にステートフルです。しかし、このエラーは、イテレータを作成するためにDataset.make_one_shot_iterator()を使用した場合にのみ発生します。この問題を回避するには、代わりにオブジェクト入力パイプラインで使用されます。


*この制限の理由は、Dataset.make_one_shot_iterator()の実装の詳細及び仕掛品TensorFlow機能(Defun)は、データセットの定義をカプセル化するために使用する支持体です。ルックアップテーブルや変数などのステートフルなリソースを使用することは当初想定していたよりも一般的であるため、この制限を緩和する方法を検討しています。

関連する問題