2017-10-23 29 views
1

Tensorflowでこのスライシング方法を実行する方法はありますか(numpyを使用して表示します)?テンソルをリストでスライスする - TensorFlow

z = np.random.random((3,7,7,12)) 
x = z[...,[0,5]] 

なTensorflowで

x_hat = np.concatenate([z[...,[0]], z[...,[5]]], 3) 
assert np.all(x == x_hat) 
x.shape # (3, 7, 7, 2) 

ことを、この操作

tfz = tf.constant(z) 
i = np.array([0,5] dtype=np.int32) 
tfx = tfz[...,i] 

はあなたが持つ連結の結果の一貫性を保つためにリシェイプが必要

ValueError: Shapes must be equal rank, but are 0 and 1 
From merging shape 0 with other shapes. for 'strided_slice/stack_1' (op: 'Pack') with input shapes: [], [2]. 

答えて

1

緑色のような形状誤差です。残念ながら、私が望んだような簡単な方法はないようですが、これは私が思いついた一般化された解決策です:

def list_slice(tensor, indices, axis): 
    """ 
    Args 
    ---- 
    tensor (Tensor) : input tensor to slice 
    indices ([int]) : list of indices of where to perform slices 
    axis (int) : the axis to perform the slice on 
    """ 

    slices = [] 

    ## Set the shape of the output tensor. 
    # Set any unknown dimensions to -1, so that reshape can infer it correctly. 
    # Set the dimension in the slice direction to be 1, so that overall dimensions are preserved during the operation 
    shape = tensor.get_shape().as_list() 
    shape[shape==None] = -1 
    shape[axis] = 1 

    nd = len(shape) 

    for i in indices: 
     _slice = [slice(None)]*nd 
     _slice[axis] = slice(i,i+1) 
     slices.append(tf.reshape(tensor[_slice], shape)) 

    return tf.concat(slices, axis=axis) 



z = np.random.random(size=(3, 7, 7, 12)) 
x = z[...,[0,5]] 
tfz = tf.constant(z) 
tfx_hat = list_slice(tfz, [0, 5], axis=3) 
x_hat = tfx_hat.eval() 

assert np.all(x == x_hat) 
+0

私はあなたの一般化が好きです。 – greeness

1

エラーをスローします元の形状(th最初の3次元)。

z = np.arange(36) 
tfz = tf.reshape(tf.constant(z), [2, 3, 2, 3]) 
slice1 = tf.reshape(tfz[:,:,:,1], [2, 3, -1, 1]) 
slice2 = tf.reshape(tfz[:,:,:,2], [2, 3, -1, 1]) 
slice = tf.concat([slice1, slice2], axis=3) 

with tf.Session() as sess: 
    print sess.run([tfz, slice]) 


> [[[[ 0, 1, 2], 
    [ 3, 4, 5]], 

    [[ 6, 7, 8], 
    [ 9, 10, 11]], 

    [[12, 13, 14], 
    [15, 16, 17]]], 

    [[[18, 19, 20], 
    [21, 22, 23]], 

    [[24, 25, 26], 
    [27, 28, 29]], 

    [[30, 31, 32], 
    [33, 34, 35]]]] 

    # Get the last two columns 
> [[[[ 1, 2], 
    [ 4, 5]], 

    [[ 7, 8], 
    [10, 11]], 

    [[13, 14], 
    [16, 17]]], 

    [[[19, 20], 
    [22, 23]], 

    [[25, 26], 
    [28, 29]], 

    [[31, 32], 
    [34, 35]]]]