5分钟玩转PyTorch | 详解张量的分割与合并

Python绿色通道

共 3781字,需浏览 8分钟

 ·

2021-11-28 13:08

AI因你而升温,记得加星标哦!

↑ 关注 + 星标 ,每天学Python新技能

后台回复【大礼包】送你Python自学大礼包


在使用PyTorch时,对张量的分割与合并是不可避免的操作,本节就带大家深刻理解张量的分割与合并。

在开始之前,我们先对张量的维度进行深入理解:

t2 = torch.zeros((34))
# tensor([[0., 0., 0., 0.],
#         [0., 0., 0., 0.],
#         [0., 0., 0., 0.]])
        
t2.shape
# torch.Size([3, 4])

重点理解

我们可以把shape的返回结果看成一个序列,代表着各张量维度的信息,第一个数字3代表行,即向量数,第二个数字4代表列,即每个向量中的标量数。

深入理解:t2是由3个一维张量组成,并且每个一维张量都包含四个元素。

张量的分割

chunk(tensor, chunks, dim)

chunk函数能够按照某个维度(dim)对张量进行均匀切分(chunks),并且返回结果是原张量的视图。

# 创建一个4×3的矩阵
t2 = torch.arange(12).reshape(43)
t2
# tensor([[ 0,  1,  2],
#        [ 3,  4,  5],
#        [ 6,  7,  8],
#        [ 9, 10, 11]])

张量可均分时

在第0个维度(shape的第一个数字,代表向量维度)上将t2进行4等分:

# 在矩阵中,第一个维度是行,理解为shape的第一个数
tc = torch.chunk(t2, chunks = 4, dim = 0)
tc
# (tensor([[0, 1, 2]]),
#  tensor([[3, 4, 5]]),
#  tensor([[6, 7, 8]]),
#  tensor([[ 9, 10, 11]]))

根据结果可见:

  1. 返回结果是一个元组,不可变
tc[0] = torch.tensor([[111]])
# TypeError: 'tuple' object does not support item assignment
  1. 元组中的每个值依然是一个二维张量
tc[0]
# tensor([[0, 1, 2]])
  1. 返回的张量tc的一个视图,不是新成了一个对象
# 我们将原张量t2中的数值进行更改
t2[0] = torch.tensor([666])
# 再打印分块后tc的结果
tc
# (tensor([[6, 6, 6]]),
#  tensor([[3, 4, 5]]),
#  tensor([[6, 7, 8]]),
#  tensor([[ 9, 10, 11]]))

若还不懂视图概念,点击这里进行学习


张量不可均分时

若原张量不能均分时,chunk不会报错,会返回次一级均分结果。

# 创建一个4×3的矩阵
t2 = torch.arange(12).reshape(43)
t2
# tensor([[ 0,  1,  2],
#        [ 3,  4,  5],
#        [ 6,  7,  8],
#        [ 9, 10, 11]])

将4行分为3等份,不可分,就会返回2等分的结果:

tc = torch.chunk(t2, chunks = 3, dim = 0)
tc
# (tensor([[0, 2, 2],
#          [3, 4, 5]]), 
#  tensor([[ 6,  7,  8],
#          [ 9, 10, 11]]))

将4行分为5等份,不可分,就会返回4等分的结果:

tc = torch.chunk(t2, chunks = 5, dim = 0)
# (tensor([[0, 2, 2]]),
#  tensor([[3, 4, 5]]),
#  tensor([[6, 7, 8]]),
#  tensor([[ 9, 10, 11]]))

split函数

split既能进行均分,也能进行自定义切分。需要注意的是split的返回结果也是视图。

# 第二个参数只输入一个数值时表示均分
# 第三个参数表示切分的维度
torch.split(t2, 2, dim = 0)
# (tensor([[0, 1, 2],
#          [3, 4, 5]]), 
#  tensor([[ 6,  7,  8],
#          [ 9, 10, 11]]))

chunk函数不同的是,split第二个参数可以输入一个序列,表示按照序列数值等分:

torch.split(t2, [1,3], dim = 0)
# (tensor([[0, 1, 2]]), 
#  tensor([[ 3,  4,  5],
#          [ 6,  7,  8],
#          [ 9, 10, 11]]))

当第二个参数输入一个序列时,序列的各数值的和必须等于对应维度下形状分量的取值,即shape对应的维度。

例如上述代码中,是按照第一个维度进行切分,而t2总共有4行,因此序列的求和必须等于4,也就是1+3=4,而序列中每个分量的取值,则代表切块大小。

torch.split(t2, [112], 0)
# (tensor([[0, 1, 2]]), 
#  tensor([[3, 4, 5]]), 
#  tensor([[ 6,  7,  8],
#         [ 9, 10, 11]]))

将张量第一个维度(行维度)分为1:1:2。

张量的合并

张量的合并操作类似与列表的追加元素,可以进行拼接、也可以堆叠。

这里一定要将dim参数与shape返回的结果相对应理解。

cat拼接函数

a = torch.zeros(23)
a
# tensor([[0., 0., 0.],
#         [0., 0., 0.]])

b = torch.ones(23)
b
# tensor([[1., 1., 1.],
#         [1., 1., 1.]])

因为在张量ab中,shape的第一个位置是代表向量维度,所以当dim取0时,就是将向量进行合并,向量中的标量数不变:

torch.cat([a, b], dim = 0)
# tensor([[0., 0., 0.],
#         [0., 0., 0.],
#         [1., 1., 1.],
#         [1., 1., 1.]])

dim取1时,shape的第二个位置是代表列,即标量数,就是在列上(标量维度)进行拼接,行数(向量数)不变:

torch.cat([a, b], dim = 1)
# tensor([[0., 0., 0., 1., 1., 1.],
##        [0., 0., 0., 1., 1., 1.]])

dimshape结合理解,是不是清晰明了了?

维度有疑惑的同学,点击这里进行学习

stack堆叠函数

和拼接不同,堆叠不是将元素拆分重装,而是将各参与堆叠的对象分装到一个更高维度的张量里。

a = torch.zeros(23)
a
# tensor([[0., 0., 0.],
#         [0., 0., 0.]])

b = torch.ones(23)
b
# tensor([[1., 1., 1.],
#         [1., 1., 1.]])

堆叠之后,生成一个三维张量:

torch.stack([a, b], dim = 0)
# tensor([[[0., 0., 0.],
#          [0., 0., 0.]],
#         [[1., 1., 1.],
#          [1., 1., 1.]]])

torch.stack([a, b], dim = 0).shape
# torch.Size([2, 2, 3])

此例中,就是将两个维度为1×2×3的张量堆叠为一个2×2×3的张量。

cat的区别

拼接之后维度不变,堆叠之后维度升高。拼接是把一个个元素单独提取出来之后再放到二维张量里,而堆叠则是直接将两个二维向量封装到一个三维张量中。因此,堆叠的要求更高,参与堆叠的张量必须形状完全相同。

python对比

a = [1, 2]
b = [3, 4]

cat拼接操作与listextend相似,不会改变维度,只会在已有框架内增加元素:

a.extend(b)
a
# [1, 2, 3, 4]

stack堆叠操作与listappend相似,会改变维度:

a = [1, 2]
b = [3, 4]
a.append(b)
a
# [1, 2, [3, 4]]



推荐阅读

  1. 终于,Python 也可以写前端了

  2. 算力羊毛!2000核时计算资源免费领取!

  3. 您已关注公众号满1年, 诚邀您免费加入网易数据分析培训营!



浏览 194
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报