2017-08-21 13 views
4

配列の一部のセグメントに値を割り当てたいとします。 セグメントのインデックスがタプル(start_idx、end_idx)になっています。 セグメントは、互いにオーバーレイするか、またはサブセグメントである可能性があります。オーバーレイされたインデックスのスライスでnumpy配列に値を代入

a = np.zeros(12) 
segments = np.array([(0, 3), (1, 2), (6, 8), (8, 10)]) 
a[segments] = 1 

結果は次のとおりです。

a 
>> array([1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0]) 

がどのように私は、この出力を得るために、すべてのセグメントをマスクすることができます。

a 
>> array([1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0]) 
+0

希望する出力が正しいですか? '[3]'要素が間違っているようですね!それは '0'ではないでしょうか? – MSeifert

+0

@MSeifertセグメントの終了を含むようです。 – Divakar

+0

@Divakarしかし、「10」はどのように処理されるべきですか?通常、少なくとも 'a = np.zeros(10)'では 'IndexError'になります。 – MSeifert

答えて

1

一つのオプションがにありますセグメントを個だけループしてください、および変換範囲実際のインデックスへ

a = np.zeros(10) 
segments = np.array([(0, 3), (1, 2), (6, 8), (8, 10)]) 

a[[i for s in segments for i in range(*s)]] = 1  
a 
# array([ 1., 1., 1., 0., 0., 0., 1., 1., 1., 1.]) 
1

を試してみて、この:

a = np.zeros(10) 
segments = np.array([(0, 3), (1, 2), (6, 8), (8, 10)]) 
a[range(3)+range(1,2)+range(6,8)+range(8,10)] = 1 
print (a) 
+0

手作業で割り当てることはできますか?私はリストジェネレータで範囲を作成することはできますが、それらをフィードする方法は? –

+0

その場合は、 –

+0

をループする必要があります。これは、Python-3.xで 'TypeError:サポートされていないオペランドの型が+:' range 'と' range''にスローされることに注意してください。 – MSeifert

2

ここthis postから借用されているという考えを持つ1つのベクトル化されたアプローチだ -

def segment_arr(segments, L): # L being length of output array 
    s = np.zeros(L,dtype=int) 
    stop = segments[:,1]+1 
    np.add.at(s,segments[:,0],1) 
    np.add.at(s,stop[stop<len(s)],-1) 
    return (s.cumsum()>0).astype(int) 

サンプル実行 -

In [298]: segments = np.array([(0, 3), (1, 2), (6, 8), (8, 10)]) 

In [299]: segment_arr(segments, L=12) 
Out[299]: array([1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0]) 
+1

少なくとも*実際には「与えられた期待される生産量」を生み出す。 :) – MSeifert

1

ちょうど言及するが、些細な解決策:for -loopをで使う10とスライスに割り当てます:

import numpy as np 
a = np.zeros(12) 
segments = np.array([(0, 3), (1, 2), (6, 8), (8, 10)]) 

for seg in segments.tolist(): # the "tolist" is just an optimization here, you *could* omit it. 
    a[seg[0]: seg[1]+1] = 1 # or just "seq[1]" if you want to exclude the end point 
print(a) 
# array([ 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 0.]) 
関連する問題