2017-02-09 14 views
1

TensorFlowのHashTable検索実装を使用して、私のSparseTensorをデフォルト値で戻します。私はそれをきれいにして、デフォルト値なしで最終的なSparseTensorを取得したいと思います。Sparse Tensorからゼロ以外の値を取得する方法

このデフォルト値を消去するにはどうすればよいですか?私はこれを実現させるためにデフォルト値が何であるか気にしません。 0はうまく、-1も-1です。

答えて

0

tf.sparse_retainは動作するはずです:

import tensorflow as tf 

a = tf.SparseTensor(indices=[[1, 2], [2, 2]], values=[0., 1.], shape=[3, 3]) 
with tf.Session() as session: 
    print(session.run([a, sparse_remove(a)])) 

プリントが(私は少しそれを再フォーマットしました):

[SparseTensorValue(indices=array([[1, 2], [2, 2]]), values=array([ 0., 1.], dtype=float32), shape=array([3, 3])), 
SparseTensorValue(indices=array([[2, 2]]), values=array([ 1.], dtype=float32), shape=array([3, 3]))] 
例として

def sparse_remove(sparse_tensor, remove_value=0.): 
    return tf.sparse_retain(sparse_tensor, tf.not_equal(a.values, remove_value)) 

関連する問題