2017-10-03 6 views
3

からgraph_defを解析するとき、私は.pbファイルとしてTensorflowグラフの非常にシンプルな節約を実行しようとしていますが、それをバック解析するとき、私はこのエラーを持っている:エラー文字列

Traceback (most recent call last): 
    File "test_import_stripped_bm.py", line 28, in <module> 
    graph_def.ParseFromString(fileContent) 
    File "/usr/local/lib/python3.5/dist-packages/google/protobuf/message.py", line 185, in ParseFromString 
    self.MergeFromString(serialized) 
    File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1069, in MergeFromString 
    if self._InternalParse(serialized, 0, length) != length: 
    File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1105, in InternalParse 
    pos = field_decoder(buffer, new_pos, end, self, field_dict) 
    File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 633, in DecodeField 
    if value._InternalParse(buffer, pos, new_pos) != new_pos: 
    File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1105, in InternalParse 
    pos = field_decoder(buffer, new_pos, end, self, field_dict) 
    File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 612, in DecodeRepeatedField 
    if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: 
    File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1105, in InternalParse 
    pos = field_decoder(buffer, new_pos, end, self, field_dict) 
    File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 743, in DecodeMap 
    if submsg._InternalParse(buffer, pos, new_pos) != new_pos: 
    File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1095, in InternalParse 
    new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) 
    File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 850, in SkipField 
    return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) 
    File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 799, in _SkipGroup 
    new_pos = SkipField(buffer, pos, end, tag_bytes) 
    File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 850, in SkipField 
    return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) 
    File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 814, in _SkipFixed32 
    raise _DecodeError('Truncated message.') 
google.protobuf.message.DecodeError: Truncated message. 

これは私のコードです.pbするためにそれを書くために使用します。

import tensorflow as tf 

builder = tf.saved_model.builder.SavedModelBuilder('models/TEST-3') 

w1 = tf.Variable(tf.random_normal((2,2)), name="w1") 
w2 = tf.Variable(tf.random_normal((2,2)), name="w2") 

sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 

builder.add_meta_graph_and_variables(sess, tags=[tf.saved_model.tag_constants.SERVING], clear_devices = True) 

builder.save() 
sess.close() 

そして、これはそれを解析するためのコードです:

import tensorflow as tf 
import os 

model_path = os.path.join('models/TEST-3', 'saved_model.pb') 
with open(model_path, mode='rb') as f: 
    fileContent = f.read() 
graph_def = tf.GraphDef() 
graph_def.ParseFromString(fileContent) 

私が持っていた正確なエラーを表示するにはdo

export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python 

実行前。 また、テンソルフローのバージョンが異なる2と3のPythonでこれを試しました。私はUbuntu 16.04で動作しています。 tensorflow 0.9.0rc0でのpython 2.7に私は少し異なるエラーを取得するために管理:

Traceback (most recent call last): 
    File "<stdin>", line 1, in <module> 
    File "/usr/local/lib/python2.7/dist-packages/google/protobuf/message.py", line 185, in ParseFromString 
    self.MergeFromString(serialized) 
    File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1091, in MergeFromString 
    if self._InternalParse(serialized, 0, length) != length: 
    File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse 
    pos = field_decoder(buffer, new_pos, end, self, field_dict) 
    File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 633, in DecodeField 
    if value._InternalParse(buffer, pos, new_pos) != new_pos: 
    File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse 
    pos = field_decoder(buffer, new_pos, end, self, field_dict) 
    File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 612, in DecodeRepeatedField 
    if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: 
    File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse 
    pos = field_decoder(buffer, new_pos, end, self, field_dict) 
    File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 612, in DecodeRepeatedField 
    if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: 
    File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse 
    pos = field_decoder(buffer, new_pos, end, self, field_dict) 
    File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 489, in DecodeRepeatedField 
    value.append(_ConvertToUnicode(buffer[pos:new_pos])) 
    File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 469, in _ConvertToUnicode 
    return local_unicode(byte_str, 'utf-8') 
UnicodeDecodeError: 'utf8' codec can't decode byte 0x80 in position 18: 'utf8' codec can't decode byte 0x80 in position 18: invalid start byte in field: tensorflow.FunctionDef.Node.ret 

私は例えば、事前にこの1 https://github.com/taey16/tf/blob/master/imagenet/classify_image_graph_def.pb

感謝をこのコードの他の.pbグラフを解析することができています。

答えて

3

ここでの問題は、SavedModelプロトコルバッファをあたかもGraphDefのように解析しようとしていることです。 SavedModelにはGraphDefが含まれていますが、バイナリ形式が異なります。 tf.saved_model.loader.load()を使用する次のコードは、動作するはずです。

import tensorflow as tf 

with tf.Session(graph=tf.Graph()) as sess: 
    tf.saved_model.loader.load(
     sess, [tf.saved_model.tag_constants.SERVING], "models/TEST-3") 
+1

ありがとうございました。私は私がやって管理し、シリアライズGraphDef取得したい: sess.graph_def.SerializeToStringを()し、その後でそれをロード: 'graph_def = tf.GraphDef() graph_def.ParseFromString(graph_string) tf.import_graph_def(graph_def 、name = ")'それをロードする必要があるからです。問題は、変数が初期化されていないということです。 私はこのモデルをSparkで使用したいと思っており、このガイドに従っていますhttps://databricks.com/blog/2016/01/25/deep-learning-with-apache-spark-and-tensorflow.html – vntzy

+0

'GraphDef'変数を初期化するのに十分な情報を提供していません。そのチュートリアルはかなり古くなっています。これは 'SavedModel'に先行しています。これは、このようなことをはるかに簡単にします。自分の変数を初期化するので、私の答えに 'SavedModel'をコマンドでロードすれば十分です。 – mrry

+0

はい、それはうまくいきましたが、今は問題は、ディスクを別のプロセスで実行し、同じtfセッションを再利用することができないため、毎回ディスクから読み込むので非常に遅いということです。ディレクトリから読み込むのではなく、文字列からSavedModelを読み込む方法はありますか?それとも別の方法で? – vntzy