2017-12-18 16 views
0

PyTorchを学習する目的で遊んでいます。行列を単一のベクトルで掛けるにはどうすればよいですか?PyTorchでベクトルを行列に乗算する方法

は、ここで私が試したものです:一方

>>> import torch 
>>> a = torch.rand(4,4) 
>>> a 

0.3162 0.4434 0.9318 0.8752 
0.0129 0.8609 0.6402 0.2396 
0.5720 0.7262 0.7443 0.0425 
0.4561 0.1725 0.4390 0.8770 
[torch.FloatTensor of size 4x4] 

>>> b = torch.rand(4) 
>>> b 

0.1813 
0.7090 
0.0329 
0.7591 
[torch.FloatTensor of size 4] 

>>> a.mm(b) 
Traceback (most recent call last): 
    File "<stdin>", line 1, in <module> 
RuntimeError: invalid argument 2: dimension 1 out of range of 1D tensor at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensor.c:24 
>>> a.mm(b.t()) 
Traceback (most recent call last): 
    File "<stdin>", line 1, in <module> 
RuntimeError: t() expects a 2D tensor, but self is 1D 
>>> b.mm(a) 
Traceback (most recent call last): 
    File "<stdin>", line 1, in <module> 
RuntimeError: matrices expected, got 1D, 2D tensors at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:1288 
>>> b.t().mm(a) 
Traceback (most recent call last): 
    File "<stdin>", line 1, in <module> 
RuntimeError: t() expects a 2D tensor, but self is 1D 

を、私は

>>> b = torch.rand(4,2) 

を行うならば、私の最初の試み、a.mm(b)は、正常に動作します。ですから、問題は行列ではなくベクトルを乗算することです。しかし、どうすればいいですか?

答えて

1

あなたは将来のために、あなたもtorch.matmul()が有用見つけることが

torch.mv(a,b) 

注意を探しています。 torch.matmul()は、引数の次元数を推測し、それに応じて、ベクトル、行列ベクトルまたはベクトル行列乗算、行列乗算、または高次のテンソルのためのバッチ行列乗算の間のドット積を実行します。

+1

ありがとうございます!私は、将来の訪問者のために、補足的な情報とともに自己回答を書いた。 – Nathaniel

1

これは、mexmexの正解で役に立つ答えを補う自己回答です。

PyTorchでは、numpyと異なり、1Dテンソルは1xNまたはNx1テンソルと互換性がありません。私は

>>> b = torch.rand((4,1)) 

>>> b = torch.rand(4) 

を交換した場合、私はmmと列ベクトル、および行列の乗算を持つことになります期待通りに動作します。

@mexmexには、マトリックスベクトルの乗算のためのmv関数と、入力の次元に応じて適切な関数をディスパッチする関数があるので、これは必ずしも必要ではありません。