2017-11-08 20 views
2

私はテンソルフローが新しく、単純なニューラルネットワークを訓練してきましたが、訓練を受けた後、入力の出力を得るためにNNを再利用する方法はわかりません。Tensorflow再利用ニューラルネットワーク

def train_neural_network(x,y,aDataTrain,aTargetTrain,aDataTest,aTargetTest): 
    batch_size = 500 
    prediction = neural_network_model(x,len(aDataTrain[0])) 
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y)) 
    optimizer = tf.train.AdamOptimizer().minimize(cost) 
    hm_epochs = 1 

    saver = tf.train.Saver() 
    with tf.Session() as sess: 
     sess.run(tf.global_variables_initializer()) 

     for epoch in range(hm_epochs): 
      epoch_loss = 0 
      i = 0 
      while i < len(aDataTrain): 
       start = i 
       end = i + batch_size 
       batch_x = np.array(aDataTrain[start:end]) 
       batch_y = np.array(aTargetTrain[start:end]) 

       _,c = sess.run([optimizer,cost],feed_dict={x:batch_x,y:batch_y}) 
       epoch_loss += c 
       i += batch_size 
      print ("Epoch", epoch, "completed out of", hm_epochs, "loss", epoch_loss) 

     correct =tf.equal(tf.argmax(prediction,1), tf.argmax(y,1)) 

     accurracy = tf.reduce_mean(tf.cast(correct,'float')) 
     finalAcc = accurracy.eval({x:aDataTest,y:aTargetTest}) 
     saver.save(sess, 'model/model.ckpt') 

    print("Accuracy:",finalAcc) 

だから、かつて私がモデルを保存し、それを復元しようとしましたが、私は「INPUT_DATA」からNNの出力を得るために継続するのか分かりません。あなたと仮定すると

def execute_neural_network(x,y,aDataTrain,aTargetTrain,aDataTest,aTargetTest): 
    batch_size = 1 
    y_pred = [] 

    prediction = neural_network_model(x,len(aDataTrain[0])) 
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y)) 
    optimizer = tf.train.AdamOptimizer().minimize(cost) 

    input_data = [5.0, 3.0, 1.0, 5.0, 6.0, 5.0, 2.0, 4.0, 7.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 3.0, 4.0, 3.0, 3.0, 2.0, 4.0, 3.0, 3.0, 2.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 61.0, 21.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 75.0, 3.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 35.0, 11.0, 10.0, 33.0, 24.0, 6.0, 2.0, 2.0, 3.0, 4.0, 3.0, 3.0, 8.0, 6.0, 5.0, 6.0, 5.0, 8.0, 9.0, 13.0, 7.0, 25.0, 11.0, 2.0, 2.0, 2.0, 2.0, 2.0] 

    saver = tf.train.Saver() 
    with tf.Session() as sess: 
     saver.restore(sess, 'model/model.ckpt') 
     #Get neural network output from input_data 

答えて

0

は次のように何らかの形であなたのグラフ/ネットワークのモデルを作成:

predictionOp
with tf.Session() as sess: 
    #do other stuff 
    predictionOp = tf.argmax(py_x, 1) 

saver.save(sess, 'model') 

がネットワークの出力である変数です。

tf.add_to_collection("predictionOp", predictionOp) のように後で追加すると、predictionOpに簡単に名前を付けることができます。その後、あなたはあなたのモデルをリロードして予測を得ることができます:

with tf.Session() as sess: 
    new_saver = tf.train.import_meta_graph('model.meta') 
    new_saver.restore(sess, 'model') 
    predictionOp = tf.get_collection("predictionOp")[0] 

    #get the prediction 
    prediction = sess.run(predictionOp, feed_dict={"x:0": input_data}) 

は、詳細については、基本についていくつかの詳細については、tensorflowdocumentationhereをご覧ください。また、同様の問題を処理する他のスレッドもあります(thisthis oneなど)。

関連する問題