5分钟玩转PyTorch | 张量广播计算的本质是什么?
AI因你而升温,记得加星标哦!
↑ 关注 + 星标 ,每天学Python新技能
后台回复【大礼包】送你Python自学大礼包
PyTorch
中的张量具有和NumPy
相同的广播特性,允许不同形状的张量之间进行计算。
广播的实质特性,其实是低维向量映射到高维之后,相同位置再进行相加。我们重点要学会的就是低维向量如何向高维向量进行映射。
相同形状的张量计算
虽然我们觉得不同形状之间的张量计算才是广播,但其实相同形状的张量计算本质上也是广播。
t1 = torch.arange(3)
t1
# tensor([0, 1, 2])
# 对应位置元素相加
t1 + t1
# tensor([0, 2, 4])
与Python对比
如果两个list
相加,结果是什么?
a = [0, 1, 2]
a + a
# [0, 1, 2, 0, 1, 2]
不同形状的张量计算
广播的特性是不同形状的张量进行计算时,一个或多个张量通过隐式转化成相同形状的两个张量,从而完成计算。
但并非任意两个不同形状的张量都能进行广播,因此我们要掌握广播隐式转化的核心依据。
2.1 标量和任意形状的张量
标量(零维张量)可以和任意形状的张量进行计算,计算过程就是标量和张量的每一个元素进行计算。
# 标量与一维向量
t1 = torch.arange(3)
# tensor([0, 1, 2])
t1 + 1 # 等效于t1 + torch.tensor(1)
# tensor([1, 2, 3])
# 标量与二维向量
t2 = torch.zeros((3, 4))
t2 + 1 # 等效于t2 + torch.tensor(1)
# tensor([[1., 1., 1., 1.],
# [1., 1., 1., 1.],
# [1., 1., 1., 1.]])
2.2 相同维度,不同形状张量之间的计算
我们以t2
为例来探讨相同维度、不同形状的张量之间的广播规则。
t2 = torch.zeros(3, 4)
t2
# tensor([[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]])
t21 = torch.ones(1, 4)
t21
# tensor([[1., 1., 1., 1.]])
它们都是二维矩阵,t21
的形状是1×4
,t2
的形状是3×4
,它们在第一个分量上取值不同,但该分量上t21
取值为1,因此可以进行广播计算:
t2 + t21
# tensor([[1., 1., 1., 1.],
# [1., 1., 1., 1.],
# [1., 1., 1., 1.]])
而t2和t21的实际计算过程如下:可理解为t21
的一行与t2
的三行分别进行了相加。而底层原理为t21
的形状由1×4
拓展成了t2
的3×4
,然后二者对应位置进行了相加。
t22 = torch.ones(3, 1)
t22
# tensor([[1.],
# [1.],
# [1.]])
t2 + t22
# tensor([[1., 1., 1., 1.],
# [1., 1., 1., 1.],
# [1., 1., 1., 1.]])
同理,t22+t2
与t21+t2
结果相同。如果矩阵的两个维度都不相同呢?
t23 = torch.arange(3).reshape(3, 1)
t23
# tensor([[0],
# [1],
# [2]])
t24 = torch.arange(3).reshape(1, 3)
# tensor([[0, 1, 2]])
t23 + t24
# tensor([[0, 1, 2],
# [1, 2, 3],
# [2, 3, 4]])
此时,t23
的形状是3×1,而t24
的形状是1×3
,二者的形状在两个份量上均不同,但都有1存在,因此可以广播:
如果两个张量的维度对应数不同且都不为1,那么就无法广播。
t25 = torch.ones(2, 4)
# t2的shape为3×4
t2 + t25
# RuntimeError
高维张量的广播
高维张量的广播原理与低维张量的广播原理一致:
t3 = torch.zeros(2, 3, 4)
t3
# tensor([[[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]],
# [[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]]])
t31 = torch.ones(2, 3, 1)
t31
# tensor([[[1.],
# [1.],
# [1.]],
# [[1.],
# [1.],
# [1.]]])
t3+t31
# tensor([[[1., 1., 1., 1.],
# [1., 1., 1., 1.],
# [1., 1., 1., 1.]],
# [[1., 1., 1., 1.],
# [1., 1., 1., 1.],
# [1., 1., 1., 1.]]])
总结
维度相同时,如果对应分量不同,但有一个为1,就可以广播。
不同维度计算中的广播
对于不同维度的张量,我们首先可以将低维的张量升维,然后依据相同维度不同形状的张量广播规则进行广播。
低维向量的升维也非常简单,只需将更高维度方向的形状填充为1即可:
# 创建一个二维向量
t2 = torch.arange(4).reshape(2, 2)
t2
# tensor([[0, 1],
# [2, 3]])
# 创建一个三维向量
t3 = torch.zeros(3, 2, 2)
t3
t2 + t3
# tensor([[[0., 1.],
# [2., 3.]],
# [[0., 1.],
# [2., 3.]],
# [[0., 1.],
# [2., 3.]]])
t3
和t2
的相加,就相当于1×2×2
和3×2×2
的两个张量进行计算,广播规则与低维张量一致。
相信看完本节,你已经充分掌握了广播机制的运算规则:
维度相同时,如果对应分量不同,但有一个为1,就可以广播 维度不同时,只需将低维向量的更高维度方向的形状填充为1即可
推荐阅读
推荐阅读