torch.cat() 和 torch.stack()区别
1 torch.cat()
torch.cat
(tensors,dim=0,out=None)→ Tensor
torch.cat()对tensors沿指定维度拼接,但返回的Tensor的维数不会变
>>> import torch
>>> a = torch.rand((2, 3))
>>> b = torch.rand((2, 3))
>>> c = torch.cat((a, b))
>>> a.size(), b.size(), c.size()
(torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([4, 3]))
可以看到c和a、b一样都是二维的。
2 torch.stack()
torch.stack
(tensors,dim=0,out=None)→ Tensor
torch.stack()同样是对tensors沿指定维度拼接,但返回的Tensor会多一维
>>> import torch
>>> a = torch.rand((2, 3))
>>> b = torch.rand((2, 3))
>>> c = torch.stack((a, b))
>>> a.size(), b.size(), c.size()
(torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([2, 2, 3]))
评论