2017-05-12 6 views
1

私は、形状が[batch_size, sentence_length, word_dim]のプレースホルダテンソルとshape=[batch_size, num_indices]のインデックスのリストを持っています。インデックスは第2の軸上にあり、文中の単語のインデックスです。 Batch_size & sentence_lengthは、実行時にのみ認識されます。Tensorflow:2番目の軸に沿ってインデックスのリストを含む3Dテンソルをスライスします。

[batch_size, len(indices), word_dim]の形状のテンソルはどのように抽出しますか?

私は約tensorflow.gatherを読んでいましたが、最初の軸に沿ってスライスだけを集めるようです。私は正しいですか?

編集:私はそれが一定

def tile_repeat(n, repTime): 
    ''' 
    create something like 111..122..2333..33 ..... n..nn 
    one particular number appears repTime consecutively. 
    This is for flattening the indices. 
    ''' 
    print n, repTime 
    idx = tf.range(n) 
    idx = tf.reshape(idx, [-1, 1]) # Convert to a n x 1 matrix. 
    idx = tf.tile(idx, [1, int(repTime)]) # Create multiple columns, each column has one number repeats repTime 
    y = tf.reshape(idx, [-1]) 
    return y 

def gather_along_second_axis(x, idx): 
    ''' 
    x has shape: [batch_size, sentence_length, word_dim] 
    idx has shape: [batch_size, num_indices] 
    Basically, in each batch, get words from sentence having index specified in idx 
    However, since tensorflow does not fully support indexing, 
    gather only work for the first axis. We have to reshape the input data, gather then reshape again 
    ''' 
    reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices] 
    idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(x)[1]) * tf.shape(x)[1] + reshapedIdx 
    y = tf.gather(tf.reshape(x, [-1,int(tf.shape(x)[2])]), # flatten input 
       idx_flattened) 
    y = tf.reshape(y, tf.shape(x)) 
    return y 

x = tf.constant([ 
      [[1,2,3],[3,5,6]], 
      [[7,8,9],[10,11,12]], 
      [[13,14,15],[16,17,18]] 
    ]) 
idx=tf.constant([[0,1],[1,0],[1,1]]) 

y = gather_along_second_axis(x, idx) 
with tf.Session(''): 
    print y.eval() 
    print tf.Tensor.get_shape(y) 

と協力を得るために管理し、出力は次のとおりです。

[[[ 1 2 3] 
    [ 3 5 6]] 
[[10 11 12] 
    [ 7 8 9]] 
[[16 17 18] 
    [16 17 18]]] 

形状:(3, 2, 3)

しかし、入力がプレースホルダているとき、それは動作しません。返品エラー:

idx = tf.tile(idx, [1, int(repTime)]) 
TypeError: int() argument must be a string or a number, not 'Tensor' 

Python 2.7、テンソルフロー0.12

ありがとうございます。 AllenLavoieのコメント@に

+0

。それが終わると 'tf.gather'はうまくいくでしょう。 –

+0

こんにちは@AllenLavoie:ありがとうございます。私はまた、ここからそのアイディアを得ました[https://github.com/tensorflow/tensorflow/issues/206]。しかし、私はそれをプレースホルダー入力のために働かせることはできません。編集した質問をすばやくご覧いただけますか? –

+0

'tf.shape'を使って次元を整数のTensorとして得ることができます。これは、静的形状情報が利用できない場合でも機能します。 –

答えて

1

感謝、私は最終的に解決策を考え出すことができます:私はプレースホルダの最初の2次元の平坦化し、平坦化の次元でインデックスを計算したい

def tile_repeat(n, repTime): 
    ''' 
    create something like 111..122..2333..33 ..... n..nn 
    one particular number appears repTime consecutively. 
    This is for flattening the indices. 
    ''' 
    print n, repTime 
    idx = tf.range(n) 
    idx = tf.reshape(idx, [-1, 1]) # Convert to a n x 1 matrix. 
    idx = tf.tile(idx, [1, repTime]) # Create multiple columns, each column has one number repeats repTime 
    y = tf.reshape(idx, [-1]) 
    return y 

def gather_along_second_axis(x, idx): 
    ''' 
    x has shape: [batch_size, sentence_length, word_dim] 
    idx has shape: [batch_size, num_indices] 
    Basically, in each batch, get words from sentence having index specified in idx 
    However, since tensorflow does not fully support indexing, 
    gather only work for the first axis. We have to reshape the input data, gather then reshape again 
    ''' 
    reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices] 
    idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(x)[1]) * tf.shape(x)[1] + reshapedIdx 
    y = tf.gather(tf.reshape(x, [-1,tf.shape(x)[2]]), # flatten input 
       idx_flattened) 
    y = tf.reshape(y, tf.shape(x)) 
    return y 

x = tf.constant([ 
      [[1,2,3],[3,5,6]], 
      [[7,8,9],[10,11,12]], 
      [[13,14,15],[16,17,18]] 
    ]) 
idx=tf.constant([[0,1],[1,0],[1,1]]) 

y = gather_along_second_axis(x, idx) 
with tf.Session(''): 
    print y.eval() 
    print tf.Tensor.get_shape(y) 
関連する問題