2017-05-12 6 views

私は、形状が[batch_size, sentence_length, word_dim]のプレースホルダテンソルとshape=[batch_size, num_indices]のインデックスのリストを持っています。インデックスは第2の軸上にあり、文中の単語のインデックスです。 Batch_size & sentence_lengthは、実行時にのみ認識されます。Tensorflow:2番目の軸に沿ってインデックスのリストを含む3Dテンソルをスライスします。

[batch_size, len(indices), word_dim]の形状のテンソルはどのように抽出しますか?



def tile_repeat(n, repTime): 
    create something like 111..122..2333..33 ..... n..nn 
    one particular number appears repTime consecutively. 
    This is for flattening the indices. 
    print n, repTime 
    idx = tf.range(n) 
    idx = tf.reshape(idx, [-1, 1]) # Convert to a n x 1 matrix. 
    idx = tf.tile(idx, [1, int(repTime)]) # Create multiple columns, each column has one number repeats repTime 
    y = tf.reshape(idx, [-1]) 
    return y 

def gather_along_second_axis(x, idx): 
    x has shape: [batch_size, sentence_length, word_dim] 
    idx has shape: [batch_size, num_indices] 
    Basically, in each batch, get words from sentence having index specified in idx 
    However, since tensorflow does not fully support indexing, 
    gather only work for the first axis. We have to reshape the input data, gather then reshape again 
    reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices] 
    idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(x)[1]) * tf.shape(x)[1] + reshapedIdx 
    y = tf.gather(tf.reshape(x, [-1,int(tf.shape(x)[2])]), # flatten input 
    y = tf.reshape(y, tf.shape(x)) 
    return y 

x = tf.constant([ 

y = gather_along_second_axis(x, idx) 
with tf.Session(''): 
    print y.eval() 
    print tf.Tensor.get_shape(y) 


[[[ 1 2 3] 
    [ 3 5 6]] 
[[10 11 12] 
    [ 7 8 9]] 
[[16 17 18] 
    [16 17 18]]] 

形状:(3, 2, 3)


idx = tf.tile(idx, [1, int(repTime)]) 
TypeError: int() argument must be a string or a number, not 'Tensor' 

Python 2.7、テンソルフロー0.12

ありがとうございます。 AllenLavoieのコメント@に


。それが終わると 'tf.gather'はうまくいくでしょう。 –


こんにちは@AllenLavoie:ありがとうございます。私はまた、ここからそのアイディアを得ました[https://github.com/tensorflow/tensorflow/issues/206]。しかし、私はそれをプレースホルダー入力のために働かせることはできません。編集した質問をすばやくご覧いただけますか? –


'tf.shape'を使って次元を整数のTensorとして得ることができます。これは、静的形状情報が利用できない場合でも機能します。 –




def tile_repeat(n, repTime): 
    create something like 111..122..2333..33 ..... n..nn 
    one particular number appears repTime consecutively. 
    This is for flattening the indices. 
    print n, repTime 
    idx = tf.range(n) 
    idx = tf.reshape(idx, [-1, 1]) # Convert to a n x 1 matrix. 
    idx = tf.tile(idx, [1, repTime]) # Create multiple columns, each column has one number repeats repTime 
    y = tf.reshape(idx, [-1]) 
    return y 

def gather_along_second_axis(x, idx): 
    x has shape: [batch_size, sentence_length, word_dim] 
    idx has shape: [batch_size, num_indices] 
    Basically, in each batch, get words from sentence having index specified in idx 
    However, since tensorflow does not fully support indexing, 
    gather only work for the first axis. We have to reshape the input data, gather then reshape again 
    reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices] 
    idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(x)[1]) * tf.shape(x)[1] + reshapedIdx 
    y = tf.gather(tf.reshape(x, [-1,tf.shape(x)[2]]), # flatten input 
    y = tf.reshape(y, tf.shape(x)) 
    return y 

x = tf.constant([ 

y = gather_along_second_axis(x, idx) 
with tf.Session(''): 
    print y.eval() 
    print tf.Tensor.get_shape(y) 