2017-09-05 12 views
1

に行ごとのサブ配列を取得する方法:]TensorFlow:私は次のコードたテンソル

[1,2,3]、[2,3を以下のように

import numpy as np 
import tensorflow as tf 

series = tf.placeholder(tf.float32, shape=[None, 5]) 
series_length = tf.placeholder(tf.int32, shape=[None]) 
useful_series = tf.magic_slice_function(series, series_length) 

with tf.Session() as sess: 
    input_x = np.array([[1, 2, 3, 0, 0], 
         [2, 3, 0, 0, 0], 
         [1, 0, 0, 0, 0]]) 
    input_y = np.array([[3], [2], [1]]) 
    print(sess.run(useful_series, feed_dict={series: input_x, series_length: input_y})) 

期待出力、 [1]]

私はtf.gather、tf.sliceなどのいくつかの関数を試しました。それらのすべては機能しません。 magic_slice_functionとは何ですか?

+1

何を取得することはテンソルではないので、おそらく、あなたは、Tensorflow外でこれを実行する必要があります。 –

答えて

1

それは少しトリッキーです:

import numpy as np 
import tensorflow as tf 

series = tf.placeholder(tf.float32, shape=[None, 5]) 
series_length = tf.placeholder(tf.int64) 

def magic_slice_function(input_x, input_y): 
    array = [] 
    for i in range(len(input_x)): 
     temp = [input_x[i][j] for j in range(input_y[i])] 
     array.extend(temp) 
    return [array] 

with tf.Session() as sess: 
    input_x = np.array([[1, 2, 3, 0, 0], 
         [2, 3, 0, 0, 0], 
         [1, 0, 0, 0, 0]]) 

    input_y = np.array([3, 2, 1], dtype=np.int64) 

    merged_series = tf.py_func(magic_slice_function, [series, series_length], tf.float32, name='slice_func') 

    out = tf.split(merged_series, input_y) 
    print(sess.run(out, feed_dict={series: input_x, series_length: input_y})) 

出力は次のようになります。

[array([ 1., 2., 3.], dtype=float32), array([ 2., 3.], dtype=float32), array([ 1.], dtype=float32)] 
関連する問題