详尽 | PyTorch动态图解析
点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达
本文转自:深度学习这件小事
void THPAutograd_initFunctions()
{
THPObjectPtr module(PyModule_New("torch._C._functions"));
......
generated::initialize_autogenerated_functions();
auto c_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
}
static std::unordered_map<std::type_index, THPObjectPtr> cpp_function_types
2,2],requires_grad=True) > gemfield = torch.empty([
> syszux = gemfield * gemfield
> syszux.grad_fn
<ThMulBackward object at 0x7f111621c350>
gemfield = torch.ones(2, 2, requires_grad=True)
syszux = gemfield + 2
civilnet = syszux * syszux * 3
gemfieldout = civilnet.mean()
gemfieldout.backward()
#Variable实例
gemfield --> grad_fn_ (Function实例)= None
--> grad_accumulator_ (Function实例)= AccumulateGrad实例0x55ca7f304500
--> output_nr_ = 0
#Function实例, 0x55ca7f872e90
AddBackward0实例 --> sequence_nr_ (uint64_t) = 0
--> next_edges_ (edge_list) --> std::vector<Edge> = [(AccumulateGrad实例, 0),(0, 0)]
--> input_metadata_ --> [(type, shape, device)...] = [(CPUFloatType, [2, 2],cpu])]
--> alpha (Scalar) = 1
--> apply() --> 使用 AddBackward0 的apply
#Variable实例
syszux --> grad_fn_ (Function实例)= AddBackward0实例0x55ca7f872e90
--> output_nr_ = 0
#Function实例, 0x55ca7ebba2a0
MulBackward0 --> sequence_nr_ (uint64_t) = 1
--> next_edges_ (edge_list) = [(AddBackward0实例0x55ca7f872e90,0),(AddBackward0实例0x55ca7f872e90,0)]
--> input_metadata_ --> [(type, shape, device)...] = [(CPUFloatType, [2, 2],cpu])]
--> alpha (Scalar) = 1
--> apply() --> 使用 MulBackward0 的apply
# #Variable实例,syszux * syszux得到的tmp
tmp --> grad_fn_ (Function实例)= MulBackward0实例0x55ca7ebba2a0
--> output_nr_ = 0
#Function实例,0x55ca7fada2f0
MulBackward0 --> sequence_nr_ (uint64_t) = 2 (每个线程内自增)
--> next_edges_ (edge_list) = [(MulBackward0实例0x55ca7ebba2a0,0),(0,0)]
--> input_metadata_ --> [(type, shape, device)...] = [(CPUFloatType, [2, 2],cpu])]
--> self_ (SavedVariable) = tmp的浅拷贝
--> other_ (SavedVariable) = 3的浅拷贝
--> apply() --> 使用 MulBackward0 的apply
#Variable实例
civilnet --> grad_fn_ (Function实例)= MulBackward0实例0x55ca7fada2f0 -
#Function实例,0x55ca7eb358b0
MeanBackward0 --> sequence_nr_ (uint64_t) = 3 (每个线程内自增)
--> next_edges_ (edge_list) = [(MulBackward0实例0x55ca7fada2f0,0)]
--> input_metadata_ --> [(type, shape, device)...] = [(CPUFloatType|[]|cpu])]
--> self_sizes (std::vector<int64_t>) = (2, 2)
--> self_numel = 4
--> apply() --> 使用 MulBackward0 的apply
#Variable实例
gemfieldout --> grad_fn_ (Function实例)= MeanBackward0实例0x55ca7eb358b0
--> output_nr_ = 0
using edge_list = std::vector<Edge>;
using variable_list = std::vector<Variable>;
struct TORCH_API Function {
...
virtual variable_list apply(variable_list&& inputs) = 0;
...
const uint64_t sequence_nr_;
edge_list next_edges_;
PyObject* pyobj_ = nullptr; // weak reference
std::unique_ptr<AnomalyMetadata> anomaly_metadata_ = nullptr;
std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_;
std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
at::SmallVector<InputMetadata, 2> input_metadata_;
};
variable_list operator()(variable_list&& inputs) {
return apply(std::move(inputs));
}
struct InputMetadata {
...
const at::Type* type_ = nullptr;
at::DimVector shape_;
at::Device device_ = at::kCPU;
};
struct Edge {
...
std::shared_ptr<Function> function;
uint32_t input_nr;
};
CopySlices : public Function
DelayedError : public Function
Error : public Function
Gather : public Function
GraphRoot : public Function
Scatter : public Function
AccumulateGrad : public Function
AliasBackward : public Function
AsStridedBackward : public Function
CopyBackwards : public Function
DiagonalBackward : public Function
ExpandBackward : public Function
IndicesBackward0 : public Function
IndicesBackward1 : public Function
PermuteBackward : public Function
SelectBackward : public Function
SliceBackward : public Function
SqueezeBackward0 : public Function
SqueezeBackward1 : public Function
TBackward : public Function
TransposeBackward0 : public Function
UnbindBackward : public Function
UnfoldBackward : public Function
UnsqueezeBackward0 : public Function
ValuesBackward0 : public Function
ValuesBackward1 : public Function
ViewBackward : public Function
PyFunction : public Function
struct AccumulateGrad : public Function {
explicit AccumulateGrad(Variable variable_);
variable_list apply(variable_list&& grads) override;
Variable variable;
};
struct GraphRoot : public Function {
GraphRoot(edge_list functions, variable_list inputs)
: Function(std::move(functions)),
outputs(std::move(inputs)) {}
variable_list apply(variable_list&& inputs) override {
return outputs;
}
variable_list outputs;
};
struct TraceableFunction : public Function {
using Function::Function;
bool is_traceable() final {
return true;
}
};
AbsBackward : public TraceableFunction
AcosBackward : public TraceableFunction
AdaptiveAvgPool2DBackwardBackward : public TraceableFunction
AdaptiveAvgPool2DBackward : public TraceableFunction
AdaptiveAvgPool3DBackwardBackward : public TraceableFunction
AdaptiveAvgPool3DBackward : public TraceableFunction
AdaptiveMaxPool2DBackwardBackward : public TraceableFunction
AdaptiveMaxPool2DBackward : public TraceableFunction
AdaptiveMaxPool3DBackwardBackward : public TraceableFunction
AdaptiveMaxPool3DBackward : public TraceableFunction
AddBackward0 : public TraceableFunction
AddBackward1 : public TraceableFunction
AddbmmBackward : public TraceableFunction
AddcdivBackward : public TraceableFunction
AddcmulBackward : public TraceableFunction
AddmmBackward : public TraceableFunction
AddmvBackward : public TraceableFunction
AddrBackward : public TraceableFunction
......
SoftmaxBackwardDataBackward : public TraceableFunction
SoftmaxBackward : public TraceableFunction
......
UpsampleBicubic2DBackwardBackward : public TraceableFunction
UpsampleBicubic2DBackward : public TraceableFunction
UpsampleBilinear2DBackwardBackward : public TraceableFunction
UpsampleBilinear2DBackward : public TraceableFunction
UpsampleLinear1DBackwardBackward : public TraceableFunction
UpsampleLinear1DBackward : public TraceableFunction
UpsampleNearest1DBackwardBackward : public TraceableFunction
UpsampleNearest1DBackward : public TraceableFunction
UpsampleNearest2DBackwardBackward : public TraceableFunction
UpsampleNearest2DBackward : public TraceableFunction
UpsampleNearest3DBackwardBackward : public TraceableFunction
UpsampleNearest3DBackward : public TraceableFunction
UpsampleTrilinear3DBackwardBackward : public TraceableFunction
UpsampleTrilinear3DBackward : public TraceableFunction
......
struct AddBackward0 : public TraceableFunction {
using TraceableFunction::TraceableFunction;
variable_list apply(variable_list&& grads) override;
Scalar alpha;
};
gemfield = torch.ones(2, 2, requires_grad=True)
syszux = gemfield + 2
civilnet = syszux * syszux * 3
gemfieldout = civilnet.mean()
gemfieldout.backward()
struct Engine {
using ready_queue_type = std::deque<std::pair<std::shared_ptr<Function>, InputBuffer>>;
using dependencies_type = std::unordered_map<Function*, int>;
virtual variable_list execute(const edge_list& roots,const variable_list& inputs,...const edge_list& outputs = {});
void queue_callback(std::function<void()> callback);
protected:
void compute_dependencies(Function* root, GraphTask& task);
void evaluate_function(FunctionTask& task);
void start_threads();
virtual void thread_init(int device);
virtual void thread_main(GraphTask *graph_task);
std::vector<std::shared_ptr<ReadyQueue>> ready_queues;
};
struct PythonEngine : public Engine
#torch/tensor.py,self is gemfieldout
def backward(self, gradient=None, retain_graph=None, create_graph=False)
|
V
#torch.autograd.backward(self, gradient, retain_graph, create_graph)
#torch/autograd/__init__.py
def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None)
|
V
Variable._execution_engine.run_backward(tensors, grad_tensors, retain_graph, create_graph,allow_unreachable=True)
#转化为Variable._execution_engine.run_backward((gemfieldout,), (tensor(1.),), False, False,True)
|
V
#torch/csrc/autograd/python_engine.cpp
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
|
V
#torch/csrc/autograd/python_engine.cpp
variable_list PythonEngine::execute(const edge_list& roots, const variable_list& inputs, bool keep_graph, bool create_graph, const edge_list& outputs)
|
V
#torch/csrc/autograd/engine.cpp
总结
交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~
评论