0
畳み込みレイヤの変数を取得して視覚化したいと考えています。 そして、私のコードはテンソルフロー内の畳み込みレイヤの変数を取得したときにエラーが発生しました
d3 = de_conv(d2, weights2['wc2'], biases2['bc2'], out_shape=[batch_size , c2, c2, 128])
d3 = batch_norm(d3, epsilon=1e-5, decay=0.9)
d3 = tf.nn.relu(d3)
tf.add_to_collection('weight_2', weights2['wc3'])
とテストに
with tf.Session() as sess:
saver.restore(sess , model_path)
conv_weights = sess.run([tf.get_collection('weight_2')])
#visualize the weights
conv_weights = np.array(conv_weights)
print(conv_weights.shape)
vis_square(conv_weights)
あるしかし、私はconv_weightsがコンバージョン層のための
(1, 1, 5, 5, 1, 128)