2017-06-23 11 views
0

Tensorflowを初めて使用しています。私は、PythonでTensorflowを使って、疎な行列入力で動作する関数を記述しようとしています。通常はテンソルフローのプレースホルダーを定義しますが、スパース行列のプレースホルダーはありません。Tensorflow関数でスパース行列引数を使用する

テンソルフローのスパースデータで動作し、値を渡す関数を定義する適切な方法は何ですか?

具体的には、ここではhttps://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/multilayer_perceptron.pyにある多層パーセプトロンの基本的な例を、密度の代わりに疎な入力を受け入れるように書き直そうとしています。

ダミーの例として、どのようにこのような関数を記述しますか?リンクhttps://github.com/tensorflow/tensorflow/issues/342

import tensorflow as tf 


x = tf.placeholder("sparse") 
y = tf.placeholder("float", [None, n_classes]) 

# Create model 
def sparse_multiply(x, y): 

    outlayer = tf.sparse_tensor_dense_matmul(x, y) 

    return out_layer 

pred = multiply(x, y) 

# Launch the graph 
with tf.Session() as sess: 
    result = sess.run(pred, feed_dict={x: x_input, y: y_input}) 

誰かがスパース行列を構築するために必要な要素を渡し、その後、関数内でその場で疎行列を作成、回避策として、推奨されます。それはちょっとハッキリしているように見えます。そのように構築しようとするとエラーが発生します。

何か助け、特にコードによる回答をいただければ幸いです。

答えて

0

私はそれを理解したと思います。私がリンクした提案は実際にはうまくいきました。一貫したタイプを持つようにすべての入力を修正する必要がありました。ここでは、問題に記載されているダミーの例を正しく示しています。

import tensorflow as tf 

import sklearn.feature_extraction 
import numpy as np 


def convert_csr_to_sparse_tensor_inputs(X): 
    coo = X.tocoo() 
    indices = np.mat([coo.row, coo.col]).transpose() 
    return indices, coo.data, coo.shape 


X = ____ #Some sparse 2 x 8 csr matrix 

y_input = np.asarray([1, 1, 1, 1, 1, 1, 1, 1]) 
y_input.shape = (8,1) 


x_indices, x_values, x_shape = convert_csr_to_sparse_tensor_inputs(X) 

# tf Graph input 
y = tf.placeholder(tf.float64) 
values = tf.placeholder(tf.float64) 
indices = tf.placeholder(tf.int64) 
shape = tf.placeholder(tf.int64) 

# Create model 
def multiply(values, indices, shape, y): 

    x_tensor = tf.SparseTensor(indices, values, shape)  

    out_layer = tf.sparse_tensor_dense_matmul(x_tensor, y) 


    return out_layer 

pred = multiply(values, indices, shape, y) 

# Launch the graph 
with tf.Session() as sess: 
    result = sess.run(pred, feed_dict={values: x_values, indices: x_indices, shape: x_shape, y: y_input}) 
関連する問題