私はTensorFlowを初めて使いました。リカレントニューラルネットワークに入力するデータをフォーマットしています。私のデータは、プレースホルダx
に入力された3次元テンソルによって与えられます。私は、第三の次元に沿ってx
を分割したい、そして、このために私は(n_timesteps
は、第三の次元に沿っx
の長さに対応することに注意してください)があります。TensorFlow - Split and Squeeze
# Split the previous 3d tensor to get a list of 'n_timesteps' 2d tensors of
# shape (batch_size, features_dimension)
x = tf.split (x, n_timesteps, axis = 2)
私はnumpy
と試みているように、けれども:
x = np.split (x, n_timesteps, axis = 2)
x
が3D ndarray
ある場合、np.split
は、第3の寸法がシングルトンであるように、ディメンション3でn_timesteps
配列のリストを返します。私はTFで同じことを行うことができます
x = [np.squeeze(a, axis=2) for a in np.split(x, n_timesteps, axis=2)]
しかし、どのように:numpy
で、私は私が簡単にシングルトン次元を削除するには、リスト内包表記と一緒にnp.squeeze
を使用してこの問題を解決できることをご存知ですか?