2013-03-23 4 views
13

マルチクラス分類の問題を解決し、一般化ブーストモデル(Rのgbmパッケージ)を使用しようとしています。私が直面した問題:キャレットのtrainの機能がmethod="gbm"の場合、マルチクラスデータで正しく機能しないようです。簡単な例を以下に示します。マルチクラス分類のためのgbmメソッドを使用したキャレットの使用

library(gbm) 
library(caret) 
data(iris) 
fitControl <- trainControl(method="repeatedcv", 
          number=5, 
          repeats=1, 
          verboseIter=TRUE) 
set.seed(825) 
gbmFit <- train(Species ~ ., data=iris, 
       method="gbm", 
       trControl=fitControl, 
       verbose=FALSE) 
gbmFit 

出力は

+ Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150 
predictions failed for Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150 
- Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150 
+ Fold1.Rep1: interaction.depth=2, shrinkage=0.1, n.trees=150 
... 
+ Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150 
predictions failed for Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150 
- Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150 
Aggregating results 
Selecting tuning parameters 
Fitting interaction.depth = numeric(0), n.trees = numeric(0), shrinkage = numeric(0) on full training set 
Error in if (interaction.depth < 1) { : argument is of length zero 

ですけれども、私はキャレットラッパーなしGBMを使用しようとすると、私は素晴らしい結果を得ます。

set.seed(1365) 
train <- createDataPartition(iris$Species, p=0.7, list=F) 
train.iris <- iris[train,] 
valid.iris <- iris[-train,] 
gbm.fit.iris <- gbm(Species ~ ., data=train.iris, n.trees=200, verbose=FALSE) 
gbm.pred <- predict(gbm.fit.iris, valid.iris, n.trees=200, type="response") 
gbm.pred <- as.factor(colnames(gbm.pred)[max.col(gbm.pred)]) ##! 
confusionMatrix(gbm.pred, valid.iris$Species)$overall 

FYI、##!でマークされた行のコードは、最も可能性の高いクラスの要因にpredict.gbmによって返されたクラス確率の行列に変換します。出力は

 Accuracy   Kappa AccuracyLower AccuracyUpper AccuracyNull AccuracyPValue McnemarPValue 
    9.111111e-01 8.666667e-01 7.877883e-01 9.752470e-01 3.333333e-01 8.467252e-16   NaN 

どのような提案は、マルチクラスのデータでgbmで正しく動作するようにするには?

UPD:

sessionInfo() 
R version 2.15.3 (2013-03-01) 
Platform: x86_64-pc-linux-gnu (64-bit) 

locale: 
[1] LC_CTYPE=en_US.UTF-8  LC_NUMERIC=C    LC_TIME=en_US.UTF-8  LC_COLLATE=en_US.UTF-8  
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8 LC_PAPER=C     LC_NAME=C     
[9] LC_ADDRESS=C    LC_TELEPHONE=C    LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C  

attached base packages: 
[1] splines stats  graphics grDevices utils  datasets methods base  

other attached packages: 
[1] e1071_1.6-1  class_7.3-5  gbm_2.0-8  survival_2.36-14 caret_5.15-61 reshape2_1.2.2 plyr_1.8   
[8] lattice_0.20-13 foreach_1.4.0 cluster_1.14.3 compare_0.2-3 

loaded via a namespace (and not attached): 
[1] codetools_0.2-8 compiler_2.15.3 grid_2.15.3  iterators_1.0.6 stringr_0.6.2 tools_2.15.3 
+0

を生成していますか? 825および1365? – agstudy

+1

それは問題ですか? 825 - は私がフォーム[Caret.r-forge.r-project.org](http://caret.r-forge.r-project.org/training.html)、1365 - 種を取ったサンプルコードの種です私は自分のプロジェクトで使った。 – maruan

答えて

6

これは、私が今働いている問題です。

これは、sessionInfo()の結果を投稿した場合に役立ちます。

また、最新のgbmをhttps://code.google.com/p/gradientboostedmodels/から外すと、問題が解決する場合があります。

マックス

+0

問題はhttps://code.google.com/p/gradientboostedmodels/issues/detail?id=12に関連しています。私は回避策がありますが、多項式データの問題に過ぎないので、避けたいのです。メンテナーに再度連絡して、イータがあるかどうかを確認します。 – topepo

+0

gbmをdevtools https://github.com/hadley/devtools/issues/419で更新した後でドキュメントを再読み込みすることについてよく知られている問題があるようです。 –

3

更新:キャレットはマルチクラス分類を行うことができます。

クラスラベルが英数字形式(文字で始まる)であることを確認してください。

たとえば、データにラベル「1」、「2」、「3」がある場合は、これらを「Seg1」、「Seg2」および「Seg3」に変更します。

2

更新:元のコードを実行し、ちょうど質問、なぜあなたは、2つの異なる種を使用している次の出力

+ Fold1.Rep1: shrinkage=0.1, interaction.depth=1, n.trees=150 
- Fold1.Rep1: shrinkage=0.1, interaction.depth=1, n.trees=150 
... 
... 
... 
+ Fold5.Rep1: shrinkage=0.1, interaction.depth=3, n.trees=150 
- Fold5.Rep1: shrinkage=0.1, interaction.depth=3, n.trees=150 
Aggregating results 
Selecting tuning parameters 
Fitting n.trees = 50, interaction.depth = 2, shrinkage = 0.1 on full training set 
> gbmFit 
Stochastic Gradient Boosting 

150 samples 
    4 predictor 
    3 classes: 'setosa', 'versicolor', 'virginica' 

No pre-processing 
Resampling: Cross-Validated (5 fold, repeated 1 times) 

Summary of sample sizes: 120, 120, 120, 120, 120 

Resampling results across tuning parameters: 

    interaction.depth n.trees Accuracy Kappa Accuracy SD 
    1     50  0.9400000 0.91 0.04346135 
    1     100  0.9400000 0.91 0.03651484 
    1     150  0.9333333 0.90 0.03333333 
    2     50  0.9533333 0.93 0.04472136 
    2     100  0.9533333 0.93 0.05055250 
    2     150  0.9466667 0.92 0.04472136 
    3     50  0.9333333 0.90 0.03333333 
    3     100  0.9466667 0.92 0.04472136 
    3     150  0.9400000 0.91 0.03651484 
    Kappa SD 
    0.06519202 
    0.05477226 
    0.05000000 
    0.06708204 
    0.07582875 
    0.06708204 
    0.05000000 
    0.06708204 
    0.05477226 

Tuning parameter 'shrinkage' was held constant at a value of 0.1 
Accuracy was used to select the optimal model using the 
largest value. 
The final values used for the model were n.trees = 
50, interaction.depth = 2 and shrinkage = 0.1. 
> summary(gbmFit) 
         var rel.inf 
Petal.Length Petal.Length 74.1266408 
Petal.Width Petal.Width 22.0668983 
Sepal.Width Sepal.Width 3.2209288 
Sepal.Length Sepal.Length 0.5855321 
関連する問題