2016-11-03 30 views
2

私は入力データの3Dボリュームを持っており、各スライスにConvNetを適用したいと考えています。Tensorflow:データをスライスして各スライスにコンボリューションを適用する

この質問は重複しています。悲しいことに答えなし: How to slice a batch and apply an operation on each slice in TensorFlow

私はTimeDistributed層を使用します。 Tensorflowでは、私は直接同等のものを見つけることができませんでした。 代わりに、データを自分でスライスする必要があることがわかります。 - > ConvNet - >concat

x=tf.placeholder(tf.float32, shape=[None, 40, 40, 40, 1]) 
slices=tf.split(0,40, x) 

segmented_slices=[] 

for slice in slices: 
    # apply a ConvNet to each slice 
    reshaped=tf.reshape(slice, (40, 40, 1))   #<-------second error 
    # reshaped=tf.reshape(slice, (None, 40, 40, 1)) #<-------third error 

    # segmented_slice=conv2d(slice, 3,1,32)   #<-------first error 
    segmented_slice=conv2d(reshaped, 3,1,32) 
    segmented_slice=conv2d(segmented_slice, 3,32,32) 

    #... (more convolutions) 

    segmented_slices.append(segmented_slice) 

volume=tf.concat(0, segmented_slices) 

基本的なレイアウトはsplitです:

これは、これまでのところ、私のコードです。 しかし、splitは次元を保持します。私は単に畳み込みにsliceを渡すと、それは文句:

ValueError: Shape (?, 40, 40, 40, 1) must have rank 4 

をので、私はリシェイプを追加しました。これは確かに次元の数を減らします。しかし、それは明らかにbatch_sizeも削減しています。最初のエラーメッセージと比較して、疑問符と最初の40文字は両方とも消えています。

ValueError: Shape (40, 40, 1) must have rank 4 

私は再構成でbatch_sizeを維持する必要があるようです。タプルにNoneを追加しようとしました。これはさらに別のエラーメッセージを生成します:

TypeError: Expected int32, got None of type '_Message' instead. 

これは正しい方法ですか? 私はこれを自分で処理するつもりですか?

+0

コードの先頭にあるtf.splitは、分割するバッチディメンションを指定していますが、サイズ40ディメンションのいずれかで分割するように見えます。 tf.split(1,40、x)]でスライスされたスライス= [tf.squeeze(スライスされた、squeeze_dims = [1])のようなものでしょうか?それは私に40(?、40、40、1)テンソルのリストを与えます。さて、あなたは何かを形作ることができますが、 'tf.shape'でテンソルとして形を得ることで、未知の次元を保持します。 –

答えて

2

Keras TimeDistributedレイヤのは、それが実装されている方法を見てみましょう、何が必要である場合:

input_length = input_shape[1] # assume 2nd dim is the one to slice 
# ... 
# Shape: (num_samples * timesteps, ...) 
inputs = K.reshape(inputs, (-1,) + input_shape[2:]) 
y = self.layer.call(inputs) # (num_samples * timesteps, ...) 
# Shape: (num_samples, timesteps, ...) 
output_shape = self.compute_output_shape(input_shape) 
y = K.reshape(y, (-1, input_length) + output_shape[2:]) 

基本的な考え方は、第1および第2の寸法(バッチおよびスライシングようにテンソルを再形成することです次元)は1つに崩壊します。言い換えれば、各「スライス」はバッチ内の追加のデータポイントと考えることができます。この新しい仮想バッチに何らかの計算を適用し、最後に元の形状に戻し直します。 これらの操作はすべてTensorflowで簡単に実装できます。

関連する問題