2017-05-18 61 views
5

イメージを入力として取得し、出力として新しいイメージを生成するオートエンコーダーがあります。extract_image_patchesを使用した後の画像の再構築

入力イメージ(1x1024x1024x3)は、ネットワークに供給される前にパッチ(1024x32x32x3)に分割されます。

出力があると、パッチのサイズが1024x32x32x3になっても、1024x1024x3イメージを再構成できるようにしたいと考えています。私はこれを単純に形を変えることで悩まされたと思ったが、ここで起こったことがある。

まず、Tensorflowで読み取るよう画像:私はここでは、次のコード

patch_size = [1, 32, 32, 1] 
patches = tf.extract_image_patches([image], 
    patch_size, patch_size, [1, 1, 1, 1], 'VALID') 
patches = tf.reshape(patches, [1024, 32, 32, 3]) 

で画像をパッチし Input image

はこのイメージからパッチのカップルです:

Patched input #168Patched input #169

しかし、私はこのパッチデータを再構築して物事は洋ナシ形になります。この例では

reconstructed = tf.reshape(patches, [1, 1024, 1024, 3]) 
converted = tf.image.convert_image_dtype(reconstructed, tf.uint8) 
encoded = tf.image.encode_png(converted) 

Reconstructed output

、何も処理は、パッチと再構成との間で行われていません。この動作をテストするために使用できるversion of the codeを作成しました。これを使用するには、次のコマンドを実行します。

echo "/path/to/test-image.png" > inputs.txt 
mkdir images 
python3 image_test.py inputs.txt images 

コードは、各入力画像1024個のパッチごとに1枚の入力画像、1枚のパッチの画像、および1つの出力画像を作成するので、入力を作成する行をコメントアウトしますすべてのパッチの保存のみに関心がある場合は、画像を出力します。 (TF 1.0)

サイズの画像(4,4,1)を考慮に変換:あなたのタスクのための一つの小さな例 -

