Tensorflowでこのスライシング方法を実行する方法はありますか(numpyを使用して表示します)?テンソルをリストでスライスする - TensorFlow

z = np.random.random((3,7,7,12)) 
x = z[...,[0,5]] 


x_hat = np.concatenate([z[...,[0]], z[...,[5]]], 3) 
assert np.all(x == x_hat) 
x.shape # (3, 7, 7, 2) 


tfz = tf.constant(z) 
i = np.array([0,5] dtype=np.int32) 
tfx = tfz[...,i] 


ValueError: Shapes must be equal rank, but are 0 and 1 
From merging shape 0 with other shapes. for 'strided_slice/stack_1' (op: 'Pack') with input shapes: [], [2]. 




def list_slice(tensor, indices, axis): 
    tensor (Tensor) : input tensor to slice 
    indices ([int]) : list of indices of where to perform slices 
    axis (int) : the axis to perform the slice on 

    slices = [] 

    ## Set the shape of the output tensor. 
    # Set any unknown dimensions to -1, so that reshape can infer it correctly. 
    # Set the dimension in the slice direction to be 1, so that overall dimensions are preserved during the operation 
    shape = tensor.get_shape().as_list() 
    shape[shape==None] = -1 
    shape[axis] = 1 

    nd = len(shape) 

    for i in indices: 
     _slice = [slice(None)]*nd 
     _slice[axis] = slice(i,i+1) 
     slices.append(tf.reshape(tensor[_slice], shape)) 

    return tf.concat(slices, axis=axis) 

z = np.random.random(size=(3, 7, 7, 12)) 
x = z[...,[0,5]] 
tfz = tf.constant(z) 
tfx_hat = list_slice(tfz, [0, 5], axis=3) 
x_hat = tfx_hat.eval() 

assert np.all(x == x_hat) 

私はあなたの一般化が好きです。 – greeness



z = np.arange(36) 
tfz = tf.reshape(tf.constant(z), [2, 3, 2, 3]) 
slice1 = tf.reshape(tfz[:,:,:,1], [2, 3, -1, 1]) 
slice2 = tf.reshape(tfz[:,:,:,2], [2, 3, -1, 1]) 
slice = tf.concat([slice1, slice2], axis=3) 

with tf.Session() as sess: 
    print sess.run([tfz, slice]) 

> [[[[ 0, 1, 2], 
    [ 3, 4, 5]], 

    [[ 6, 7, 8], 
    [ 9, 10, 11]], 

    [[12, 13, 14], 
    [15, 16, 17]]], 

    [[[18, 19, 20], 
    [21, 22, 23]], 

    [[24, 25, 26], 
    [27, 28, 29]], 

    [[30, 31, 32], 
    [33, 34, 35]]]] 

    # Get the last two columns 
> [[[[ 1, 2], 
    [ 4, 5]], 

    [[ 7, 8], 
    [10, 11]], 

    [[13, 14], 
    [16, 17]]], 

    [[[19, 20], 
    [22, 23]], 

    [[25, 26], 
    [28, 29]], 

    [[31, 32], 
    [34, 35]]]]