リカレントニューラルネットワーク(LSTM)を訓練し、重みとメタグラフを保存しました。予測のためにメタグラフを取得すると、シーケンスの長さがトレーニング中のものと同じであれば、すべてが完全に機能します。Tensorflow:メタ文字を取得するときのプレースホルダの形状を変更する
LSTMの利点の1つは、入力のシーケンス長が変わる可能性があることです(たとえば、入力が文章を構成する文字である場合、文章の長さは異なる場合があります)。
メタグラフからグラフを取得するときに、入力のシーケンス長を変更するにはどうすればよいですか?コードと
詳細:トレーニングの間
は、私がデータを供給するために、プレースホルダx
とy
を使用しています。予測のために、これらのプレースホルダを取得しますが、形状を変更することはできません([None, previous_sequence_length=100, n_input]
から[None, new_sequence_length=50, n_input]
まで)。私は予測のためメタグラフを取得するファイルprediction.py
で
self.x = tf.placeholder("float32", [None, self.n_steps, self.n_input], name='x_input')
self.y = tf.placeholder("float32", [None, self.n_classes], name='y_labels')
tf.add_to_collection('x', self.x)
tf.add_to_collection('y', self.y)
...
def build_model(self):
#using the placeholder self.x to build the model
...
tf.split(0, self.n_input, self.x) # split input for RNN cell
...
:アーキテクチャとプレースホルダを定義するファイルmodel.py
で
、ここで
with tf.Session() as sess:
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir=checkpoint_dir)
new_saver = tf.train.import_meta_graph(latest_checkpoint + '.meta')
new_saver.restore(sess, latest_checkpoint)
x = tf.get_collection('x')[0]
y = tf.get_collection('y')[0]
...
sess.run(..., feed_dict={x: batch_x})
は私が得るエラーです:
ValueError: Cannot feed value of shape (128, 50, 2) for Tensor u'placeholders/x_input:0', which has shape '(?, 100, 2)'
注:私はこの問題を解決することができますn メタテキストを使用せず、ゼロからモデルを再構築し、保存された重みだけをロードします(メタグラフではありません)。
EDIT:None
でself.n_steps
を交換しtf.split(0, self.x.get_shape()[1], self.x)
とtf.split(0, self.n_input, self.x)
を変更するとき、私は次のエラーを取得する:TypeError: Expected int for argument 'num_split' not Dimension(None).
事実の後にテンソルの形を変えることはできません。しかし、あなたができることの1つは、トレーニング中にすべてのディメンションの形状を修正するのではなく、それらを指定しないままにすることです。あなたが供給するテンソルの次元は、プレースホルダーの形状と互換性がなければなりませんが、最初にすべてのプレースホルダー次元を指定する必要はありません。ここでは、self.n_stepsの代わりに "None"を入れてみてください。 –
私は実際に質問を投稿する前にそれをやろうとしましたが、モデル作成中のある時点で、私は 'tf.split(0、self.n_input、self.x)'を持っています。私は 'self.n_input'を知らない/修正するときに、代わりに' self.x.get_shape()[1] '(' tf.split(0、self.x.get_shape()[1]、selfを置きます。 x) ')。しかし、私は次のエラーが発生します: 'TypeError:引数 'num_split'のための予想されるintではありません次元(なし)。 – BiBi