torch.cat() 和 torch.stack()区别

pytorch玩转深度学习

共 806字,需浏览 2分钟

 ·

2021-03-13 15:22

1 torch.cat()

torch.cat(tensors,dim=0,out=None)→ Tensor

torch.cat()对tensors沿指定维度拼接,但返回的Tensor的维数不会变

>>> import torch
>>> a = torch.rand((2, 3))
>>> b = torch.rand((2, 3))
>>> c = torch.cat((a, b))
>>> a.size(), b.size(), c.size()
(torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([4, 3]))

可以看到c和a、b一样都是二维的。

2 torch.stack()

torch.stack(tensors,dim=0,out=None)→ Tensor

torch.stack()同样是对tensors沿指定维度拼接,但返回的Tensor会多一维

>>> import torch
>>> a = torch.rand((2, 3))
>>> b = torch.rand((2, 3))
>>> c = torch.stack((a, b))
>>> a.size(), b.size(), c.size()
(torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([2, 2, 3]))


浏览 19
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报