2017-04-26 9 views
1

私はtf.train.import_meta_graphでインポートされた保存されたグラフ定義を持っています。グラフにはpy_func opが含まれていますが、これはシリアライズできません。グラフを最初から構築せずに、このオペレーションにPython関数を定義して割り当てることはできますか?TensorFlowでGraphDefをインポートした後にpy_func opを設定する

答えて

4

これは可能ですが、少し壊れやすいかもしれません。特に、pyfuncは元のグラフで定義されているのと同じ順序で再定義する必要があります(FuncRegistryで同じ識別子を持つように)。

例。我々の場合、新しいPythonのセッションでは、その後、

1 37350.4 -0.0934748 0.193026 
1001 19.2717 12.3749 5.40368 
2001 0.108373 12.9532 2.2548 
3001 8.28227e-06 12.9996 2.00222 
4001 3.77258e-09 13.0 2.00004 

:これは、いくつかの基本的なトレーニング(py_funcで見つかった線形関数をマッチング)を行い

import tensorflow as tf 

def my_py_func(x): 
    return 13. * x + 2. 

def train_model(): 
    with tf.Graph().as_default(): 
    some_input = tf.constant([[1., 2., 3., 4.], 
           [5., 6., 7., 8.]]) 
    after_py_func = tf.py_func(my_py_func, [some_input], Tout=tf.float32, 
           name="my_py_func") 
    coefficient = tf.get_variable(
     "coefficient", 
     shape=[]) 
    bias = tf.get_variable(
     "bias", 
     shape=[]) 
    loss = tf.reduce_sum((coefficient * some_input + bias - after_py_func) ** 2) 
    global_step = tf.contrib.framework.get_or_create_global_step() 
    train_op = tf.group(tf.train.AdamOptimizer(0.1).minimize(loss), 
         tf.assign_add(global_step, 1)) 
    # Make it easy to retreive things we care about when the metagraph is reloaded. 
    tf.add_to_collection('useful_ops', bias) 
    tf.add_to_collection('useful_ops', coefficient) 
    tf.add_to_collection('useful_ops', loss) 
    tf.add_to_collection('useful_ops', train_op) 
    tf.add_to_collection('useful_ops', global_step) 
    tf.add_to_collection('useful_ops', some_input) 
    init_op = tf.global_variables_initializer() 
    saver = tf.train.Saver() 
    with tf.Session() as session: 
     session.run(init_op) 
     for i in range(5000): 
     (_, evaled_loss, evaled_coefficient, evaled_bias, 
     evaled_global_step) = session.run(
      [train_op, loss, coefficient, bias, global_step]) 
     if i % 1000 == 0: 
      print(evaled_global_step, evaled_loss, evaled_coefficient, 
       evaled_bias) 
     saver.save(session, "./trained_pyfunc_model", global_step=global_step) 

:私たちはpy_funcが含まれてグラフを定義することができます再定義pyfuncなしメタグラフをロードしようと、我々はエラーを取得:

def load_model(): 
    with tf.Graph().as_default(): 
    saver = tf.train.import_meta_graph("./trained_pyfunc_model-5000.meta") 
    bias, coefficient, loss, train_op, global_step, some_input = tf.get_collection('useful_ops') 
    #after_py_func = tf.py_func(my_py_func, [some_input], Tout=tf.float32, 
    #       name="my_py_func") 
    with tf.Session() as session: 
     saver.restore(session, "./trained_pyfunc_model-5000") 
     (_, evaled_loss, evaled_coefficient, evaled_bias, 
     evaled_global_step) = session.run(
      [train_op, loss, coefficient, bias, global_step]) 
     print("Restored: ", evaled_global_step, evaled_loss, evaled_coefficient, evaled_bias) 

UnknownError (see above for traceback): KeyError: 'pyfunc_0'

しかし、番目の限りEのpy_funcsが同じ順序で定義されており、同じ実装を持っている、我々は問題ないはずです。

Restored: 5001 1.77897e-09 13.0 2.00003 

def load_model(): 
    with tf.Graph().as_default(): 
    saver = tf.train.import_meta_graph("./trained_pyfunc_model-5000.meta") 
    bias, coefficient, loss, train_op, global_step, some_input = tf.get_collection('useful_ops') 
    after_py_func = tf.py_func(my_py_func, [some_input], Tout=tf.float32, 
           name="my_py_func") 
    with tf.Session() as session: 
     saver.restore(session, "./trained_pyfunc_model-5000") 
     (_, evaled_loss, evaled_coefficient, evaled_bias, 
     evaled_global_step) = session.run(
      [train_op, loss, coefficient, bias, global_step]) 
     print("Restored: ", evaled_global_step, evaled_loss, evaled_coefficient, evaled_bias) 

これは、私たちはトレーニングを続ける、あるいは我々が復元モデルをどうしたいのですが、他のものは何でもすることができます

ステートフルなpy_funcsは扱いにくいでしょう:TensorFlowは、関連付けられたPython変数を保存していません!

関連する問題