2017-02-28 4 views
4

私はTensorFlowを初めて使いました。リカレントニューラルネットワークに入力するデータをフォーマットしています。私のデータは、プレースホルダxに入力された3次元テンソルによって与えられます。私は、第三の次元に沿ってxを分割したい、そして、このために私は(n_timestepsは、第三の次元に沿っxの長さに対応することに注意してください)があります。TensorFlow - Split and Squeeze

# Split the previous 3d tensor to get a list of 'n_timesteps' 2d tensors of 
# shape (batch_size, features_dimension) 
x = tf.split (x, n_timesteps, axis = 2) 

私はnumpyと試みているように、けれども:

x = np.split (x, n_timesteps, axis = 2) 

xが3D ndarrayある場合、np.splitは、第3の寸法がシングルトンであるように、ディメンション3でn_timesteps配列のリストを返します。私はTFで同じことを行うことができます

x = [np.squeeze(a, axis=2) for a in np.split(x, n_timesteps, axis=2)] 

しかし、どのように:numpyで、私は私が簡単にシングルトン次元を削除するには、リスト内包表記と一緒にnp.squeezeを使用してこの問題を解決できることをご存知ですか?

答えて

0

リストの理解ではなく、Tensorflow(tf.squeeze)のスクイーズ機能とTensorflow(tf.scan)のスキャン機能を使用してください。

tf.scan(lambda a, x_i: tf.squeeze(x_i, [2]), x, initializer=tf.constant(0, shape=[n_dim0, n_dim1])) 
あなたは tf.unstack OPを探している可能性があり
2

x = tf.unstack(x, axis=2)