誰かが:(

+0

CJxDは、私の答えを見ています。パッチを印刷し、第1部で再構築して、コードで起こっている間違いを見ることができます。私の答えがあなたに役立つことを願っています – hars

+0

「ものが梨地になる」。私はその表現を借りるつもりです! –

答えて

3

使用更新#2は何が起こったかを説明してくださいサイズ(4,2,2,1)のパッチは、イメージに戻ってそれらを再構築

import tensorflow as tf 
image = tf.constant([[[1], [2], [3], [4]], 
       [[5], [6], [7], [8]], 
       [[9], [10], [11], [12]], 
       [[13], [14], [15], [16]]]) 

patch_size = [1,2,2,1] 
patches = tf.extract_image_patches([image], 
    patch_size, patch_size, [1, 1, 1, 1], 'VALID') 
patches = tf.reshape(patches, [4, 2, 2, 1]) 
reconstructed = tf.reshape(patches, [1, 4, 4, 1]) 
rec_new = tf.space_to_depth(reconstructed,2) 
rec_new = tf.reshape(rec_new,[4,4,1]) 

sess = tf.Session() 
I,P,R_n = sess.run([image,patches,rec_new]) 
print(I) 
print(I.shape) 
print(P.shape) 
print(R_n) 
print(R_n.shape) 

出力:

[[[ 1][ 2][ 3][ 4]] 
    [[ 5][ 6][ 7][ 8]] 
    [[ 9][10][11][12]] 
    [[13][14][15][16]]] 
(4, 4, 1) 
(4, 2, 2, 1) 
[[[ 1][ 2][ 3][ 4]] 
    [[ 5][ 6][ 7][ 8]] 
    [[ 9][10][11][12]] 
    [[13][14][15][16]]] 
(4,4,1) 
-

更新P = SQRT(H)

import tensorflow as tf 
import numpy as np 
c = 3 
h = 1024 
p = 32 

image = tf.random_normal([h,h,c]) 
patch_size = [1,p,p,1] 
patches = tf.extract_image_patches([image], 
    patch_size, patch_size, [1, 1, 1, 1], 'VALID') 
patches = tf.reshape(patches, [h, p, p, c]) 
reconstructed = tf.reshape(patches, [1, h, h, c]) 
rec_new = tf.space_to_depth(reconstructed,p) 
rec_new = tf.reshape(rec_new,[h,h,c]) 

sess = tf.Session() 
I,P,R_n = sess.run([image,patches,rec_new]) 
print(I.shape) 
print(P.shape) 
print(R_n.shape) 
err = np.sum((R_n-I)**2) 
print(err) 

出力のみ働い3つのチャネルのための(デバッグ..)

(1024, 1024, 3) 
(1024, 32, 32, 3) 
(1024, 1024, 3) 
0.0 

アップデート2から再構築

extract_image_patchesの出力が難しいようです。他の関数を使ってパッチを抽出し、プロセスを逆にして再構築して、より簡単に見えるようにします。

import tensorflow as tf 
import numpy as np 
c = 3 
h = 1024 
p = 128 


image = tf.random_normal([1,h,h,c]) 

# Image to Patches Conversion 
pad = [[0,0],[0,0]] 
patches = tf.space_to_batch_nd(image,[p,p],pad) 
patches = tf.split(patches,p*p,0) 
patches = tf.stack(patches,3) 
patches = tf.reshape(patches,[(h/p)**2,p,p,c]) 

# Do processing on patches 
# Using patches here to reconstruct 
patches_proc = tf.reshape(patches,[1,h/p,h/p,p*p,c]) 
patches_proc = tf.split(patches_proc,p*p,3) 
patches_proc = tf.stack(patches_proc,axis=0) 
patches_proc = tf.reshape(patches_proc,[p*p,h/p,h/p,c]) 

reconstructed = tf.batch_to_space_nd(patches_proc,[p, p],pad) 

sess = tf.Session() 
I,P,R_n = sess.run([image,patches,reconstructed]) 
print(I.shape) 
print(P.shape) 
print(R_n.shape) 
err = np.sum((R_n-I)**2) 
print(err) 

出力:

(1, 1024, 1024, 3) 
(64, 128, 128, 3) 
(1, 1024, 1024, 3) 
0.0 

あなたがここに他のクールなテンソル変換関数を見ることができる:それはバックグラウンドで多くのものがそうであるように、https://www.tensorflow.org/api_guides/python/array_ops

+0

異なるパッチサイズを設定すると、2番目の例がクラッシュします。例えばp = 64例えば – Temak

+0

パッチを作り直しながら "h"を調整する必要があります。今更新する予定です。 – hars

+0

'tf.reshape(patches、[(h/p)** 2、p、p、c])' – Temak

2

tf.extract_image_patchesは使いにくい静かです。

ノンオーバーラップするだけであれば、それを自分自身で書く方がはるかに簡単です。 image_to_patchesですべての操作を反転すると、フルイメージを再構築できます。

コードサンプル(プロット原画像とパッチ):

import tensorflow as tf 
from skimage import io 
import matplotlib.pyplot as plt 


def image_to_patches(image, patch_height, patch_width): 
    # resize image so that it's dimensions are dividable by patch_height and patch_width 
    image_height = tf.cast(tf.shape(image)[0], dtype=tf.float32) 
    image_width = tf.cast(tf.shape(image)[1], dtype=tf.float32) 
    height = tf.cast(tf.ceil(image_height/patch_height) * patch_height, dtype=tf.int32) 
    width = tf.cast(tf.ceil(image_width/patch_width) * patch_width, dtype=tf.int32) 

    num_rows = height // patch_height 
    num_cols = width // patch_width 
    # make zero-padding 
    image = tf.squeeze(tf.image.resize_image_with_crop_or_pad(image, height, width)) 

    # get slices along the 0-th axis 
    image = tf.reshape(image, [num_rows, patch_height, width, -1]) 
    # h/patch_h, w, patch_h, c 
    image = tf.transpose(image, [0, 2, 1, 3]) 
    # get slices along the 1-st axis 
    # h/patch_h, w/patch_w, patch_w,patch_h, c 
    image = tf.reshape(image, [num_rows, num_cols, patch_width, patch_height, -1]) 
    # num_patches, patch_w, patch_h, c 
    image = tf.reshape(image, [num_rows * num_cols, patch_width, patch_height, -1]) 
    # num_patches, patch_h, patch_w, c 
    return tf.transpose(image, [0, 2, 1, 3]) 


image = io.imread('http://www.petful.com/wp-content/uploads/2011/09/slow-blinking-cat.jpg') 
print('Original image shape:', image.shape) 
tile_size = 200 
image = tf.constant(image) 
tiles = image_to_patches(image, tile_size, tile_size) 

sess = tf.Session() 
I, tiles = sess.run([image, tiles]) 
print(I.shape) 
print(tiles.shape) 


plt.figure(figsize=(1 * (4 + 1), 5)) 
plt.subplot(5, 1, 1) 
plt.imshow(I) 
plt.title('original') 
plt.axis('off') 
for i, tile in enumerate(tiles): 
    plt.subplot(5, 5, 5 + 1 + i) 
    plt.imshow(tile) 
    plt.title(str(i)) 
    plt.axis('off') 
plt.show() 
関連する問題