2017-10-18 15 views
1

map()multiprocessing.Pool()からnumpy.ndarray -subclassのインスタンスのリストに使用すると、独自のクラスの新しい属性が削除されます。マルチ処理.Pool.map()はサブクラス化されたndarrayの属性を削除します

numpy docs subclassing exampleに基づいて、次の最小限の例では、問題を再現:

from multiprocessing import Pool 
import numpy as np 


class MyArray(np.ndarray): 

    def __new__(cls, input_array, info=None): 
     obj = np.asarray(input_array).view(cls) 
     obj.info = info 
     return obj 

    def __array_finalize__(self, obj): 
     if obj is None: return 
     self.info = getattr(obj, 'info', None) 

def sum_worker(x): 
    return sum(x) , x.info 

if __name__ == '__main__': 
    arr_list = [MyArray(np.random.rand(3), info=f'foo_{i}') for i in range(10)] 
    with Pool() as p: 
     p.map(sum_worker, arr_list) 

属性infoは細かい

arr_list = [MyArray(np.random.rand(3), info=f'foo_{i}') for i in range(10)] 
list(map(sum_worker, arr_list2)) 

目的を作品map()組み込みを使用して

AttributeError: 'MyArray' object has no attribute 'info' 

を落としています方法__array_finalize__()の目的は

arr = MyArray([1,2,3], info='foo') 
subarr = arr[:2] 
print(subarr.info) 

をスライスした後、属性を保持していることである。しかしマルチプロセッシングは、別々のプロセスから/にこのデータをシリアル化するためにpickleを使用しているのでPool.map()のために、この方法は何とか...

答えて

2

が動作していません本質的にはthis questionの複製です。その質問から受け入れたソリューションを適応さ

は、あなたの例は次のようになります。

from multiprocessing import Pool 
import numpy as np 

class MyArray(np.ndarray): 

    def __new__(cls, input_array, info=None): 
     obj = np.asarray(input_array).view(cls) 
     obj.info = info 
     return obj 

    def __array_finalize__(self, obj): 
     if obj is None: return 
     self.info = getattr(obj, 'info', None) 

    def __reduce__(self): 
     pickled_state = super(MyArray, self).__reduce__() 
     new_state = pickled_state[2] + (self.info,) 
     return (pickled_state[0], pickled_state[1], new_state) 

    def __setstate__(self, state): 
     self.info = state[-1] 
     super(MyArray, self).__setstate__(state[0:-1]) 

def sum_worker(x): 
    return sum(x) , x.info 

if __name__ == '__main__': 
    arr_list = [MyArray(np.random.rand(3), info=f'foo_{i}') for i in range(10)] 
    with Pool() as p: 
     p.map(sum_worker, arr_list) 

注意、二答えは哀愁がdill代わりのpickleを使用しているので、あなたの非適応元のコードでpathos.multiprocessingを使用することができるかもしれない示唆しています。私はそれをテストしたが、これは動作しませんでした。

関連する問題