2017-11-29 13 views
2

私は本当にゆっくりとしているものに頭を引っ張り、最適化を望んでいます。TensorFlowでチャネルごとに完全に接続されたレイヤーを適用する方法

私はエンコーダとデコーダの間に大きな潜在的なスペースを持つTensorFlowに畳み込み自動エンコーダを実装しようとしています。現在では、通常、完全に接続されたレイヤーでエンコーダをデコーダに接続しますが、これは潜在的な空間が高い次元性を持つため、計算上実現可能な機能が多すぎるためです。

this paperでこの問題の素晴らしい解決策を見つけました。彼らはそれを「チャネルワイズ完全接続層」と呼んでいる。基本的には、チャンネルごとに完全に接続されたレイヤーです。

私はこれを実装しようとしていますが、実際には動作するようになっていますが、グラフの生成には長い時間がかかります。これまでのコードです:

def _network(self, dataset, isTraining): 
     encoded = self._encoder(dataset, isTraining) 
     with tf.variable_scope("fully_connected_channel_wise"): 
      shape = encoded.get_shape().as_list() 
      print(shape) 
      channel_wise = tf.TensorArray(dtype=tf.float32, size=(shape[-1])) 
      for i in range(shape[-1]): # last index in shape should be the output channels of the last conv 
       channel_wise = channel_wise.write(i, self._linearLayer(encoded[:,:,i], shape[1], shape[1]*4, 
            name='Channel-wise' + str(i), isTraining=isTraining)) 
      channel_wise = channel_wise.concat() 
      reshape = tf.reshape(channel_wise, [shape[0], shape[1]*4, shape[-1]]) 
     reconstructed = self._decoder(reshape, isTraining) 
     return reconstructed 

これはなぜそういうのですか?実際には範囲(2048)ですが、すべての線形レイヤーは実際には小さくなります(4x16)。私はこれに間違った方法で近づいていますか?

ありがとうございます!

答えて

1

Tensorflowでその用紙の実装を確認できます。 ここでは、「チャネルごとに完全に接続されたレイヤー」の実装を示します。

def channel_wise_fc_layer(self, input, name): # bottom: (7x7x512) 
    _, width, height, n_feat_map = input.get_shape().as_list() 
    input_reshape = tf.reshape(input, [-1, width*height, n_feat_map]) 
    input_transpose = tf.transpose(input_reshape, [2,0,1]) 

    with tf.variable_scope(name): 
     W = tf.get_variable(
       "W", 
       shape=[n_feat_map,width*height, width*height], # (512,49,49) 
       initializer=tf.random_normal_initializer(0., 0.005)) 
     output = tf.batch_matmul(input_transpose, W) 

    output_transpose = tf.transpose(output, [1,2,0]) 
    output_reshape = tf.reshape(output_transpose, [-1, height, width, n_feat_map]) 

    return output_reshape 

https://github.com/jazzsaxmafia/Inpainting/blob/8c7735ec85393e0a1d40f05c11fa1686f9bd530f/src/model.py#L60

主なアイデアは、tf.batch_matmul機能を使用しています。

ただし、tf.batch_matmulはTensorflowの最新バージョンで削除されていますが、tf.matmulを使用して置き換えることができます。

+0

このリンクは質問に答えるかもしれませんが、回答の重要な部分をここに含めて参考にしてください。リンクされたページが変更された場合、リンクのみの回答は無効になります。 - [レビューより](/レビュー/低品質の投稿/ 18677744) – aerokite

+0

@aerokiteありがとう、私はすでに私の答えを変更しました。 – Haibo

+0

うわー!ありがとう!私は彼らがオンラインで実装されているとは信じられませんし、テキストでも言及しません!また、これは私がやっていたこととはまったく離れています。私と同じ問題があるのだろうかと思います。 –

関連する問題