2017-11-13 11 views
0

私はTensorflowの使い方が新しく、これはグーグルで見つけたコードのサンプルです。私はグラフをフリーズしようとしましたが、私は正しい出力ノードを入力する必要があると言いました。私はこれに新しいので、私はそれを理解するのに苦労している。このコードで出力ノードを見つけるにはどうしたらいいですか、グラフ全体をフリーズする必要はありますか?TensorFlow:私のTensorflow訓練モデルで出力ノードを見つけるにはどうすればいいですか?

輸入tensorflow NP 輸入SYS

クラスSeq2Seq(オブジェクト)として 輸入numpyのTFとして:あなたは最初にあなたのすでに凍結されたグラフをロード/読む必要が

def __init__(self, xseq_len, yseq_len, 
     xvocab_size, yvocab_size, 
     emb_dim, num_layers, ckpt_path, 
     lr=0.0001, 
     epochs=200, model_name='seq2seq_model'): 
    # attach these arguments to self 
    self.xseq_len = xseq_len 
    self.yseq_len = yseq_len 
    self.ckpt_path = ckpt_path 
    self.epochs = epochs 
    self.model_name = model_name 


    # build thy graph 
    # attach any part of the graph that needs to be exposed, to the self 
    def __graph__(): 

     # placeholders 
     tf.reset_default_graph() 
     # encoder inputs : list of indices of length xseq_len 
     self.enc_ip = [ tf.placeholder(shape=[None,], 
         dtype=tf.int64, 
         name='ei_{}'.format(t)) for t in range(xseq_len) ] 

     # labels that represent the real outputs 

     self.labels = [ tf.placeholder(shape=[None,], 
         dtype=tf.int64, 
         name='ei_{}'.format(t)) for t in range(yseq_len) ] 

     # decoder inputs : 'GO' + [ y1, y2, ... y_t-1 ] 
     self.dec_ip = [ tf.zeros_like(self.enc_ip[0], dtype=tf.int64, name='GO') ] + self.labels[:-1] 


     # Basic LSTM cell wrapped in Dropout Wrapper 
     self.keep_prob = tf.placeholder(tf.float32) 
     # define the basic cell 
     basic_cell = tf.contrib.rnn.core_rnn_cell.DropoutWrapper(
       tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(emb_dim, state_is_tuple=True), 
       output_keep_prob=self.keep_prob) 
     # stack cells together : n layered model 
     stacked_lstm = tf.contrib.rnn.core_rnn_cell.MultiRNNCell([basic_cell]*num_layers, state_is_tuple=True) 


     # for parameter sharing between training model 
     # and testing model 
     with tf.variable_scope('decoder') as scope: 
      # build the seq2seq model 
      # inputs : encoder, decoder inputs, LSTM cell type, vocabulary sizes, embedding dimensions 
      self.decode_outputs, self.decode_states = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(self.enc_ip,self.dec_ip, stacked_lstm, 
               xvocab_size, yvocab_size, emb_dim) 
      # share parameters 
      scope.reuse_variables() 
      # testing model, where output of previous timestep is fed as input 
      # to the next timestep 
      self.decode_outputs_test, self.decode_states_test = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(
       self.enc_ip, self.dec_ip, stacked_lstm, xvocab_size, yvocab_size,emb_dim, 
       feed_previous=True) 

     # now, for training, 
     # build loss function 

     # weighted loss 
     # TODO : add parameter hint 
     loss_weights = [ tf.ones_like(label, dtype=tf.float32) for label in self.labels ] 
     self.loss = tf.contrib.legacy_seq2seq.sequence_loss(self.decode_outputs, self.labels, loss_weights, yvocab_size) 
     # train op to minimize the loss 
     self.train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.loss) 

    sys.stdout.write('<log> Building Graph ') 
    # build comput graph 
    __graph__() 
    sys.stdout.write('</log>') 



''' 
    Training and Evaluation 

''' 

# get the feed dictionary 
def get_feed(self, X, Y, keep_prob): 
    feed_dict = {self.enc_ip[t]: X[t] for t in range(self.xseq_len)} 
    feed_dict.update({self.labels[t]: Y[t] for t in range(self.yseq_len)}) 
    feed_dict[self.keep_prob] = keep_prob # dropout prob 
    return feed_dict 

# run one batch for training 
def train_batch(self, sess, train_batch_gen): 
    # get batches 
    batchX, batchY = train_batch_gen.__next__() 
    # build feed 
    feed_dict = self.get_feed(batchX, batchY, keep_prob=0.5) 
    _, loss_v = sess.run([self.train_op, self.loss], feed_dict) 
    return loss_v 

