2017-07-31 5 views
2

私はsklearnインターフェイスを使用してXGBRegressorモデルを訓練しました。次のように関連するコードは次のとおりです。XGBRegressor.predict()はpickle.load()の後にエラーを返します

def xgb_regressor_wrapper(X_train, y_train): 
    xgb_regressor = XGBRegressor(objective='reg:linear', n_estimators=1000, learning_rate=0.01, base_score=0.005) 
    xgb_regressor.fit(X=X_train, y=y_train) #, eval_set=[(X_test, y_test)], verbose=True) 
    return xgb_regressor 

def save_regressor(station, feature, regressor): 
    fname = generate_regressor_fname(station, feature) 
    pickle.dump(regressor, open(fname, "wb")) 

# regressor_list dict contains wrapper functions 
# I currently have XGBRegressor and CatBoostRegressor in the list. 
regressor_wrapper = regressor_list.get(name) 

# Create and fit XGBRegressor 
regressor = regressor_wrapper(X_train, y_train) 

# Save regressor 
save_regressor(station_id, feature, best_regressor) 

しばらくして、私は回帰をリロードするには、次のコードを使用して、予測を行います。

def load_regressor(station, feature): 
    fname = generate_regressor_fname(station, feature) 
    return pickle.load(open(fname, "rb")) 

# Load the regressor 
regressor = load_regressor(station_id, feature) 

# Do the prediction 
y_predict = regressor.predict(X_test) 

私は次のエラーを取得する:

File "regressor_stuff.py", line 169, in regressor_check_for_station_feature 
    y_predict = regressor.predict(X_test) 
    File "D:\Anaconda\envs\Deep\lib\site-packages\xgboost\sklearn.py", line 268, in predict 
    return self.booster().predict(test_dmatrix, 
TypeError: 'str' object is not callable 

デバッグ後、self.boosterは実際に文字列 'gbtree'を格納しています。洗練された機能の回帰訓練の後(これは数日かかりましたが)、これはクールではありませんでした。

なぜこのようなことが起こったのでしょうか?

私の現在の回避策は次のようにXGBBoosterを再構築することである。

# Load the regressor 
if isinstance(regressor, XGBRegressor): 
    regressor = XGBRegressor() 
    r = pickle.load(open(fname, "rb")) 
    print r.get_xgb_params() 
    regressor._Booster = r._Booster 
    regressor.set_params(**r.get_xgb_params()) 

# Do the prediction 
y_predict = regressor.predict(X_test) 

は、私はあなたがあなたのトレーニングやスコアリング環境でxgboostバージョンの不一致を持っているかもしれないと思うあなた

Kursat

答えて

1

ありがとうございます。私は同じ問題に遭遇し、xgboost == 0.6を使ってxgboost == 0.6a2を得点したことがわかりました。

+0

こんにちは@InfiniteExistence、 良いキャッチです。 XGBoostには2つの異なる環境があります。私ができることを確認します。 ありがとう K. –

関連する問題