2017-12-05 9 views
0

predictiontf.estimatorにコールする方法の明確な例を挙げることはできますか?入力ファイルは、ラベル列を持っていない場合はTensorflow Estimatorでディープワイドモデルを予測する

2つのことは私が予測された結果に

をエクスポートする方法input_fn

  • を調整する方法、

    1. 私には不明ですdeep_wideモデルで作業しています。

  • +1

    この[記事](HTTPSを予測する新しいinput_fnを使用します://developers.googleblog.com/2017/09/introducing-tensorflow-datasets.html)はすばらしい読書になるはずです。 – Lescurel

    答えて

    0

    私は自分自身に答え、これは質問1の場合と同様の質問

    で誰かを助けることができると思います、はい、私たちは、新たな入力機能を構築する必要があり、最も可能性が高いことがあるため、あなたの入力CSVファイルとして異なる数の列を持つことになります我々は、質問2のため

    def parse_csv(value): 
        print('Parsing', data_file) 
        columns = tf.decode_csv(value, record_defaults=_PREDICT_COLUMNS_DEFAULTS) 
        features = dict(zip(_PREDICT_COLUMNS, columns)) 
    
        return features 
    
    def predict_input_fn(data_file): 
        assert tf.gfile.Exists(data_file), ('%s not found. Please make sure the path is correct.' % data_file) 
    
        dataset = tf.data.TextLineDataset(data_file) 
        dataset = dataset.map(parse_csv, num_parallel_calls=5) 
        dataset = dataset.batch(1) # this is very important to keep the rank right 
        iterator = dataset.make_one_shot_iterator() 
        features = iterator.get_next() 
        return features 
    

    そして、ラベル列を削除するには、結果

    def predict(model): 
        start_from_id = 892 
        test_csv = [] 
    
        results = model.predict(
         input_fn=lambda: predict_input_fn(data_file='test.csv') 
        ) 
    
        # for result in results: 
        #  print 'result: {}'.format(result) 
    
    関連する問題