深入理解Pytroch中的hook机制
共 1014字,需浏览 3分钟
·
2021-01-12 05:56
【GiantPandaCV导语】Pytorch 中的 hook 机制可以很方便的让用户往计算图中注入控制代码,这样就可以通过自定义各种操作来修改计算图中的张量。
点击小程序观看视频(时长22分)
视频太长不看版:
Pytorch 中的 hook 机制可以很方便的让用户往计算图中注入控制代码(注入的代码也可以删除),这样用户就可以通过自定义各种操作来修改计算图中的张量。
Pytroch 中主要有两种hook,分别是注册在Tensor上的hook和注册在Module上的 hook。
注册在 Tensor 上的 hook,可以在反向回传过程中对梯度作修改,分为两种:
叶子节点上的hook
中间张量上的hook
在输出梯度传入 backward 函数计算输入梯度之前,调用注册的hook的函数对梯度做一些操作
会在 AccumulateGrad 之前对梯度做一些操作
注意:
最好不要在hook函数中对梯度做 inplace 修改,因为会直接修改该梯度张量,
如果该op有多个输入,比如 add op,那么在反向阶段,如果其中一个张量上注册的hook函数对梯度做了inplace修改,那么就会有可能影响到另一个输入张量的梯度。
注册在 Module 上的 hook,则可以在前后过程中对张量作修改,主要有三种:
在module的前向被调用之前调用的hook函数
在module的前向被调用之后调用的hook函数
后向过程会调用的hook
对Module的输入张量做一些操作
对Module的输入和输出张量做一些操作
可以打印module输入张量的梯度,但是目前还有bug,建议不要用。
github上相关的讨论:https://github.com/pytorch/pytorch/issues/598
为了感谢读者的长期支持,今天我们将送出三本由 机械工业出版社 提供的:《分布式人工智能:基于TensorFlow、RTOS与群体智能体系》 。点击下方抽奖助手参与抽奖。没抽到并且对本书有兴趣的也可以使用下方链接进行购买。
《分布式人工智能:基于TensorFlow、RTOS与群体智能体系》抽奖链接