5分钟玩转PyTorch | 详解张量的分割与合并
AI因你而升温,记得加星标哦!
↑ 关注 + 星标 ,每天学Python新技能
后台回复【大礼包】送你Python自学大礼包
在使用PyTorch
时,对张量的分割与合并是不可避免的操作,本节就带大家深刻理解张量的分割与合并。
在开始之前,我们先对张量的维度进行深入理解:
t2 = torch.zeros((3, 4))
# tensor([[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]])
t2.shape
# torch.Size([3, 4])
重点理解
我们可以把shape的返回结果看成一个序列,代表着各张量维度的信息,第一个数字3
代表行,即向量数,第二个数字4
代表列,即每个向量中的标量数。
深入理解:t2
是由3个一维张量组成,并且每个一维张量都包含四个元素。
张量的分割
chunk(tensor, chunks, dim)
chunk
函数能够按照某个维度(dim
)对张量进行均匀切分(chunks
),并且返回结果是原张量的视图。
# 创建一个4×3的矩阵
t2 = torch.arange(12).reshape(4, 3)
t2
# tensor([[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11]])
张量可均分时
在第0个维度(shape
的第一个数字,代表向量维度)上将t2
进行4等分:
# 在矩阵中,第一个维度是行,理解为shape的第一个数
tc = torch.chunk(t2, chunks = 4, dim = 0)
tc
# (tensor([[0, 1, 2]]),
# tensor([[3, 4, 5]]),
# tensor([[6, 7, 8]]),
# tensor([[ 9, 10, 11]]))
根据结果可见:
返回结果是一个元组,不可变
tc[0] = torch.tensor([[1, 1, 1]])
# TypeError: 'tuple' object does not support item assignment
元组中的每个值依然是一个二维张量
tc[0]
# tensor([[0, 1, 2]])
返回的张量 tc
的一个视图,不是新成了一个对象
# 我们将原张量t2中的数值进行更改
t2[0] = torch.tensor([6, 6, 6])
# 再打印分块后tc的结果
tc
# (tensor([[6, 6, 6]]),
# tensor([[3, 4, 5]]),
# tensor([[6, 7, 8]]),
# tensor([[ 9, 10, 11]]))
张量不可均分时
若原张量不能均分时,chunk不会报错,会返回次一级均分结果。
# 创建一个4×3的矩阵
t2 = torch.arange(12).reshape(4, 3)
t2
# tensor([[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11]])
将4行分为3等份,不可分,就会返回2等分的结果:
tc = torch.chunk(t2, chunks = 3, dim = 0)
tc
# (tensor([[0, 2, 2],
# [3, 4, 5]]),
# tensor([[ 6, 7, 8],
# [ 9, 10, 11]]))
将4行分为5等份,不可分,就会返回4等分的结果:
tc = torch.chunk(t2, chunks = 5, dim = 0)
# (tensor([[0, 2, 2]]),
# tensor([[3, 4, 5]]),
# tensor([[6, 7, 8]]),
# tensor([[ 9, 10, 11]]))
split
函数
split
既能进行均分,也能进行自定义切分。需要注意的是split
的返回结果也是视图。
# 第二个参数只输入一个数值时表示均分
# 第三个参数表示切分的维度
torch.split(t2, 2, dim = 0)
# (tensor([[0, 1, 2],
# [3, 4, 5]]),
# tensor([[ 6, 7, 8],
# [ 9, 10, 11]]))
与chunk
函数不同的是,split
第二个参数可以输入一个序列,表示按照序列数值等分:
torch.split(t2, [1,3], dim = 0)
# (tensor([[0, 1, 2]]),
# tensor([[ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11]]))
当第二个参数输入一个序列时,序列的各数值的和必须等于对应维度下形状分量的取值,即shape
对应的维度。
例如上述代码中,是按照第一个维度进行切分,而t2
总共有4行,因此序列的求和必须等于4,也就是1+3=4,而序列中每个分量的取值,则代表切块大小。
torch.split(t2, [1, 1, 2], 0)
# (tensor([[0, 1, 2]]),
# tensor([[3, 4, 5]]),
# tensor([[ 6, 7, 8],
# [ 9, 10, 11]]))
将张量第一个维度(行维度)分为1:1:2。
张量的合并
张量的合并操作类似与列表的追加元素,可以进行拼接、也可以堆叠。
这里一定要将dim
参数与shape
返回的结果相对应理解。
cat
拼接函数
a = torch.zeros(2, 3)
a
# tensor([[0., 0., 0.],
# [0., 0., 0.]])
b = torch.ones(2, 3)
b
# tensor([[1., 1., 1.],
# [1., 1., 1.]])
因为在张量a
与b
中,shape
的第一个位置是代表向量维度,所以当dim
取0时,就是将向量进行合并,向量中的标量数不变:
torch.cat([a, b], dim = 0)
# tensor([[0., 0., 0.],
# [0., 0., 0.],
# [1., 1., 1.],
# [1., 1., 1.]])
当dim
取1时,shape
的第二个位置是代表列,即标量数,就是在列上(标量维度)进行拼接,行数(向量数)不变:
torch.cat([a, b], dim = 1)
# tensor([[0., 0., 0., 1., 1., 1.],
## [0., 0., 0., 1., 1., 1.]])
将dim
与shape
结合理解,是不是清晰明了了?
stack
堆叠函数
和拼接不同,堆叠不是将元素拆分重装,而是将各参与堆叠的对象分装到一个更高维度的张量里。
a = torch.zeros(2, 3)
a
# tensor([[0., 0., 0.],
# [0., 0., 0.]])
b = torch.ones(2, 3)
b
# tensor([[1., 1., 1.],
# [1., 1., 1.]])
堆叠之后,生成一个三维张量:
torch.stack([a, b], dim = 0)
# tensor([[[0., 0., 0.],
# [0., 0., 0.]],
# [[1., 1., 1.],
# [1., 1., 1.]]])
torch.stack([a, b], dim = 0).shape
# torch.Size([2, 2, 3])
此例中,就是将两个维度为1×2×3的张量堆叠为一个2×2×3的张量。
与cat
的区别
拼接之后维度不变,堆叠之后维度升高。拼接是把一个个元素单独提取出来之后再放到二维张量里,而堆叠则是直接将两个二维向量封装到一个三维张量中。因此,堆叠的要求更高,参与堆叠的张量必须形状完全相同。
与python
对比
a = [1, 2]
b = [3, 4]
cat
拼接操作与list
的extend
相似,不会改变维度,只会在已有框架内增加元素:
a.extend(b)
a
# [1, 2, 3, 4]
stack
堆叠操作与list
的append
相似,会改变维度:
a = [1, 2]
b = [3, 4]
a.append(b)
a
# [1, 2, [3, 4]]
推荐阅读
推荐阅读