2016-12-17 10 views
2

保存されたファイルを解析することによって、訓練されたモデルのパラメータ(コンボリューションと完全に接続されたレイヤの重みとバイアス)を他のフレームワークやiOSやTorchなどの言語に渡したいと思います。TensorFlow:訓練されたモデルパラメータを他のフレームワークにインポートできるファイルに保存する方法は?

私はtf.train.write_graph(session.graph_def, '', 'graph.pb')を試しましたが、重みとバイアスのないグラフアーキテクチャしか含まれていないようです。もしそうなら、チェックポイントファイル(saver.save(session, "model.ckpt"))を作成するのが最善の方法ですか? ckptファイルタイプをSwiftや他の言語で解析するのは簡単ですか?

提案がありましたら教えてください。

答えて

1

.ckptファイルを解析する代わりに、テンソル(あなたの場合は畳み込みレイヤーのウェイト)を評価して値をnumpyの配列として取得してみることができます。ここで(r0.10上でテスト - 新しいバージョンではいくつかの小さなAPIの変更があります)迅速なおもちゃの例である:

import tensorflow as tf 
import numpy as np 

x = tf.placeholder(np.float32, [2,1]) 
w = tf.Variable(tf.truncated_normal([2,2], stddev=0.1)) 
b = tf.Variable(tf.constant(1.0, shape=[2,1])) 
z = tf.matmul(w, x) + b 

with tf.Session() as sess: 
    sess.run(tf.initialize_all_variables()) 
    w_val, z_val = sess.run([w, z], feed_dict={x: np.arange(2).reshape(2,1)}) 
    print(w_val) 
    print(z_val) 

出力:

[[-0.02913031 0.13549708] 
[ 0.13807134 0.03763327]] 
[[ 1.13549709] 
[ 1.0376333 ]] 

あなたのテンソルへの参照を取得悩みをお持ちの場合(それがより上位の「レイヤー」操作にネストされているとします)、名前で検索してみてください。詳細:Tensorflow: How to get a tensor by name?

あなたは重みがトレーニング中にどのように変化するか、また、あなたはtf.Summaryオブジェクトに興味を持っているすべての値を保存し、後でそれを解析しようとすることができます確認したい場合は、次のParsing `summary_str` byte string evaluated on tensorflow summary object

+0

はありがとうとしてください私を明確にさせてください。上記の出力を取得した後、TensorFlowユーザーは通常、これらの値をいくつかのファイルを使用して他のプラットフォームを通過させるために格納します(私は.datまたは.pb)。 – kangaroo

+0

あなたの他のプラットフォームがPythonベースのものであれば、最も簡単な方法はpickleファイル(またはnumpy配列)ファイルを作ることです。 HDF5のようなバイナリ形式がより適切かもしれません(私は個人的にHDF5を使っていくつかのデータセットを配布しています)。 –

+0

https://stackoverflow.com/questions/34343259/is-there-an-example-on-how-to-generate-protobuf-files-holding-trained-tensorflowを参照してください。 – Cristi

関連する問題