2016-12-10 11 views
0

畳み込みレイヤの変数を取得して視覚化したいと考えています。 そして、私のコードはテンソルフロー内の畳み込みレイヤの変数を取得したときにエラーが発生しました

d3 = de_conv(d2, weights2['wc2'], biases2['bc2'], out_shape=[batch_size , c2, c2, 128]) 
    d3 = batch_norm(d3, epsilon=1e-5, decay=0.9) 
    d3 = tf.nn.relu(d3) 

    tf.add_to_collection('weight_2', weights2['wc3']) 

とテストに

with tf.Session() as sess: 

      saver.restore(sess , model_path) 
      conv_weights = sess.run([tf.get_collection('weight_2')]) 
      #visualize the weights 
      conv_weights = np.array(conv_weights) 

      print(conv_weights.shape) 
      vis_square(conv_weights) 

あるしかし、私はconv_weightsがコンバージョン層のための

(1, 1, 5, 5, 1, 128) 

答えて

0

重量混乱寸法を有して理解していない[filter height, filter width, input channels, number of filters (output channels]でなければなりません。最初の2つの次元を除いて、あなたの体重は適合します。それはちょうど2つのリストに包まれていますか?例えば。 weightsの代わりに[[weights]]となります。

関連する問題