2017-06-26 17 views
2

tensorflowのLinearClassifier()クラスを使用してロジスティック回帰モデルモデルを訓練し、モデル訓練中にチェックポイントのメタグラフを保存する場所を指定するmodel_dirパラメータを設定しました:訓練されたLinearClassifierをtensorflow高水準APIから復元して予測を行う方法

# Create temporary directory where metagraphs will evenually be saved 
model_dir = tempfile.mkdtemp() 

logistic_model = tf.contrib.learn.LinearClassifier(
    feature_columns=feature_columns, 
    n_classes=num_labels, model_dir=model_dir) 

私はmetagraphsからモデルを復元について読んでてきたが、高レベルAPIを使用して作成されたモデルのためにこれを行う方法については何も見つかりませんでした。 LinearClassifier()にはpredict()関数がありますが、チェックポイントのメタデータで復元されたモデルのインスタンスを使用して予測を実行する方法に関するドキュメントは見つかりません。これをどうやってやりますか?

with tf.Session() as sess: 
    new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta') 
    new_saver.restore(sess, 'my-save-dir/my-model-10000') 
    # Run prediction algorithm... 

私は同じ予測を実行するにはどうすればよい:モデルが復元されると、私の理解では、私はこのように、LinearClassifierクラスの機能に建てられたのすべてを欠いtf.Sessオブジェクトで働いていますということです復元されたモデルの予測を行うために高水準apiによって使用されるアルゴリズム?これにアプローチするより良い方法はありますか?

ありがとうございます。

+0

を提案されたソリューションはあなたに適していますか? –

+0

潜在的に、私はあなたの提案された修正を実装しましたが、まだ動作することを確認する必要があります。今すぐスーパースワンプされ、私はチャンスを得るときに戻って報告します。ご協力いただきありがとうございます。 –

答えて

1

LinearClassifier()には、訓練されたモデルをポイントするとモデルを復元する場合、 'model_dir'パラメータがあります。トレーニングの間
は、あなたが実行します。

logistic_model = tf.contrib.learn.LinearClassifier(feature_columns=feature_columns, n_classes=num_labels, model_dir=model_dir) 
classifier.fit(X_train, y_train, steps=10) 

推論時には、LinearClassifier()は、指定されたパスから訓練されたモデルをロードします、とあなたはfit()メソッドを使用しませんが、predict()メソッドを呼び出す:

logistic_model = tf.contrib.learn.LinearClassifier(feature_columns=feature_columns, n_classes=num_labels, model_dir=model_dir) 
y_pred = classifier.predict(X_test) 
関連する問題