torch.cat() 和 torch.stack()区别
1 torch.cat()
torch.cat(tensors,dim=0,out=None)→ Tensortorch.cat()对tensors沿指定维度拼接,但返回的Tensor的维数不会变
>>> import torch
>>> a = torch.rand((2, 3))
>>> b = torch.rand((2, 3))
>>> c = torch.cat((a, b))
>>> a.size(), b.size(), c.size()
(torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([4, 3]))可以看到c和a、b一样都是二维的。
2 torch.stack()
torch.stack(tensors,dim=0,out=None)→ Tensortorch.stack()同样是对tensors沿指定维度拼接,但返回的Tensor会多一维
>>> import torch
>>> a = torch.rand((2, 3))
>>> b = torch.rand((2, 3))
>>> c = torch.stack((a, b))
>>> a.size(), b.size(), c.size()
(torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([2, 2, 3]))评论