def eval_step(self, sess, eval_batch_gen): 
    # get batches 
    batchX, batchY = eval_batch_gen.__next__() 
    # build feed 
    feed_dict = self.get_feed(batchX, batchY, keep_prob=1.) 
    loss_v, dec_op_v = sess.run([self.loss, self.decode_outputs_test], feed_dict) 
    # dec_op_v is a list; also need to transpose 0,1 indices 
    # (interchange batch_size and timesteps dimensions 
    dec_op_v = np.array(dec_op_v).transpose([1,0,2]) 
    return loss_v, dec_op_v, batchX, batchY 

# evaluate 'num_batches' batches 
def eval_batches(self, sess, eval_batch_gen, num_batches): 
    losses = [] 
    for i in range(num_batches): 
     loss_v, dec_op_v, batchX, batchY = self.eval_step(sess, eval_batch_gen) 
     losses.append(loss_v) 
    return np.mean(losses) 

# finally the train function that 
# runs the train_op in a session 
# evaluates on valid set periodically 
# prints statistics 
def train(self, train_set, valid_set, sess=None): 

    # we need to save the model periodically 
    saver = tf.train.Saver() 

    # if no session is given 
    if not sess: 
     # create a session 
     sess = tf.Session() 
     # init all variables 
     sess.run(tf.global_variables_initializer()) 

    sys.stdout.write('\n<log> Training started </log>\n') 
    # run M epochs 
    for i in range(self.epochs): 
     try: 
      self.train_batch(sess, train_set) 
      if i and i% (self.epochs//5) == 0: # TODO : make this tunable by the user 
       # save model to disk 
       saver.save(sess, self.ckpt_path + self.model_name + '.ckpt', global_step=i) 
       # evaluate to get validation loss 
       val_loss = self.eval_batches(sess, valid_set, 16) # TODO : and this 
       # print stats 
       print('\nModel saved to disk at iteration #{}'.format(i)) 
       print('val loss : {0:.6f}'.format(val_loss)) 
       sys.stdout.flush() 
     except KeyboardInterrupt: # this will most definitely happen, so handle it 
      print('Interrupted by user at iteration {}'.format(i)) 
      self.session = sess 
      return sess 

def restore_last_session(self): 
    saver = tf.train.Saver() 
    # create a session 
    sess = tf.Session() 
    # get checkpoint state 
    ckpt = tf.train.get_checkpoint_state(self.ckpt_path) 
    # restore session 
    if ckpt and ckpt.model_checkpoint_path: 
     saver.restore(sess, ckpt.model_checkpoint_path) 
    # return to user 
    return sess 

# prediction 
def predict(self, sess, X): 
    feed_dict = {self.enc_ip[t]: X[t] for t in range(self.xseq_len)} 
    feed_dict[self.keep_prob] = 1. 
    dec_op_v = sess.run(self.decode_outputs_test, feed_dict) 
    # dec_op_v is a list; also need to transpose 0,1 indices 
    # (interchange batch_size and timesteps dimensions 
    dec_op_v = np.array(dec_op_v).transpose([1,0,2]) 
    # return the index of item with highest probability 
    return np.argmax(dec_op_v, axis=2) 

答えて

1

。あなたがロードした後

def load_graph(frozen_graph_filename): 
    # We load the protobuf file from the disk and parse it to retrieve the 
    # unserialized graph_def 
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f: 
     graph_def = tf.GraphDef() 
     graph_def.ParseFromString(f.read()) 

    # Then, we import the graph_def into a new Graph and return it 
    with tf.Graph().as_default() as graph: 
     # The name var will prefix every op/nodes in your graph 
     # Since we load everything in a new graph, this is not needed 
     tf.import_graph_def(graph_def, name="prefix") 
    return graph 

ごすでにを凍結、あなたはこのようなあなたの内のすべての演算子を印刷することができます。

for op in graph.get_operations(): 
    print (op.name(), op.value() 

参考:上記のコードはthisチュートリアルから取られています。

出力ノードは印刷された出力にあります。これは、あなたが従っているチュートリアルで最初に設定したものと似ているはずです。

# This line is from my graph where the operator's name is "y_" and the value is ":0 "  
y = graph.get_tensor_by_name('prefix/y_:0') 

注1:あなたは、出力ノード名を見つけたら、あなたが実際にこのようなテンソルを得ることができ、あなたのコード内の変数名は、ここで言及されたものと異なる場合があります。

注2:始めにTensorFlowを操作すると、処理するさまざまなメソッドとプロパティすべてで圧倒される可能性があります。私はTensorFlowの理解を得るために小さなものを始めることをお勧めします。そして、あなたが従っているような長い例を進めてください。この解答で参照されているチュートリアルを読んで理解してみて、それに従っているサンプルコードと平行して描きたいと思うかもしれません。

+0

ご協力いただきありがとうございます。私はグラフをフリーズすることができました。しかし、私はそれを読み返そうとしました。これは、 "float32型のテンソルをfloat32_ref型の入力に変換できません。"どんな考え? –

+0

これはあなたの事例に関連するかもしれません。あなたはこれをグーグル・グーグルでやるほうが良いでしょう。 –

関連する問題