2017-06-30 3 views
1

PyTorchのtorch.transpose関数は、2D入力のみを転置します。ドキュメントはhereです。PyTorchでのN次元の変換なし

一方、Tensorflowのtf.transpose関数では、任意の次元のテンソルのNを転置することができます。

誰かがPyTorchにN次元の転置機能がない/できない理由を説明できますか?これは、PyTorch対TensorflowのDefine-Then-runパラダイムにおける計算グラフ構築の動的な性質によるものですか?

答えて

3

これは、単純にピッコロで異なると呼ばれています。 torch.Tensor.permuteは、TensorFlowのtf.transposeのように、pytorchの次元を入れ替えることができます。あなたはNCHWにNHWCから4D画像テンソルを変換したい方法の例として、

(テストしていないので、含まれている可能性があるバグ):

>>> img_nhwc = torch.randn(10, 480, 640, 3) 
>>> img_nhwc.size() 
torch.Size([10, 480, 640, 3]) 
>>> img_nchw = img_nhwc.permute(0, 3, 1, 2) 
>>> img_nchw.size() 
torch.Size([10, 3, 480, 640]) 
関連する問題