2016-11-14 4 views
2

私はCifar10 example model from TensorFlow's repositoryとbatch_size 128を訓練しました。その後、私はfroze graphと、彼らはC++ label image exampleでそれを行うようにC + +でそれを実行するために管理されます。TensorFlowで異なるバッチサイズで保存されたモデルを実行する最善の方法は何ですか?

唯一の問題は、保存されたモデルがバッチ内の128サンプルの入力を期待するため、テンソルの形状[128、image_height、image_width、channels]を人工的に生成しなければならないということでした。キューから来ます。

私はbatch_size = 1でCifar10の例を練習しようとしましたが、その後C + +でモデルを実行すると1つずつ例を分類することができましたが、それは素晴らしい解決策のようには見えません。私はまた、保存されたグラフファイルで手動でテンソルシェイプを変更しようとしましたが、動作しませんでした。

固定のバッチサイズ(32,64,128など)でモデルを訓練し、任意の長さのバッチサイズで使用できるようにモデルを保存する最良の方法は何ですか?それが不可能な場合は、モデルを保存してサンプルを1つずつ分類できるようにする方法です。

答えて

6

TensorFlowがバッチサイズをグラフ内の他のテンソルに「ベーキング」していると思われます(たとえば、グラフがバッチサイズに依存する形状のtのテンソルtf.shape(t)の場合は、バッチサイズがグラフとして定数として記憶される)。解決策は、プログラムを少し変更して、tf.train.batch()が可変バッチサイズのテンソルを返すようにすることです。

引数の場合は、tf.Tensorを受け入れます。おそらく、可変サイズのバッチのあなたのプログラムを修正する最も簡単な方法は、バッチサイズのためのプレースホルダを定義するには、次のようになります。次のことができるようにする必要がありますので

# Define a scalar tensor for the batch size, so that you can alter it at 
# Session.run()-time. 
batch_size_tensor = tf.placeholder(tf.int32, shape=[]) 
input_tensors = tf.train.batch(..., batch_size=batch_size_tensor, ...) 

これは、あなたのGraphDefに焼かれるのバッチサイズを妨げます任意のバッチサイズのC++フィード値。ただし、この変更により、すべてのステップでバッチサイズの値を入力する必要がありますが、これはやや面倒です。

あなたは常にバッチサイズ128で訓練したいと仮定すると、それ以降のバッチサイズを変更するための柔軟性を維持する、あなたがいないフィードANを行うときのバッチサイズは128であることを指定するには、tf.placeholder_with_default()を使用することができます別の値:

# Define a scalar tensor for the batch size, so that you can alter it at 
# Session.run()-time. 
batch_size_tensor = tf.placeholder_with_default(128, shape=[]) 
input_tensors = tf.train.batch(..., batch_size=batch_size_tensor, ...) 
+0

batch_size_tensorを使用して別のテンソルの形状を定義する方法は? –

0

グラフに固定バッチサイズが必要な理由はありますか?

可変のバッチサイズのグラフを作成することをお勧めします。最初の次元としてNoneを設定します。トレーニング中にバッチサイズフラグをデータプロバイダに渡すことで、各繰り返しで必要な量のデータをフィードできます。

モデルを訓練した後、tf.train.Saver()を使用してexport the graphを使用すると、メタデータをエクスポートできます。推論を行うには、エクスポートされたファイルを読み込み、任意の数の例を評価するだけで、1つだけでも評価できます。
これはフリーズしたグラフとは異なります。

+0

私が固定バッチサイズでトレーニングしているのは、バッチ処理に使用するtf.train.batch()がパラメータbatch_sizeを受け取り、その番号を指定する必要があるということです。 フリーズグラフが行く方法だったようでした。私はtf.train.Saver()でモデルを保存しようとすると、単一のサンプルで推論を実行し、そのグラフをフリーズする別のセッションで復元しようとします。そのグラフをC++で実行した後、おそらく私は単一サンプルの推論を実行できます。私はまだTensorFlowとちょっと混乱しています。その多くは、適切に文書化されていないか、正式にサポートされていないため、スクリプトです。 – Drag0

+0

また、後でカメラフィードからリアルタイムで来る画像を分類したいので、私は単一のサンプルで推論を実行したいと思います。 – Drag0

関連する問題