开辟新视野之高层训练框架 PyTorch-Ignite
1 前言
https://github.com/open-mmlab/mmcv
本文介绍另一个高层封装训练框架 Ignite, 其官方介绍是:PyTorch-Ignite 是一个可帮助在 PyTorch 中灵活透明地训练和评估神经网络的高级库。可以发现 Ignite 对标的是 MMCV 和 Pytorch-Lighting,但是相比 Pytorch-Lighting 更加简单。本文对 Ignite 进行整体性分析希望大家能够开辟新视野,换个姿势了解其他训练框架的封装方式,而不要拘泥于某一种固定的开发模式,阻碍自身成长。
由于 Ignite 内容比较多,本文分析会有侧重于整体分析,无法顾全所有内容。如果想了解的非常透彻,建议和我交流或者留言。
Github 地址:
https://github.com/pytorch/ignite
官方地址:
https://pytorch-ignite.ai/
Docs 地址:
Ignite Your Networks! — PyTorch-Ignite v0.4.7 Documentation
Guides 地址:
https://pytorch-ignite.ai/how-to-guides/
本文代码比较多,手机端阅读体验不佳,建议采用电脑端查看或者后续移步知乎社区,知乎 ID: 深度眸
2 Ignite 特性分析
当前分析版本是 V0.4.7。
2.1 核心特性
其核心特性是:
比纯 PyTorch 更少的代码,同时确保最大程度的控制和简单性
库方法和没有程序控制反转 - 在需要的地方和时间使用 Ignite
用于指标、实验管理器和其他组件的可扩展 API
这里解释下控制反转 (Inversion of control) 。IoC 是一种设计思想,用于解决对象与对象实例化耦合问题,在 Spring 等大型应用程序框架中有着非常多的应用。控制反转是指把传统模式中需要自己通过实例化构造函数,或者通过工厂模式实例化的任务交给容器来避免强耦合。这种做法其实非常常见,和我们常说的依赖抽象而不是依赖实体非常类似。举个最简单的例子:
# 沙丁鱼
class SardineFish:
pass
# 石斑鱼
class GrouperFish:
pass
# 吃饭
class Dining:
def __init__(self):
self.fish=SardineFish()
def eat(self):
return self.fish
dining=Dining()
dining.eat()
今天想吃沙丁鱼,因此直接在 Dining 类中实例化了 SardineFish 类,这是一种非常强的耦合关系,一旦我明天想吃石斑鱼就麻烦了。控制反转可以解决上述问题,将鱼类具体实例化交给容器实例化,Dining 内部只是被动的获取类对象即可,不负责创建实例,而是交给容器类。自己需要主动实例化对象变为被动获取,依赖对象控制权被反转,不需要再考虑如何实例化其他依赖的类。
class SardineFish:
def name(self):
return 'SardineFish'
class GrouperFish:
def name(self):
return 'GrouperFish'
class Container:
def __init__(self):
self.fish_dict={}
def bind(self,fish):
self.fish_dict[fish.name]=fish
def get(self,name):
return self.fish_dict[name]
class Dining:
def __init__(self,container):
self.container= container
def eat(self,name):
return self.container.get(name)
container= Container()
container.bind(SardineFish())
container.bind(GrouperFish())
dining= Dining(container)
dining.eat('GrouperFish')
Ignite 没有程序控制反转,是因为他都是基于方法或者函数进行扩展开发,不存在对象和对象自己的实例化耦合问题。
作者指出其功能上主要特点是:
极其简单的训练引擎和事件系统
开箱即用的指标,可轻松评估模型
用于组成训练 pipeline、保存以及记录参数和指标的内置处理程序
事件是什么?可以简单理解为一个动作 action,例如保存权重就是一个事件, 通过丰富的事件系统可以实现灵活的无侵入的扩展功能。
需要指出的是,Ignite 文档非常多也非常全面,包括 Getting Started、Documentation、Additional Materials、Examples、Tutorials 和 Projects using Ignite 等等,如果你非常有兴趣,可以阅读相关 project,有非常多的实例。
2.2 主要特性
2.2.1 简化训练和验证循环
不再需要为 epoch 和 iterations 手动设置 for/while 循环,用户初始化的实例化引擎会自动处理和运行。这算是作为一个高层训练框架包装器的最基本要求了吧。
2.2.2 强大的 Event 事件和 Handler 处理器
Ignite 处理程序很酷的地方在于它们提供了无与伦比的灵活性(例如与回调相比)。处理程序可以是任何函数:例如 lambda、简单函数、类方法等。因此我们不需要从接口继承并覆盖其抽象方法,这可能会不必要地增加您的代码及其复杂性。
作为一个框架,最需要考虑的是扩展性,MMCV 和 Pytorch-Lighting 都提出了自己的扩展方式,ignite 扩展方式非常简洁,不需要继承并覆写某些抽象方法,而是可以传入任意函数。这也是 ignite 不同于其他两个框架的特点,后面会重点介绍。
(1) 随时执行任意数量的你想要的扩展功能
# 1 注入自定义的事件处理器
trainer.add_event_handler(Events.STARTED, lambda _: print("Start training"))
# attach handler with args, kwargs
mydata = [1, 2, 3, 4]
logger = ...
def on_training_ended(data):
print(f"Training is ended. mydata={data}")
# User can use variables from another scope
logger.info("Training is ended")
# 2 注入自定义的事件处理器
trainer.add_event_handler(Events.COMPLETED, on_training_ended, mydata)# call any number of functions on a single event
# 3 注入自定义的事件处理器
trainer.add_event_handler(Events.COMPLETED, lambda engine: print(engine.state.times))
# 4 注入自定义的事件处理器
def log_something(engine):
print(engine.state.output)
上面例子写了几种 Ignite 支持的扩展开发方式,如果你已经熟悉了 MMCV 的 Hook 开发模式,那么上面例子含义非常容易理解。
(2) 内置事件过滤器
# run the validation every 5 epochs
def run_validation():
# run validation
# change some training variable once on 20th epoch
def change_training_variable():
# ...
# Trigger handler with customly defined frequency
def log_gradients():
# ...
事件过滤器是指基于过滤规则运行指定事件,例如每隔 20 个 epoch 验证一次,跳过前 n 次迭代等等。
(3) 一个事件并集操作共享多个 action
def run_validation():
# ...
这是一种非常好的特性。
(4) 支持标准事件外的自定义事件
from ignite.engine import EventEnum
class BackpropEvents(EventEnum):
BACKWARD_STARTED = 'backward_started'
BACKWARD_COMPLETED = 'backward_completed'
OPTIM_STEP_COMPLETED = 'optim_step_completed'
def update(engine, batch):
# ...
loss = criterion(y_pred, y)
engine.fire_event(BackpropEvents.BACKWARD_STARTED)
loss.backward()
engine.fire_event(BackpropEvents.BACKWARD_COMPLETED)
optimizer.step()
engine.fire_event(BackpropEvents.OPTIM_STEP_COMPLETED)
# ...
trainer = Engine(update)
trainer.register_events(*BackpropEvents)
def function_before_backprop(engine):
# ...
2.2.3 开箱即用的评估指标
假设 Ignite 内置的事件无法满足我的需求则可以自定义事件,如上所示用户自定义了反向传播相关的事件,然后可以通过 register_events 注册从而生效。
目前已经支持了非常多评估指标,例如 Precision, Recall, Accuracy, Confusion Matrix, IoU 等等,当然用户也可以组合或者自定义新的评估指标
precision = Precision(average=False)
recall = Recall(average=False)
F1_per_class = (precision * recall * 2 / (precision + recall))
F1_mean = F1_per_class.mean() # torch mean method
F1_mean.attach(engine, "F1")
3 从一个典型而简单的例子说起
如果直接讲 Ignite 整体设计原则,可能很多人依然觉得难以理解,故先以一个非常简单的分类任务例子说明常用用法,从这个用法中可以说明 Ignite 的大部分特性和设计巧妙之处。
完整代码来自:https://pytorch-ignite.ai/tutorials/beginner/01-getting-started/#complete-code
3.1 初始化必备对象实例
model = Net().to(device)
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
train_loader = DataLoader(
MNIST(download=True, root=".", transform=data_transform, train=True), batch_size=128, shuffle=True
)
val_loader = DataLoader(
MNIST(download=True, root=".", transform=data_transform, train=False), batch_size=256, shuffle=False
)
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.005)
criterion = nn.CrossEntropyLoss()
初始化模型、train loader、val loader、optimizer 和 loss 计算类等。
3.2 初始化训练引擎
trainer = create_supervised_trainer(model, optimizer, criterion, device)
create_supervised_trainer 只是只是一个简单的帮助函数,内部实际上是初始化了一个训练 Engine,核心代码为:
def _update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
model.train()
y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
loss = loss / gradient_accumulation_steps
loss.backward()
if engine.state.iteration % gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
return output_transform(x, y, y_pred, loss)
trainer = Engine(_update)
训练引擎在每一 step 时候都会调用 _update 函数进行推理、loss 计算和参数优化
3.3 定义评估流程
1 定义评估指标
val_metrics = {
"accuracy": Accuracy(),
"loss": Loss(criterion)
}
# 2. 实例化两个新的 engine
# 一个负责训练过程中的评估,一个负责验证过程中的评估
train_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)
val_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)
log_interval = 100
# 3 自定义 handler,并通过 on 装饰器注入到 engine
def log_training_loss(engine):
print(f"Epoch[{engine.state.epoch}], Iter[{engine.state.iteration}] Loss: {engine.state.output:.2f}")
def log_training_results(trainer):
train_evaluator.run(train_loader)
metrics = train_evaluator.state.metrics
print(f"Training Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}")
def log_validation_results(trainer):
val_evaluator.run(val_loader)
metrics = val_evaluator.state.metrics
print(f"Validation Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}")
def score_function(engine):
return engine.state.metrics["accuracy"]
# 4 模型保存 handler
model_checkpoint = ModelCheckpoint(
"checkpoint",
n_saved=2,
filename_prefix="best",
score_function=score_function,
score_name="accuracy",
global_step_transform=global_step_from_engine(trainer),
)
# 将模型保存 handler 注入到验证评估引擎中
val_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model})
不同于我们常规的理解,其 Engine 不是一个,而是 3 个,每个 Engine 负责一个流程,分别是训练 Engine、训练中评估的 Engine 和验证中评估 Engine。训练 Engine 是用于控制训练的循环过程,train_evaluator 是对 train dataloader 进行评估,评估的指标就是前面定义的 val_metrics,val_evaluator 是对 val dataloader 进行评估。三个 engine 相互独立,但是实际上是通过 train engine 组织起来的。
3.4 插入 logger handler
tb_logger = TensorboardLogger(log_dir="tb-logger")
tb_logger.attach_output_handler(
trainer,
event_name=Events.ITERATION_COMPLETED(every=100),
tag="training",
output_transform=lambda loss: {"batch_loss": loss},
)
for tag, evaluator in [("training", train_evaluator), ("validation", val_evaluator)]:
tb_logger.attach_output_handler(
evaluator,
event_name=Events.EPOCH_COMPLETED,
tag=tag,
metric_names="all",
global_step_transform=global_step_from_engine(trainer),
)
tensorboard 非常重要,可以将 tb_logger 插入到任意一个或者多个 Engine 中,例如上面代码是插入到了每个 Engine,插入过程是通过 attach_output_handler 实现的,而 event_name 表示触发时机。
3.5 开启训练
trainer.run(train_loader, max_epochs=5)
tb_logger.close()
开启训练后,在特定时候会触发注入的事件 Handler,例如在每个 epoch 完成后,会进行训练集的评估和验证集的评估,并将所有评估指标保存到 Tensorboard 中。
通过上述的完整例子,大家应该有了第一直观感受,其将函数作为一等公民这一宗旨发挥到了最大化,除了 Engine 类外,其他功能都可以通过函数形式注册进去,实现丰富的扩展功能。
4 Ignite 整体分析
Ignite 主要要理解 Engine、State、Event 和 Handler 这 4 个概念,核心代码位于 ignite/engine/engine.py、ignite/engine/events.py,其关系如下所示:
Engine 负责一个完整的循环流程,可以是一个训练流程,也可以是一个验证流程,整个流程的状态都是通过 State 对象统一维护,而 Events 管理了所有支持的触发事件,如果有自定义事件可以通过 register_events 接口实现,事件触发后具体的任务执行是通过 Handler 对象负责,Handler 可以认为是 Hook 的升级版本,其更加灵活好用,各种扩展功能都可以 Handler 实现,例如 logger、checkpoint、metric 等等。Engine 运行流程本质就是 for 循环,然后在特定点位触发事件,执行 Handler 任务。
4.1 Engine
Engine 是运行流程的核心,但是又非常简单,其核心流程如下:
while epoch < max_epochs:
# run an epoch on data
data_iter = iter(data)
while True:
try:
batch = next(data_iter)
output = process_function(batch)
iter_counter += 1
except StopIteration:
data_iter = iter(data)
if iter_counter == epoch_length:
break
可以看出就是一个典型的 for 循环,process_function 负责处理一个 epoch 的数据。一个典型的例子是:
def train_step(trainer, batch):
model.train()
optimizer.zero_grad()
y = prepare_batch(batch)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return loss.item()
trainer = Engine(train_step) # process_function
max_epochs=100)
用户自定义 train_step 函数,返回啥无所谓,都会直接存储到 trainer.state.output中,后续自己可以针对性处理,这也体现去其灵活的地方了。
def update(engine, batch):
x, y = batch
y_pred = model(inputs)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return {'loss': loss.item(),
'y_pred': y_pred,
'y': y}
trainer = Engine(update)
def print_loss(engine):
epoch = engine.state.epoch
loss = engine.state.output['loss']
print (f'Epoch {epoch}: train_loss = {loss}')
accuracy = Accuracy(output_transform=lambda x: [x['y_pred'], x['y']])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)
同时允许用户手动设置 max_epoch 和 epoch_length,epoch_length 的应用场景可以是 debug 阶段,数据集过大,可以设置 epoch_length 来取其中一小部分,也可用于 dataset 是无限长的场景。
trainer.run(data, max_epochs=100, epoch_length=200)
如果是 GAN 这种复杂场景,支持也非常容易:
model_1 = ...
model_2 = ...
# ...
optimizer_1 = ...
optimizer_2 = ...
# ...
criterion_1 = ...
criterion_2 = ...
# ...
def train_step(trainer, batch):
data_1 = batch["data_1"]
data_2 = batch["data_2"]
# ...
model_1.train()
optimizer_1.zero_grad()
loss_1 = forward_pass(data_1, model_1, criterion_1)
loss_1.backward()
optimizer_1.step()
# ...
model_2.train()
optimizer_2.zero_grad()
loss_2 = forward_pass(data_2, model_2, criterion_2)
loss_2.backward()
optimizer_2.step()
# ...
# User can return any type of structure.
return {
"loss_1": loss_1,
"loss_2": loss_2,
# ...
}
trainer = Engine(train_step)
trainer.run(data, max_epochs=100)
如果对 MMCV 比较了解,你可以认为一个 Engine 就是对应一个 Runner。
4.2 State
State 对象比较好理解,专门用于存储训练中所需的所有状态,实现训练过程和训练状态分离,便于管理。默认情况下有如下状态:
def __init__(self, **kwargs: Any) -> None:
self.iteration = 0
self.epoch = 0
self.epoch_length = None # type: Optional[int]
self.max_epochs = None # type: Optional[int]
self.max_iters = None # type: Optional[int]
self.output = None # type: Optional[int]
self.batch = None # type: Optional[int]
self.metrics = {} # type: Dict[str, Any]
self.dataloader = None # type: Optional[Union[DataLoader, Iterable[Any]]]
self.seed = None # type: Optional[int]
self.times = {
Events.EPOCH_COMPLETED.name: None,
Events.COMPLETED.name: None,
} # type: Dict[str, Optional[float]]
for k, v in kwargs.items():
setattr(self, k, v)
如果你自定义的事件中也有状态要保存,也可以通过 event_to_attr 实现。self.output 保存的就是 update 函数返回值。
4.3 Event
Event 和 Handler 是 Ignite 的核心,要掌握这个框架就必须理解这两个对象。Event 用于记录事件的触发时机,例如每个 epoch 后,每隔 2 个 epoch等等,Handler 是在事件触发后具体的执行器。
因为事件也分成很多类型,故作者也进行了区分:
Events,这个是最基本的事件记录器,典型的是 STARTED、EPOCH_STARTED、ITERATION_STARTED、ITERATION_COMPLETED 等
EventsList,这个是并集操作事件记录器,用于将多个事件堆叠
CallableEventWithFilter,这个是基类,用于提供基于过滤规则触发的事件
其触发的核心伪代码为:
fire_event(Events.STARTED)
while epoch < max_epochs:
fire_event(Events.EPOCH_STARTED)
# run once on data
for batch in data:
fire_event(Events.ITERATION_STARTED)
output = process_function(batch)
fire_event(Events.ITERATION_COMPLETED)
fire_event(Events.EPOCH_COMPLETED)
fire_event(Events.COMPLETED)
虽然有三种类型,但是对用户而言不要操心,因为内部会基于事件类型来确定应该用哪个。
@engine.on(Events.EPOCH_COMPLETED)
events = Events.STARTED | Events.COMPLETED
@engine.on(events)
@engine.on(Events.ITERATION_COMPLETED(every=log_interval))
4.4 Handler
任何一个处理函数都必然要和对应的 Event 事件绑定,不然不知道何时触发。当然用户也可以自定义事件。
Handler 对应一个具体处理事件的 API,其可以是一个函数,可以是一个类方法,可以通过 on 装饰器或者 add_evevnt_hander 注册到引擎中,也可以自身的通过 attach 接口接入 Engine 中。Handler非常类似 Hook,但是这里更加宽泛,不像 Hook 必须是指定的方法名和固定的输入参数,其可以随意设置。
下面是一个最简单的 Handler 函数 log_training_loss
# 每隔 log_interval 且是迭代完成后触发,打印训练 loss
def log_training_loss(engine):
print(f"Epoch[{engine.state.epoch}], Iter[{engine.state.iteration}] Loss: {engine.state.output:.2f}")
这个 Handler 通过装饰器 on 函数注入到 trainer 中。
# 保存模型
val_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint)
这个 Handler 是通过 Engine 自身的 add_event_handler 函数注入到 Engine 中。
# tb_logger 通过 attach 方法注入到 trainer 中
tb_logger = TensorboardLogger(log_dir="tb-logger")
tb_logger.attach_output_handler(
trainer,
event_name=Events.ITERATION_COMPLETED(every=100),
tag="training",
output_transform=lambda loss: {"batch_loss": loss},
)
engine.add_event_handler(name, log_handler, self, name) 方法的。这个 Handler 是通过类本身的 attach 方法注入到 Engine 中,实际上 attach 方法内部也是调用了
除了上面这些比较简单的 Handler,作者还实现了各种各样的 Handler,都可以通过 on、add_event_handler 或者 attach 方式注入到 Engine 中,这种设计解耦性很好,容易维护和扩展。
下面是一个典型的带有 EarlyStopping 和 reduce_lr_plateau handler 的训练流程图:
5 总结
整体来说,Ignite 属于一个小而精的框架,代码量非常少,实现优雅,学习起来非常顺畅。总之其优点可以归纳为:
代码简洁优雅,容易理解
设计了一套独特的扩展模式,解耦合性非常好
但是从目前来看缺点也比较明显:
Engine 做的事情过少,导致很大一部分功能都需要自定义 Handler 实现,自身要维护的代码比较多
整体功能过弱,暂时还无法和 Pytorch-Iighting 这种量级的功能相比,社区活跃度也相差很大,后面可以对 Pytorch-Iighting 这种大型复杂的框架进行整体性分析
写这篇文章的目的正如题目所言,主要是开阔下大家的视野,在设计自身代码的时候可以参考人家开源的高质量代码。当然如果你是重头写一个训练任务,那么尝试 Ignite 也是一种不错的选择。由于内容较多,时间匆忙,如果有不对的地方,可以联系我。
最后,如果你觉得本文对你有帮助,请给 MMCV 点赞
https://github.com/open-mmlab/mmcv
如果有任何疑问,可以直接知乎联系,知乎账号:深度眸
推荐阅读
超实用半监督目标检测 Soft Teacher 及 MMDetection 最强代码实践