2017-04-13 5 views
1

文を生成する目的で次のようにモデルを訓練しました: 訓練例2としてシーケンスを供給します:xは文字のシーケンスであり、yは1 。モデルはLSTMに基づいており、テンソルフローで作成されています。
私の質問はです:モデルは入力に(私の場合は50)特定のサイズのシーケンスを取るために、どのように私はシードとして彼にのみ単一の文字を与えて予測を行うことができますか?私はいくつかの例でそれを見たことがあります。私はあなたの代わりに、実行時間中にグラフを作成し、任意の長さの入力を持つことができますstatic_rnn、のdynamic_rnn使用することをお勧め訓練された文字レベルのLSTMモデルを使ってテキストを生成

with tf.name_scope('input'): 
     x = tf.placeholder(tf.float32, [batch_size, truncated_backprop], name='x') 
     y = tf.placeholder(tf.int32, [batch_size, truncated_backprop], name='y') 

    with tf.name_scope('weights'): 
     W = tf.Variable(np.random.rand(n_hidden, num_classes), dtype=tf.float32) 
     b = tf.Variable(np.random.rand(1, num_classes), dtype=tf.float32) 

    inputs_series = tf.split(x, truncated_backprop, 1) 
    labels_series = tf.unstack(y, axis=1) 

    with tf.name_scope('LSTM'): 
     cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, state_is_tuple=True) 
     cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=dropout) 
     cell = tf.contrib.rnn.MultiRNNCell([cell] * n_layers) 

    states_series, current_state = tf.contrib.rnn.static_rnn(cell, inputs_series, \ 
     dtype=tf.float32) 

    logits_series = [tf.matmul(state, W) + b for state in states_series] 
    prediction_series = [tf.nn.softmax(logits) for logits in logits_series] 

    losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) \ 
     for logits, labels, in zip(logits_series, labels_series)] 
    total_loss = tf.reduce_mean(losses) 

    train_step = tf.train.AdamOptimizer(learning_rate).minimize(total_loss) 

答えて

2


はここに私のコードです。入力したプレースホルダは

x = tf.placeholder(tf.float32, [batch_size, None, features], name='x') 

です。次に、独自の初期状態をネットワークに入力する方法が必要です。あなたが好きな、dynamic_rnninitial_stateパラメータを渡すことによってそれを行うことができます、あなたが一度にグラフ1文字を養うことができ、単一の文字からテキストを生成するために、それと

initialstate = cell.zero_state(batch_sie, tf.float32) 
outputs, current_state = tf.nn.dynamic_rnn(cell, 
              inputs, 
              initial_state=initialstate) 

前の文字を渡すと、毎回次のような状態です:

prompt = 's' # beginning character, whatever 
inp = one_hot(prompt) # preprocessing, as you probably want to feed one-hot vectors 
state = None 
while True: 
    if state is None: 
     feed = {x: [[inp]]} 
    else: 
     feed = {x: [[inp]], initialstate: state} 

    out, state = sess.run([outputs, current_state], feed_dict=feed) 

    inp = process(out) # extract the predicted character from out and one-hot it 
+0

ありがとうございます。ダイナミックRNNのトリックは本当にきれいです。今はもっとはっきりしています。 – JimZer

関連する問題