2016-07-18 11 views

答えて

2

__getstate____getstate__を使用できます。内部量の大部分は型配列またはスカラであり、したがってhdf5に適しています。 最後にもう__getstate__が返す関数が関数であるため、hdf5ストロークの場合はpickle.dumpsで文字列に変換することができます。

興味深いことに、ソースコードがKDTreehereで、返された値が__getstate__であることがわかります。

from sklearn.neighbors import KDTree 
import h5py 
import pickle 

""" 
You may find the source code of KDTree from link below 
https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/neighbors/binary_tree.pxi 
""" 

__all__ = ["KDTreeH5"] 


class KDTreeH5(KDTree): 
    def dump(self, file): 
     """ 
     file: str or HDF group 
     """ 
     if not isinstance(file, h5py.Group): 
      file = h5py.File(file) 

     state = list(self.__getstate__()) 
     assert len(state) == 12 
     # convert dist_metric to string for hdf5 storage. 
     state[-1] = pickle.dumps(state[-1]) 
     for i, v in enumerate(state): 
      file[str(i)] = v 

    @classmethod 
    def load(cls, file): 
     """ 
     file: str or HDF group 
     """ 
     if not isinstance(file, h5py.Group): 
      file = h5py.File(file, 'r') 

     state = [None] * 12 
     for i in range(12): 
      state[i] = file[str(i)].value 
     # recover dist_metric from string. 
     state[-1] = pickle.loads(state[-1]) 

     obj = cls.__new__(cls) 
     obj.__setstate__(state) 
     return obj 
関連する問題