2017-04-05 14 views
0

タイトルが述べるように、私はtensorflowの行列から行ごとに最も高いn個の要素を抽出し、結果を疎なTensorに格納しようとしています。tf.nn.top_nの出力を疎行列に変換します

私はtf.nn.top_nでインデックスと値を抽出できましたが、インデックスはtf.SparseTensorで必要な規約に従いません。

具体的には、tf.nn.top_nは結果の値行列(行xn)と同じ形のcolインデックスの行列を返しますが、tf.SparseTensorは1行1列の(非ゼロx2)行列を求めます。 0でない要素と、行と列のインデックスを保持する列。

値は、値の行列の代わりに非ゼロ要素のリストが必要な同様の問題になります。

どのようにこれらのインデックス表記スキームを素早く変換できますか?

答えて

2

これは、モジュラ演算のビットで行うことができます。行列に作用する例がありますが、より多くの軸をループすることは可能です。

import tensorflow as tf 

def slices_to_dims(slice_indices): 
    """ 
    Args: 
    slice_indices: An [N, k] Tensor mapping to column indices. 
    Returns: 
    An index Tensor with shape [N * k, 2], corresponding to indices suitable for 
    passing to SparseTensor. 
    """ 
    slice_indices = tf.cast(slice_indices, tf.int64) 
    num_rows = tf.shape(slice_indices, out_type=tf.int64)[0] 
    row_range = tf.range(num_rows) 
    item_numbers = slice_indices * num_rows + tf.expand_dims(row_range, axis=1) 
    item_numbers_flat = tf.reshape(item_numbers, [-1]) 
    return tf.stack([item_numbers_flat % num_rows, 
        item_numbers_flat // num_rows], axis=1) 

使用例:

dense_shape = [5, 7] 
dense_matrix = tf.random_normal(shape=dense_shape) 
top_values, top_indices = tf.nn.top_k(dense_matrix, k=2) 
sparse_indices = slices_to_dims(top_indices) 
sparse_tensor = tf.sparse_reorder(tf.SparseTensor(
    indices=sparse_indices, 
    values=tf.reshape(top_values, [-1]), 
    dense_shape=dense_shape)) 
densified_top = tf.sparse_tensor_to_dense(sparse_tensor) 
with tf.Session() as session: 
    sparse_top, dense_original, dense_selected = session.run(
     [sparse_tensor, dense_matrix, densified_top]) 
    print(dense_original) 
    print(dense_selected) 
    print(sparse_top) 

プリント:魔法のように

[[ 1.44056129 -1.01790774 -0.2795608 2.34854746 -2.27528405 -0.62035948 
    3.36598897] 
[ 0.7114948 -0.42564821 -0.93446779 -0.25373486 -0.51730365 0.72331643 
    -0.75625718] 
[-0.6501748 -0.92748415 -0.95409006 -0.07157528 0.80637723 -0.32177576 
    -1.4516511 ] 
[-1.081038 -0.67226124 -1.19455576 0.44537872 -0.69019234 -0.61539739 
    0.15328468] 
[ 0.43032476 -0.11295394 0.83491379 -0.67906654 0.20325914 -0.0155068 
    0.52107805]] 
[[ 0.   0.   0.   2.34854746 0.   0. 
    3.36598897] 
[ 0.7114948 0.   0.   0.   0.   0.72331643 
    0.  ] 
[ 0.   0.   0.   -0.07157528 0.80637723 0.   0.  ] 
[ 0.   0.   0.   0.44537872 0.   0. 
    0.15328468] 
[ 0.   0.   0.83491379 0.   0.   0. 
    0.52107805]] 
SparseTensorValue(indices=array([[0, 3], 
     [0, 6], 
     [1, 0], 
     [1, 5], 
     [2, 3], 
     [2, 4], 
     [3, 3], 
     [3, 6], 
     [4, 2], 
     [4, 6]]), values=array([ 2.34854746, 3.36598897, 0.7114948 , 0.72331643, -0.07157528, 
     0.80637723, 0.44537872, 0.15328468, 0.83491379, 0.52107805], dtype=float32), dense_shape=array([5, 7])) 
+0

作品!私はこのような方法でオーバーヘッドが大きくなることを心配していましたが、それは非常にうまくいくようです。ありがとう! – zergylord

関連する問題