PyTorch扩展自定义PyThon/C++(CUDA)算子的若干方法总结

极市导读
关于PyTorch构建扩展的一些基础操作,官方往往已经出具了完整的教程。本文对这些官方教程的链接进行了整理,以供读者查阅。
第一种情况:使用PyThon扩展PyTorch
torch.nn)只需要继承torch.nn.Module并实现其forward方法即可。详细的过程请参考官方教程传送门:第二种情况:使用pybind11构建共享库形式的C++和CUDA扩展
forward函数可以接受任意多的参数并且应该返回一个 variable list或者variable;forward函数需要将[torch::autograd::AutogradContext](https://link.zhihu.com/?target=https%3A//pytorch.org/cppdocs/api/structtorch_1_1autograd_1_1_autograd_context.html%23structtorch_1_1autograd_1_1_autograd_context) 作为自己的第一个参数。Variables可以被使用ctx->save_for_backward保存,而其他数据类型可以使用ctx->saved_data以pairs的形式保存在一个map中。backward函数第一个参数同样需要为torch::autograd::AutogradContext,其余的参数是一个variable_list,包含的变量数量与forward输出的变量数量相等。它应该返回和forward输入一样多的变量。保存在forward中的Variable变量可以通过ctx->get_saved_variables而其他的数据类型可以通过ctx->saved_data获取。// PyG的C++扩展就选择的是直接继承PyTorch的C++端的torch::autograd类进行扩展// 下面是PyG的一个ScatterSum算子的扩展示例// 不用纠结这个算子的具体内容,对扩展的算子的结构有一个大致了解即可class ScatterSum : public torch::autograd::Function{ public:// AutogradContext *ctx指针可以操作static variable_list forward(AutogradContext *ctx, Variable src,Variable index, int64_t dim,torch::optionaloptional_out, torch::optionaldim_size) { dim = dim < 0 ? src.dim() + dim : dim;ctx->saved_data["dim"] = dim;ctx->saved_data["src_shape"] = src.sizes();index = broadcast(index, src, dim);auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");auto out = std::get<0>(result);ctx->save_for_backward({index});// 如果在扩展的C++代码中使用非Aten内建操作修改了tensor的值,需要对其进行脏标记if (optional_out.has_value())ctx->mark_dirty({optional_out.value()});return {out};}// grad_outs是out参数反传回来的梯度值static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {auto grad_out = grad_outs[0];auto saved = ctx->get_saved_variables();auto index = saved[0];auto dim = ctx->saved_data["dim"].toInt();auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());auto grad_in = torch::gather(grad_out, dim, index, false);// 不需要求导的参数需要空Variable占位return {grad_in, Variable(), Variable(), Variable(), Variable()};}};

Tensor类,在其上定义了数百种操作。这些操作大多数都具有CPU和GPU实现,Tensor该类将根据其类型向其动态调度。和Torch相比Aten更接近底层和核心逻辑。第三种情况:为TORCHSCRIPT添加C++和CUDA扩展
推荐阅读

评论
