深入理解Pytroch中的hook机制

GiantPandaCV

共 1014字,需浏览 3分钟

 ·

2021-01-12 05:56

【GiantPandaCV导语】Pytorch 中的 hook 机制可以很方便的让用户往计算图中注入控制代码,这样就可以通过自定义各种操作来修改计算图中的张量。


点击小程序观看视频(时长22分)

视频太长不看版:


Pytorch 中的 hook 机制可以很方便的让用户往计算图中注入控制代码(注入的代码也可以删除),这样用户就可以通过自定义各种操作来修改计算图中的张量。



Pytroch 中主要有两种hook,分别是注册在Tensor上的hook和注册在Module上的 hook。


注册在 Tensor 上的 hook,可以在反向回传过程中对梯度作修改,分为两种:

  • 叶子节点上的hook

  • 会在 AccumulateGrad 之前对梯度做一些操作 

  • 中间张量上的hook

    在输出梯度传入 backward 函数计算输入梯度之前,调用注册的hook的函数对梯度做一些操作


注意:

最好不要在hook函数中对梯度做 inplace 修改,因为会直接修改该梯度张量,

如果该op有多个输入,比如 add op,那么在反向阶段,如果其中一个张量上注册的hook函数对梯度做了inplace修改,那么就会有可能影响到另一个输入张量的梯度。



注册在 Module 上的 hook,则可以在前后过程中对张量作修改,主要有三种:

  • 在module的前向被调用之前调用的hook函数

  • 对Module的输入张量做一些操作

  • 在module的前向被调用之后调用的hook函数

  • 对Module的输入和输出张量做一些操作

  • 后向过程会调用的hook

  • 可以打印module输入张量的梯度,但是目前还有bug,建议不要用。

    github上相关的讨论:https://github.com/pytorch/pytorch/issues/598



为了感谢读者的长期支持,今天我们将送出三本由 机械工业出版社 提供的:《分布式人工智能:基于TensorFlow、RTOS与群体智能体系》 。点击下方抽奖助手参与抽奖。没抽到并且对本书有兴趣的也可以使用下方链接进行购买。

《分布式人工智能:基于TensorFlow、RTOS与群体智能体系》抽奖链接

浏览 28
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报