2

私はtflearnモデルのハイパーパラメータでグリッド検索を実行します。 tflearn.DNNが作成したモデルがsklearnのGridSearchCVの期待と互換性がないようです:sklearnのGridSearchCVでtflearnを実行できません

from sklearn.grid_search import GridSearchCV 
import tflearn 
import tflearn.datasets.mnist as mnist 
import numpy as np 

X, Y, testX, testY = mnist.load_data(one_hot=True) 

encoder = tflearn.input_data(shape=[None, 784]) 
encoder = tflearn.fully_connected(encoder, 256) 
encoder = tflearn.fully_connected(encoder, 64) 

# Building the decoder 
decoder = tflearn.fully_connected(encoder, 256) 
decoder = tflearn.fully_connected(decoder, 784) 

# Regression, with mean square error 
net = tflearn.regression(decoder, optimizer='adam', learning_rate=0.01, 
         loss='mean_square', metric=None) 

model = tflearn.DNN(net, tensorboard_verbose=0) 

grid_hyperparams = {'optimizer': ['adam', 'sgd', 'rmsprop'], 'learning_rate': np.logspace(-4, -1, 4)} 
grid = GridSearchCV(model, param_grid=grid_hyperparams, scoring='mean_squared_error', cv=2) 
grid.fit(X, X) 

私はエラーを取得する:

TypeError         Traceback (most recent call last) 
<ipython-input-3-fd63245cd0a3> in <module>() 
    22 grid_hyperparams = {'optimizer': ['adam', 'sgd', 'rmsprop'], 'learning_rate': np.logspace(-4, -1, 4)} 
    23 grid = GridSearchCV(model, param_grid=grid_hyperparams, scoring='mean_squared_error', cv=2) 
---> 24 grid.fit(X, X) 
    25 
    26 

/home/deeplearning/anaconda3/lib/python3.5/site-packages/sklearn/grid_search.py in fit(self, X, y) 
    802 
    803   """ 
--> 804   return self._fit(X, y, ParameterGrid(self.param_grid)) 
    805 
    806 

/home/deeplearning/anaconda3/lib/python3.5/site-packages/sklearn/grid_search.py in _fit(self, X, y, parameter_iterable) 
    539           n_candidates * len(cv))) 
    540 
--> 541   base_estimator = clone(self.estimator) 
    542 
    543   pre_dispatch = self.pre_dispatch 

/home/deeplearning/anaconda3/lib/python3.5/site-packages/sklearn/base.py in clone(estimator, safe) 
    45        "it does not seem to be a scikit-learn estimator " 
    46        "as it does not implement a 'get_params' methods." 
---> 47        % (repr(estimator), type(estimator))) 
    48  klass = estimator.__class__ 
    49  new_object_params = estimator.get_params(deep=False) 

TypeError: Cannot clone object '<tflearn.models.dnn.DNN object at 0x7fead09948d0>' (type <class 'tflearn.models.dnn.DNN'>): it does not seem to be a scikit-learn estimator as it does not implement a 'get_params' methods. 

私はGridSearchCVに適したオブジェクトを取得する可能性がどのように任意のアイデアを?

答えて

0

私はtflearnに関する経験はありませんが、私はPythonとSklearnの基本的な背景をいくつか持っています。 StackOverflowのスクリーンショットのエラーから判断すると、tflearn ** models **にはscikit-learn estimatorと同じメソッドまたは属性がありません。これは、うまくいけば、シキット学習の推定量ではないので、理解できる。

Sklearnのグリッド検索CVは、scikit-learn見積もりと同じメソッドおよび属性(たとえば、fit()およびpredict()メソッドなど)を持つオブジェクトでのみ機能します。 sklearnのグリッド検索を使用する予定がある場合は、sklearn見積もりの​​代わりにドロップとして機能させるために、独自のラッパーをtflearnモデルに書き込む必要があります。つまり、同じクラスの独自のクラスを作成する必要があります他のscikit-learn見積もりと同様に、実際にこれらのメソッドを実装するためにtflearnライブラリを使用します。

これを行うには、基本的なscikit-learn見積もり(よく知っていることが望ましい)のコードを理解し、fit()、predict()、get_params()などのメソッドが実際にオブジェクトに対して行うことを確認し、その内部。次に、tflearnライブラリを使用して独自のクラスを作成します。

簡単なGoogle検索では、このリポジトリが「テンソルフローフレームワークのための薄いサイクリングスタイルのラッパー」DSLituiev/tflearn(https://github.com/DSLituiev/tflearn)であることがわかります。これがGrid Searchの代替品として機能するかどうかはわかりませんが、一見価値があります。

関連する問題