PyTorch扩展自定义PyThon/C++(CUDA)算子的若干方法总结
极市平台
共 5828字,需浏览 12分钟
·
2020-09-07 23:44
极市导读
关于PyTorch构建扩展的一些基础操作,官方往往已经出具了完整的教程。本文对这些官方教程的链接进行了整理,以供读者查阅。
第一种情况:使用PyThon扩展PyTorch
torch.nn
)只需要继承torch.nn.Module并实现其forward方法即可。详细的过程请参考官方教程传送门:第二种情况:使用pybind11构建共享库形式的C++和CUDA扩展
forward
函数可以接受任意多的参数并且应该返回一个 variable list或者variable;forward函数需要将[torch::autograd::AutogradContext](https://link.zhihu.com/?target=https%3A//pytorch.org/cppdocs/api/structtorch_1_1autograd_1_1_autograd_context.html%23structtorch_1_1autograd_1_1_autograd_context)
作为自己的第一个参数。Variables可以被使用ctx->save_for_backward保存,而其他数据类型可以使用ctx->saved_data以
pairs的形式保存在一个map中。backward
函数第一个参数同样需要为torch::autograd::AutogradContext,其余的参数是一个variable_list,包含的变量数量与forward
输出的变量数量相等。它应该返回和forward
输入一样多的变量。保存在forward中的Variable变量可以通过ctx->get_saved_variables而其他的数据类型可以通过ctx->saved_data获取。// PyG的C++扩展就选择的是直接继承PyTorch的C++端的torch::autograd类进行扩展
// 下面是PyG的一个ScatterSum算子的扩展示例
// 不用纠结这个算子的具体内容,对扩展的算子的结构有一个大致了解即可
class ScatterSum : public torch::autograd::Function
{ public:
// AutogradContext *ctx指针可以操作
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional
optional_out, torch::optional
dim_size) { dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
ctx->save_for_backward({index});
// 如果在扩展的C++代码中使用非Aten内建操作修改了tensor的值,需要对其进行脏标记
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
// grad_outs是out参数反传回来的梯度值
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::gather(grad_out, dim, index, false);
// 不需要求导的参数需要空Variable占位
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
Tensor
类,在其上定义了数百种操作。这些操作大多数都具有CPU和GPU实现,Tensor
该类将根据其类型向其动态调度。和Torch相比Aten更接近底层和核心逻辑。第三种情况:为TORCHSCRIPT添加C++和CUDA扩展
推荐阅读
评论