2017-06-30 14 views
1

私は、与えられたscikit-learn見積もりの​​ためにtrainableのパラメータを検出し、可能であればカテゴリ変数の許容値(および連続するものについては合理的な間隔)を見つける必要があるプロジェクトの作業を開始しました。パラメータグリッドでどの値が許容されているかを検出する方法はありますか? (sklearn)

estimator.get_params()を使用してパラメータを使用して辞書をフェッチしてから、estimator.set_params(**{'var1':val1, 'var2':val2})などを使用して値を設定することができます。

たとえば、KNN分類器では、 {'metric': 'minkowski', 'algorithm': 'auto', 'n_neighbors': 10, 'n_jobs': 1, 'p': 2, 'metric_params': None, 'weights': 'uniform', 'leaf_size': 30}というパラメータがあります。

ここでは、カテゴリ(strタイプ)、連続型(float)、離散型(int)などの推論する値の型を使用できます。おそらく関連する問題の1つは、デフォルトがNoneTypeに設定されているパラメータですが、正当な理由でこれらのパラメータには触れないかもしれません。

ここで、チャレンジは、例えば、使用するためのパラメータグリッドを推測し、定義するようになる。 RandomizedSearchCV。離散変数および連続変数の場合、問題は扱いやすい。 tryの組み合わせ - scipy.statsモジュールと一緒にexceptブロックは、間隔を制限する可能性がデフォルト値の周りの「周辺」にある(しかし、例えばn_jobsいくつかのクレイジーな値に設定されていないと同時に慎重であることに - ことハードコーディングされているか、明示的に後で設定する必要があるかもしれません)。似たような経験があり、あなたの袖の上にいくつかのヒント/トリックを持っているなら、私はそれらについて聞いてみたいと思います。

しかし実際の問題は次のとおりです。 algorithm実際には許容値は{‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’} ??

私は問題に探し始めている、と我々はいくつかの非許容値に設定しようとすると、おそらく我々は、我々が得るエラーメッセージを解析することができますか?私は良いアイデアを探してくださいここにいます、(...私は私がしなければならない場合、それはかなり洗練ようです)私は手動でこれを行うことを避けるためにしたいと

ありがとうございました!

+0

自分自身への注意:これは非常に難しい/解決できない問題である可能性があります。私はAPIとソースコードを覗いてみた。オートスケルトンはこれを解決します。手動の(ハードコード化された)ソリューションが今のところ行く方法だと思われます。 – Magnus

+0

興味深い問題があります。[署名とデフォルトパラメータの解析](https://stackoverflow.com/questions/2677185/how-can-i-read-a-functions-signature-including-default-argument-values)を除いて、私は[this](https://stackoverflow.com/questions/713138/getting-the-docstring-from-a-function)のようなscikit-learnのドキュメントストリングを解析します。しようとするもう一つのことは、文字列化された関数を解析することです。見積もり者の '__init__'だが、そこでは何のチェックも行われていないので、見るべきかもしれない全体的な階層があるので、それは面倒なものだ。 – mkaran

+0

こんにちは!興味深いテーマを見つけてうれしいです。はい、それは/私が検討している/考えている(ドキュメントを解析する)オプションの1つです。しかし、私が心配しているのは、ドキュメントストリングの作成方法が一貫していることです。私は少し時間をかけてパーサを実装し、ドキュメントストリングの束をテストします。 – Magnus

答えて

0

私は、彼らがsklearnに各推定のために書かれている方法を何セットコンベンションWRTがないので、私が見ていた特定の例に対する解決策は、しかし、それは他のdoc-の文字列にも一般化しません発見しました。

したがって、私は自分の「解決策」を投稿して、他の人が引き継ぎ、おそらくそれを改善できるようにします。次のコードを参照してください:

import re 
from pprint import pprint 
from sklearn.neighbors import KNeighborsClassifier 
knn = KNeighborsClassifier() 
doc = knn.__doc__ # Get the doc string 
#from sklearn.svm import SVC 
#svc = SVC() 
#doc = svc.__doc__ 
pattern = "([a-zA-Z_]+\s:\s)|(-\s*)'([a-zA-Z_]+)'" # Define search pattern 
re.compile(pattern) 
matches = re.findall(pattern, doc) 

clf_params = {} 
previous_param = '' 
for param, _, value in matches: 
    if ":" in param and param[-4]!="_": # 'Hack-y' 
     if param not in clf_params.keys(): 
      clf_params[param] = list() 
      previous_param = param 
     else: 
      if len(value)>0: 
       clf_params[previous_param].append(value) 
pprint(clf_params) 

このコードが印刷さ

正しい
{'algorithm : ': ['ball_tree', 'kd_tree', 'brute', 'auto'], 
'leaf_size : ': [], 
'metric : ': [], 
'metric_params : ': [], 
'n_jobs : ': [], 
'n_neighbors : ': [], 
'p : ': [], 
'weights : ': ['uniform', 'distance']} 

。私たちはSVC().__doc__に対して同じ手順を繰り返す場合

しかし、我々はそれが失敗したことがわかります。

私は誰かが、これはやや重宝願っています。

関連する問題