BaseEstimatorとClassifierMixinを拡張して "Scikit-Learn準拠"の分類子を作成しようとしています。私はdocumentation on their websiteを読んでいて、私はthis oneClassifyMixinのscikit-learn、check_estimatorで
のようなオンラインガイドに従おうとしました。check_estimator()
テストに合格する見積もりを作ることができます。ただし、クラシファイアを作成しようとするたびに、テストに合格することはありません。でもこれは誤りであるので、私はエラーを引き上げるテストのいくつかの種類を忘れていているようです
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_X_y, check_is_fitted, check_array
from sklearn.utils.estimator_checks import check_classifiers_classes
from sklearn.metrics import euclidean_distances
import numpy as np
class MyCustomClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, param1=2):
self.param1 = param1
def fit(self, X, y=None, **kwargs):
# Check that X and y have correct shape
X, y = check_X_y(X, y)
# Store the classes seen during fit
self.classes_ = np.unique(y)
self.X_ = X
self.y_ = y
return self
def predict(self, X):
# Check is fit had been called
check_is_fitted(self, ['X_', 'y_', 'classes_'])
# Input validation
X = check_array(X)
closest = np.argmin(euclidean_distances(X, self.X_), axis=1)
return self.y_[closest]
MyCustomClassifier()
from sklearn.utils.estimator_checks import check_estimator
check_estimator(MyCustomClassifier)
テストに合格しない提供しScikit・ラーンtemplateは...私が取得:
Traceback (most recent call last):
File "C:/Users/vca/Google Drive/Internship/Skratch/supervised/logistic_regression.py", line 97, in <module>
check_estimator(MyCustomClassifier)
File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\site-packages\sklearn\utils\estimator_checks.py", line 265, in check_estimator
check(name, estimator)
File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\site-packages\sklearn\utils\testing.py", line 291, in wrapper
return fn(*args, **kwargs)
File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\site-packages\sklearn\utils\estimator_checks.py", line 1729, in check_classifiers_regression_target
assert_raises_regex(ValueError, msg, e.fit, X, y)
File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\unittest\case.py", line 1258, in assertRaisesRegex
return context.handle('assertRaisesRegex', args, kwargs)
File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\unittest\case.py", line 176, in handle
callable_obj(*args, **kwargs)
File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\unittest\case.py", line 196, in __exit__
self.obj_name))
File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\unittest\case.py", line 134, in _raiseFailure
raise self.test_case.failureException(msg)
AssertionError: ValueError not raised by fit
テストに合格したクラシファイアを正常に作成した人は誰ですか?
'' 'defのcheck_classifiers_regression_target(名前、推定):分類器が例外をスローした場合 #チェックを検出する方法がわから...エラーがスロー機能である回帰ターゲット ' '' を与えませんこれは "回帰ターゲットを与えられたとき"です。 – toine
私は実際に文字通りこれを修正する方法を見つけました。 –