实操教程|PyTorch AutoGrad C++层实现
共 9036字,需浏览 19分钟
·
2021-04-13 22:12
极市导读
本文为一篇实操教程,作者介绍了PyTorch AutoGrad C++层实现中各个概念的解释。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
autograd依赖的数据结构
at::Tensor
:shared ptr 指向 TensorImpl
TensorImpl
:对 at::Tensor
的实现
包含一个类型为 [AutogradMetaInterface](c10::AutogradMetaInterface)
的autograd_meta_,在tensor是需要求导的variable时,会被实例化为[AutogradMeta](c10::AutogradMetaInterface)
,里面包含了autograd需要的信息
Variable
: 就是Tensor,为了向前兼容保留的
using Variable = at::Tensor; 概念上有区别, Variable
是需要计算gradient的,Tensor
是不需要计算gradient的Variable
的AutogradMeta
是对[AutogradMetaInterface](c10::AutogradMetaInterface)
的实现,里面包含了一个Variable
,就是该variable的gradient带有version和view 会实例化 AutogradMeta
, autograd需要的关键信息都在这里
AutoGradMeta
: 记录 Variable
的autograd历史信息
包含一个叫grad_的 Variable
, 即AutoGradMeta
对应的var的梯度tensor包含类型为 Node
指针的grad_fn
(var在graph内部时)和grad_accumulator
(var时叶子时), 记录生成grad_的方法包含 output_nr
,标识var对应grad_fn
的输入编号构造函数包含一个类型为 Edge
的gradient_edge,gradient_edge.function
就是grad_fn
, 另外gradient_edge.input_nr
记录着对应grad_fn
的输入编号,会赋值给AutoGradMeta
的output_nr
autograd::Edge
: 指向autograd::Node
的一个输入
包含类型为 Node
指针,表示edge指向的Node包含 input_nr
, 表示edge指向的Node的输入编号
autograd::Node
: 对应AutoGrad Graph中的Op
是所有autograd op的抽象基类,子类重载apply方法
next_edges_
记录出边input_metadata_
记录输入的tensor的metadata实现的子类一般是可求导的函数和他们的梯度计算op
Node in AutoGrad Graph
Variable通过Edge关联Node的输入和输出 多个Edge指向同一个Var时,默认做累加 call operator
最重要的方法,实现计算 next_edge
缝合Node的操作 获取Node的出边,next_edge(index)/next_edges() add_next_edge(),创建
前向计算
PyTorch通过tracing只生成了后向AutoGrad Graph.
代码是生成的,需要编译才能看到对应的生成结果
gen_variable_type.py生成可导版本的op 生成的代码在 pytorch/torch/csrc/autograd/generated/
前向计算时,进行了tracing,记录了后向计算图构建需要的信息 这里以relu为例,代码在 pytorch/torch/csrc/autograd/generated/VariableType_0.cpp
Tensor relu(const Tensor & self) {
auto& self_ = unpack(self, "self", 0);
std::shared_ptr<ReluBackward0> grad_fn;
if (compute_requires_grad( self )) { // 如果输入var需要grad
// ReluBackward0的类型是Node
grad_fn = std::shared_ptr<ReluBackward0>(new ReluBackward0(), deleteNode);
// collect_next_edges(var)返回输入var对应的指向的
// grad_fn(前一个op的backward或者是一个accumulator的)的输入的Edge
// set_next_edges(),在grad_fn中记录这些Edge(这里完成了后向的构图)
grad_fn->set_next_edges(collect_next_edges( self ));
// 记录当前var的一个版本
grad_fn->self_ = SavedVariable(self, false);
}
c10::optional<Storage> self__storage_saved =
self_.has_storage() ? c10::optional<Storage>(self_.storage()) : c10::nullopt;
c10::intrusive_ptr<TensorImpl> self__impl_saved;
if (self_.defined()) self__impl_saved = self_.getIntrusivePtr();
auto tmp = ([&]() {
at::AutoNonVariableTypeMode non_var_type_mode(true);
return at::relu(self_); // 前向计算
})();
auto result = std::move(tmp);
if (self__storage_saved.has_value())
AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage()));
if (self__impl_saved) AT_ASSERT(self__impl_saved == self_.getIntrusivePtr());
if (grad_fn) {
// grad_fn增加一个输入,记录输出var的metadata作为grad_fn的输入
// 输出var的AutoGradMeta实例化,输出var的AutoGradMeta指向起grad_fn的输入
set_history(flatten_tensor_args( result ), grad_fn);
}
return result;
}
可以看到和 grad_fn
相关的操作trace了一个op的计算,构建了后向计算图.
后向计算
autograd::backward()
:计算output var的梯度值,调用的 run_backward()
autograd::grad()
:计算有output var和到特定input的梯度值,调用的 run_backward()
autograd::run_backward()
对于要求梯度的output var,获取其指向的grad_fn作为roots,是后向图的起点 对于有input var的,获取其指向的grad_fn作为output_edges, 是后向图的终点 调用 autograd::Engine::get_default_engine().execute(...)
执行后向计算
autograd::Engine::execute(...)
创建
GraphTask
,记录了一些配置信息创建
GraphRoot
,是一个Node,把所有的roots作为其输出边,Node的apply()返回的是roots的grad【这里已经得到一个单起点的图】计算依赖
compute_dependencies(...)
从GraphRoot开始,广度遍历,记录所有碰到的grad_fn的指针,并统计grad_fn被遇到的次数,这些信息记录到GraphTask中 GraphTask
初始化:当有input var时,判断后向图中哪些节点是真正需要计算的GraphTask
执行选择CPU or GPU线程执行 以CPU为例,调用的 autograd::Engine::thread_main(...)
autograd::Engine::thread_main(...)
evaluate_function(...)
,输入输出的处理,调度call_function(...)
, 调用对应的Node计算执行后向过程中的生成的中间grad Tensor,如果不释放,可以用于计算高阶导数;(同构的后向图,之前的grad tensor是新的输出,grad_fn变成之前grad_fn的backward,这些新的输出还可以再backward) 具体的执行机制可以支撑单独开一个Topic分析,在这里讨论到后向图完成构建为止.
推荐阅读
2021-04-11
2021-04-08
2021-04-07
# CV技术社群邀请函 #
备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳)
即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群
每月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企视觉开发者互动交流~