2017-10-31 1 views
0

Training API http://lightgbm.readthedocs.io/en/latest/Python-API.html#training-apiとScikit-learn API http://lightgbm.readthedocs.io/en/latest/Python-API.html#scikit-learn-apiでLightGBMを試しています。LightGBM:SklearnとネイティブAPIの同等性

私は以下の例で強調表示されているように、両方のAPI間で明確なマッピングを行うことができませんでした。基本的な考え方は、合成データセットの50%を訓練することです。

import numpy as np 
import lightgbm as lgbm 

# Generate Data Set 
xs = np.linspace(0, 10, 100).reshape((-1, 1)) 
ys = xs**2 + 4*xs + 5.2 
ys = ys.reshape((-1,)) 

# LGBM configuration 
alg_conf = { 
    "num_boost_round":25, 
    "max_depth" : 3, 
    "num_leaves" : 31, 
    'learning_rate' : 0.1, 
    'boosting_type' : 'gbdt', 
    'objective' : 'regression_l2', 
    "early_stopping_rounds": None, 
} 

# Calling Regressor using scikit-learn API 
sk_reg = lgbm.sklearn.LGBMRegressor(
    num_leaves=alg_conf["num_leaves"], 
    n_estimators=alg_conf["num_boost_round"], 
    max_depth=alg_conf["max_depth"], 
    learning_rate=alg_conf["learning_rate"], 
    objective=alg_conf["objective"] 
) 
sk_reg.fit(xs[::2], ys[::2]) 

print("Scikit-learn API results") 
print(sk_reg.predict(xs[1::2])) 


# Calling Regressor using native API 
train_dataset = lgbm.Dataset(xs[::2], ys[::2]) 
lg_reg = lgbm.train(alg_conf.copy(), train_dataset) 

print("Native API results") 
print(lg_reg.predict(xs[1::2])) 

出力

Scikit-learn API results 
[ 14.35693851 14.35693851 14.35693851 14.35693851 14.35693851 
    14.35693851 14.35693851 14.35693851 14.35693851 14.35693851 
    25.37944751 25.37944751 25.37944751 25.37944751 25.37944751 
    35.10572544 35.10572544 35.10572544 35.10572544 35.10572544 
    46.50667974 46.50667974 46.50667974 46.50667974 46.50667974 
    59.44952419 59.44952419 59.44952419 59.44952419 59.44952419 
    75.42846332 75.42846332 75.42846332 75.42846332 75.42846332 
    109.4610814 109.4610814 109.4610814 109.4610814 109.4610814 
    109.4610814 109.4610814 109.4610814 109.4610814 109.4610814 
    109.4610814 109.4610814 109.4610814 109.4610814 109.4610814 ] 
Native API results 
[ 22.55947971 22.55947971 22.55947971 22.55947971 22.55947971 
    22.55947971 22.55947971 22.55947971 22.55947971 22.55947971 
    22.55947971 22.55947971 22.55947971 22.55947971 22.55947971 
    22.55947971 22.55947971 22.55947971 22.55947971 22.55947971 
    45.33537795 45.33537795 45.33537795 45.33537795 45.33537795 
    91.6376959 91.6376959 91.6376959 91.6376959 91.6376959 
    91.6376959 91.6376959 91.6376959 91.6376959 91.6376959 
    91.6376959 91.6376959 91.6376959 91.6376959 91.6376959 
    91.6376959 91.6376959 91.6376959 91.6376959 91.6376959 
    91.6376959 91.6376959 91.6376959 91.6376959 91.6376959 ] 

質問私は両方のAPIのパラメータの間に明確な等価を見つけることができる

ありがとうございます。

答えて

0

私はLightGBM GitHubで答えを得ました。それが正常に動作します

import numpy as np 
import lightgbm as lgbm 

# Generate Data Set 
xs = np.linspace(0, 10, 100).reshape((-1, 1)) 
ys = xs**2 + 4*xs + 5.2 
ys = ys.reshape((-1,)) 

# Or you could add to your alg_conf "min_child_weight": 1e-3, "min_child_samples": 20. 

# LGBM configuration 
alg_conf = { 
    "num_boost_round":25, 
    "max_depth" : 3, 
    "num_leaves" : 31, 
    'learning_rate' : 0.1, 
    'boosting_type' : 'gbdt', 
    'objective' : 'regression_l2', 
    "early_stopping_rounds": None, 
    "min_child_weight": 1e-3, 
    "min_child_samples": 20 
} 

# Calling Regressor using scikit-learn API 
sk_reg = lgbm.sklearn.LGBMRegressor(
    num_leaves=alg_conf["num_leaves"], 
    n_estimators=alg_conf["num_boost_round"], 
    max_depth=alg_conf["max_depth"], 
    learning_rate=alg_conf["learning_rate"], 
    objective=alg_conf["objective"], 
    min_sum_hessian_in_leaf=alg_conf["min_child_weight"], 
    min_data_in_leaf=alg_conf["min_child_samples"] 
) 
sk_reg.fit(xs[::2], ys[::2]) 

print("Scikit-learn API results") 
print(sk_reg.predict(xs[1::2])) 


# Calling Regressor using native API 
train_dataset = lgbm.Dataset(xs[::2], ys[::2]) 
lg_reg = lgbm.train(alg_conf.copy(), train_dataset) 

print("Native API results") 
print(lg_reg.predict(xs[1::2])) 

alg_conf "min_child_weight": 1e-3, "min_child_samples": 20)修正の違いを追加

:下記の結果を共有します。

+0

Lightgbm Scikitで私たちはlightgbmモデル/ paramsを得るために 'print(sk_reg)'を学ぶことができます。ネイティブAPIでこれを行う方法を知っていますか? print(lg_reg)はオブジェクトブースターへの参照を返します。 –

関連する問題