2017-04-24 2 views
9

私はいくつかのチュートリアルに基づいて、いくつかのシンプルなモデルを作成しました。私が行ったことから、モデルは再利用するのが難しいと感じています。モデルをカプセル化するために、クラスを持つ構造をいくつか作成する必要があると感じています。テンソルフローモデルのデザインパターン

テンソルフローモデルの「標準的な」方法は何ですか?このためのコーディング規約/ベストプラクティスはありますか?

+0

質問はあまりにも広範ではありません。 GoogleのTensorflowは、そこから誰かがモデルを構造化するための規約を共有しているといいでしょう。 –

答えて

3
Tensorflow例を通して

と、三つの機能にモデルを分割するようにされた構造モデルコードの顕著なパターンのチュートリアル:

  • inference(inputs, ...)の上部に損失を追加したモデル
  • loss(logits, ...)を構築logits訓練OPS

モデルFOを作成する追加

  • train(loss, ...)このパターンは、例(codetutorial)のためCIFAR-10チュートリアルのために使用されている

    inputs = tf.placeholder(...) 
    logits = mymodel.inference(inputs, ...) 
    loss = mymodel.loss(logits, ...) 
    train = mymodel.train(loss, ...) 
    

    :Rのトレーニングは、あなたのコードは次のようになります。

    1つの事柄は、inferenceloss関数の間で(Python)変数を共有できないという事実です。しかし、これは大きな問題ではありません。Tensorflowは正確にこのユースケースにGraph collectionsを提供しています。これにより、より洗練された設計が可能になります。すべての正則化罰則は、デフォルトでは、コレクションtf.GraphKeys.REGULARIZATION_LOSSESに(source)追加されますので、あなたは、あなたはすでにあなたが必要なものを持っている(例えばtf.layers.conv2dlayersモジュールを使用している場合

    :このための一つの主要なユースケースが正則です。あなたがこれを行うとき、たとえば:あなたはレイヤーモジュールを使用していない場合は、あなただけのように(手動でコレクションを移入する必要があります

    def loss(logits, labels): 
        softmax_loss = tf.losses.softmax_cross_entropy(
         onehot_labels=labels, 
         logits=logits) 
    
        regularization_loss = tf.add_n(tf.get_collection(
         tf.GraphKeys.REGULARIZATION_LOSSES))) 
    
        return tf.add(softmax_loss, regularization_loss) 
    

    conv1 = tf.layers.conv2d(
        inputs, 
        filters=96, 
        kernel_size=11, 
        strides=4, 
        activation=tf.nn.relu, 
        kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), 
        kernel_regularizer=tf.contrib.layers.l2_regularizer(), 
        name='conv1') 
    

    をあなたの損失は、このようになります。リンクされたソーススニペット)。基本的にはtf.add_to_collectionを使用してコレクションに罰則を追加する:これにより

    tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, reg_penalty) 
    

    は、あなただけの上記のような正則の罰則を含む損失を計算することができます。

  • +0

    いいパターンですが、ここに問題があります。損失の一部として正規化期間がある場合は機能しません。 –

    +0

    グラフコレクションを使ってこれを行う良い方法があります。答えを更新しました。 – thertweck

    関連する問題