2017-02-03 18 views
1

私はテンソルを持って言う:テンソルフローの4次元テンソルでゼロを省略する方法は?

import tensorflow as tf 
t = tf.Variable([[[[0., 235., 0., 0., 1006., 0., 0., 23., 42.], [77., 0., 0., 12., 0., 0., 33., 55., 0.]], 
       [[0., 132., 0., 0., 234., 0., 1., 24., 0.], [43., 0., 0., 124., 0., 0., 0., 52., 645]]]]) 

Iがゼロを省略し、形状のテンソルが残されたい:(1、2、2、4)、4は私の中の非ゼロ要素の数であるとテンソルのように

t = tf.Variable([[[[235., 1006., 23., 42], [77., 12., 33., 55.]], 
       [[132., 234., 1., 24.], [43., 124., 52., 645]]]]) 

私はブーリアンマスクを使って1次元テンソルでこれを行っています。どのようにして4次元テンソルのゼロを省略することができますか?それはより高いランクのために一般化することはできますか? TensorFlow 0.12.1を使用して

答えて

2

import tensorflow as tf 

def batch_of_vectors_nonzero_entries(batch_of_vectors): 
    """Removes non-zero entries from batched vectors. 

    Requires that each vector have the same number of non-zero entries. 

    Args: 
    batch_of_vectors: A Tensor with length-N vectors, having shape [..., N]. 
    Returns: 
    A Tensor with shape [..., M] where M is the number of non-zero entries in 
    each vector. 
    """ 
    nonzero_indices = tf.where(tf.not_equal(
     batch_of_vectors, tf.zeros_like(batch_of_vectors))) 
    # gather_nd gives us a vector containing the non-zero entries of the 
    # original Tensor 
    nonzero_values = tf.gather_nd(batch_of_vectors, nonzero_indices) 
    # Next, reshape so that all but the last dimension is the same as the input 
    # Tensor. Note that this will fail unless each vector has the same number of 
    # non-zero values. 
    reshaped_nonzero_values = tf.reshape(
     nonzero_values, 
     tf.concat(0, [tf.shape(batch_of_vectors)[:-1], [-1]])) 
    return reshaped_nonzero_values 

t = tf.Variable(
    [[[[0., 235., 0., 0., 1006., 0., 0., 23., 42.], 
     [77., 0., 0., 12., 0., 0., 33., 55., 0.]], 
     [[0., 132., 0., 0., 234., 0., 1., 24., 0.], 
     [43., 0., 0., 124., 0., 0., 0., 52., 645]]]]) 
nonzero_t = batch_of_vectors_nonzero_entries(t) 

with tf.Session(): 
    tf.global_variables_initializer().run() 
    result_evaled = nonzero_t.eval() 
    print(result_evaled.shape, result_evaled) 

プリント:

(1, 2, 2, 4) [[[[ 2.35000000e+02 1.00600000e+03 2.30000000e+01 4.20000000e+01] 
    [ 7.70000000e+01 1.20000000e+01 3.30000000e+01 5.50000000e+01]] 

    [[ 1.32000000e+02 2.34000000e+02 1.00000000e+00 2.40000000e+01] 
    [ 4.30000000e+01 1.24000000e+02 5.20000000e+01 6.45000000e+02]]]] 

結果が今までぼろぼろされて終わる場合SparseTensorのに見することが有用です。

関連する問題