2016-08-25 10 views
2

同じ変数名が現在のモデルにある場合、チェックポイントファイルからいくつかの変数を復元しようとしています。
そして私は、だから私は以下のようにhas_tensor("variable.name")を使用して、チェックポイントファイル内の変数名をチェックして何をしたいのかTensorfow GithubTensorflowでCheckpointReaderを使用して変数を復元する方法

のように、いくつかの方法があることを発見し、

...  
reader = tf.train.NewCheckpointReader(ckpt_path) 
for v in tf.trainable_variables(): 
    print v.name 
    if reader.has_tensor(v.name): 
     print 'has tensor' 
... 

しかし、私が見つかりました。v.nameリターンの両方の変数ことnameおよびcolon+number。たとえば、私は変数名W_ob_oを持っていて、次にv.nameW_o:0, b_o:0を返します。

しかしreader.has_tensor()colonW_o, b_oなどnumberなしnameが必要です。

私の質問は、変数を読み込むために変数名の最後にcolonnumberを削除するにはどうすればいいですか?
このような変数を復元する方法はありますか?

答えて

4

あなたはテンソル名を取得するためにstring.split()を使用することができます。

...  
reader = tf.train.NewCheckpointReader(ckpt_path) 
for v in tf.trainable_variables(): 
    tensor_name = v.name.split(':')[0] 
    print tensor_name 
    if reader.has_tensor(tensor_name): 
     print 'has tensor' 
... 

次に、私が.cpktファイルから可能なすべての変数を復元する方法を示すために、例を使用してみましょう。まずは、tmp.ckptv2v3を保存してみましょう:

import tensorflow as tf 

v1 = tf.Variable(tf.ones([1]), name='v1') 
v2 = tf.Variable(2 * tf.ones([1]), name='v2') 
v3 = tf.Variable(3 * tf.ones([1]), name='v3') 

saver = tf.train.Saver({'v2': v2, 'v3': v3}) 

with tf.Session() as sess: 
    sess.run(tf.initialize_all_variables()) 
    saver.save(sess, 'tmp.ckpt') 

私はtmp.ckptに現れて(新しいグラフに属している)すべての変数を復元する方法をだこと:

with tf.Graph().as_default(): 
    assert len(tf.trainable_variables()) == 0 
    v1 = tf.Variable(tf.zeros([1]), name='v1') 
    v2 = tf.Variable(tf.zeros([1]), name='v2') 

    reader = tf.train.NewCheckpointReader('tmp.ckpt') 
    restore_dict = dict() 
    for v in tf.trainable_variables(): 
     tensor_name = v.name.split(':')[0] 
     if reader.has_tensor(tensor_name): 
      print('has tensor ', tensor_name) 
      restore_dict[tensor_name] = v 

    saver = tf.train.Saver(restore_dict) 
    with tf.Session() as sess: 
     sess.run(tf.initialize_all_variables()) 
     saver.restore(sess, 'tmp.ckpt') 
     print(sess.run([v1, v2])) # prints [array([ 0.], dtype=float32), array([ 2.], dtype=float32)] 

また、あなたが確実にすることをお勧めしますそのシェイプとdtypeが一致します。

+0

ありがとうrvinas !!私は、私が見たものの中から所望の変数を復元する方法が最も明確な例であると思います。 – user270700

+0

あなたは大歓迎です! – rvinas

1

tf.train.NewCheckpointReaderは、CheckpointReaderオブジェクトを作成する便利なメソッドです。 CheckpointReaderには、いくつかの非常に便利なメソッドがあります。あなたの質問に最も関連するメソッドは、get_variable_to_shape_map()です。

  • get_variable_to_shape_map()変数名と形状を持つ辞書を提供します。

saved_shapes = reader.get_variable_to_shape_map() 
 
print 'fire9/squeeze1x1/kernels:', saved_shapes['fire9/squeeze1x1/kernels']

以下、この簡単なチュートリアルをご覧ください: Loading Variables from Existing Checkpoints

+0

答えをありがとう!私はあなたのチュートリアルをチェックし、あなたの方法を試してみます。 – user270700

関連する問題