2017-11-10 12 views
0

BaseEstimatorとClassifierMixinを拡張して "Scikit-Learn準拠"の分類子を作成しようとしています。私はdocumentation on their websiteを読んでいて、私はthis oneClassifyMixinのscikit-learn、check_estimatorで

のようなオンラインガイドに従おうとしました。check_estimator()テストに合格する見積もりを作ることができます。ただし、クラシファイアを作成しようとするたびに、テストに合格することはありません。でもこれは誤りであるので、私はエラーを引き上げるテストのいくつかの種類を忘れていているようです

from sklearn.base import BaseEstimator, ClassifierMixin 
from sklearn.utils.validation import check_X_y, check_is_fitted, check_array 
from sklearn.utils.estimator_checks import check_classifiers_classes 
from sklearn.metrics import euclidean_distances 
import numpy as np 

class MyCustomClassifier(BaseEstimator, ClassifierMixin): 
    def __init__(self, param1=2): 
     self.param1 = param1 

    def fit(self, X, y=None, **kwargs): 
     # Check that X and y have correct shape 
     X, y = check_X_y(X, y) 

     # Store the classes seen during fit 
     self.classes_ = np.unique(y) 

     self.X_ = X 
     self.y_ = y 

     return self 


    def predict(self, X): 
     # Check is fit had been called 
     check_is_fitted(self, ['X_', 'y_', 'classes_']) 

     # Input validation 
     X = check_array(X) 

     closest = np.argmin(euclidean_distances(X, self.X_), axis=1) 

     return self.y_[closest] 

MyCustomClassifier() 

from sklearn.utils.estimator_checks import check_estimator 

check_estimator(MyCustomClassifier) 

テストに合格しない提供しScikit・ラーンtemplateは...私が取得:

Traceback (most recent call last): 
    File "C:/Users/vca/Google Drive/Internship/Skratch/supervised/logistic_regression.py", line 97, in <module> 
    check_estimator(MyCustomClassifier) 
    File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\site-packages\sklearn\utils\estimator_checks.py", line 265, in check_estimator 
    check(name, estimator) 
    File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\site-packages\sklearn\utils\testing.py", line 291, in wrapper 
    return fn(*args, **kwargs) 
    File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\site-packages\sklearn\utils\estimator_checks.py", line 1729, in check_classifiers_regression_target 
    assert_raises_regex(ValueError, msg, e.fit, X, y) 
    File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\unittest\case.py", line 1258, in assertRaisesRegex 
    return context.handle('assertRaisesRegex', args, kwargs) 
    File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\unittest\case.py", line 176, in handle 
    callable_obj(*args, **kwargs) 
    File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\unittest\case.py", line 196, in __exit__ 
    self.obj_name)) 
    File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\unittest\case.py", line 134, in _raiseFailure 
    raise self.test_case.failureException(msg) 
AssertionError: ValueError not raised by fit 

テストに合格したクラシファイアを正常に作成した人は誰ですか?

+0

'' 'defのcheck_classifiers_regression_target(名前、推定):分類器が例外をスローした場合 #チェックを検出する方法がわから...エラーがスロー機能である回帰ターゲット ' '' を与えませんこれは "回帰ターゲットを与えられたとき"です。 – toine

+0

私は実際に文字通りこれを修正する方法を見つけました。 –

答えて

1

これを修正する方法を知りました。 1つはcheck_classification_targetsをフィットで実行する必要があります。回帰ターゲットを使用すると、明らかにエラーが発生します。

import numpy as np 
from sklearn.base import BaseEstimator, ClassifierMixin 
from sklearn.metrics import euclidean_distances 
from sklearn.utils.multiclass import check_classification_targets 
from sklearn.utils.validation import check_X_y, check_is_fitted, check_array 

class MyCustomClassifier(BaseEstimator, ClassifierMixin): 
    def __init__(self, param1=2): 
     self.param1 = param1 

    def fit(self, X, y=None, **kwargs): 
     # Check that X and y have correct shape 
     X, y = check_X_y(X, y) 
     check_classification_targets(y) 

     # Store the classes seen during fit 
     self.classes_ = np.unique(y) 

     self.X_ = X 
     self.y_ = y 

     return self 

    def predict(self, X): 
     # Check is fit had been called 
     check_is_fitted(self, ['X_', 'y_', 'classes_']) 

     # Input validation 
     X = check_array(X) 

     closest = np.argmin(euclidean_distances(X, self.X_), axis=1) 

     return self.y_[closest] 
+0

その余分な 'check_array(X)'の必要性をチェックするテストを実際に実行することによって。これは 'check_X_y()'で行われます –

+0

実際、check_X_y()で実行され、テストはそのテストなしで実行されます。しかし、新しいデータで予測を使用すると、データがそのフォーマットでうまく動作しない場合はどうしたらよいですか?たとえば、既知のテストデータに対して予測を使用しない場合はどうなりますか?余分なcheck_array()を保つことはより安全ではないでしょうか? –

+0

はい。 predict()はokです。私はfit()について話していました。誤解をおかけして申し訳ありません。私はcheck_X_yがそれを処理し、fit()でのみ行われるとコメントしました。だから私はfit()について話していることを明確にした。 –

関連する問題