2017-07-29 28 views
0

私は次のコードをwritainに持っています。torch.Tensorからアイテムを削除

N最大スコアのインデックスをscoresから取得したいと思います。

現在の最大値をscoresから繰り返し削除して最大値を再度取得する必要がありますが、それを行う適切な方法が見つからないようです。

nqs=dataset['question']:size(1); 
scores=torch.Tensor(nqs,noutput); 
qids=torch.LongTensor(nqs); 
for i=1,nqs,batch_size do 
    xlua.progress(i, nqs) 
    r=math.min(i+batch_size-1,nqs); 
    scores[{{i,r},{}}],qids[{{i,r}}]=forward(i,r); 
-- print(scores) 
end 

tmp,pred=torch.max(scores,2); 

答えて

1

あなたが表示されたコード(特にfoorループが)本当にあなたがやりたいたいに関連していないようですので、私は、私が誤解していなかった願っています。とにかく、ここで私はそれをやるでしょう。

sr=scores:view(-1,scores:size(1)*scores:size(2)) 
val,id=sr:sort() 
--val is a row vector with the values stored in increasing order 
--id will be the corresponding index in sr 
--now you can slice val and id from the end to find the N values you want, then you can recover the original index in the scores matrix simply with 
col=(index-1)%scores:size(2)+1 
row=math.ceil(index/scores:size(2)) 

希望します。

+0

"N値を見つけるために、最後からスライスのvalとid"の部分を詳しく調べることはできますか? – ytrewq

+0

私は 'val {{{1}、{val:size(2)-N + 1、val:size(2)}} 'のようなもので、' id'と同じものです。ソートされたテンソルの終わりに。 – Ash

+0

これは重複の問題を解決していないことに注意してください( 'スコア 'に*が*含まれていれば、最大値の2倍、例えば' 100'のようになります)。しかし、それは問題ではないと思いますあなたの質問に記載されています。 – Ash

関連する問題