あなたはテンソル名を取得するためにstring.split()を使用することができます。
...
reader = tf.train.NewCheckpointReader(ckpt_path)
for v in tf.trainable_variables():
tensor_name = v.name.split(':')[0]
print tensor_name
if reader.has_tensor(tensor_name):
print 'has tensor'
...
次に、私が.cpkt
ファイルから可能なすべての変数を復元する方法を示すために、例を使用してみましょう。まずは、tmp.ckpt
でv2
とv3
を保存してみましょう:
import tensorflow as tf
v1 = tf.Variable(tf.ones([1]), name='v1')
v2 = tf.Variable(2 * tf.ones([1]), name='v2')
v3 = tf.Variable(3 * tf.ones([1]), name='v3')
saver = tf.train.Saver({'v2': v2, 'v3': v3})
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
saver.save(sess, 'tmp.ckpt')
私はtmp.ckpt
に現れて(新しいグラフに属している)すべての変数を復元する方法をだこと:
with tf.Graph().as_default():
assert len(tf.trainable_variables()) == 0
v1 = tf.Variable(tf.zeros([1]), name='v1')
v2 = tf.Variable(tf.zeros([1]), name='v2')
reader = tf.train.NewCheckpointReader('tmp.ckpt')
restore_dict = dict()
for v in tf.trainable_variables():
tensor_name = v.name.split(':')[0]
if reader.has_tensor(tensor_name):
print('has tensor ', tensor_name)
restore_dict[tensor_name] = v
saver = tf.train.Saver(restore_dict)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
saver.restore(sess, 'tmp.ckpt')
print(sess.run([v1, v2])) # prints [array([ 0.], dtype=float32), array([ 2.], dtype=float32)]
また、あなたが確実にすることをお勧めしますそのシェイプとdtypeが一致します。
ありがとうrvinas !!私は、私が見たものの中から所望の変数を復元する方法が最も明確な例であると思います。 – user270700
あなたは大歓迎です! – rvinas