EMA在detectron2中的实现
共 20299字,需浏览 41分钟
·
2021-09-18 21:04
点蓝色字关注“机器学习算法工程师”
设为星标,干货直达!
近期很流行的一些检测模型如YOLOv5和YOLOX都包含了很多的tricks,如数据增强(MixUp, Mosaic)等,其中EMA也是一种常采用的trick。EMA全称为Exponential Moving Average,最早是在TensorFlow中出现(具体实现为tf.train.ExponentialMovingAverage),简单来说,在模型训练过程中对模型参数计算指数移动平均,得到的模型参数要比最后训练得到的模型参数在效果上可能要好一点。从某种意义上来看,EMA有点像模型集成,但是它在测试时不需要额外的负担,在训练过程只是多消耗一份显存(多一份模型参数)以及训练过程稍多一点开销(对参数进行移动平均,耗时很小)。
EMA的实现也很简单,对模型参数params只需要多维护一份参数ema_params就好,然后在每个训练step后,对每一个模型参数进行移动平均:
这里的decay是一个超参数,一般取值接近1,比如设置为0.999。可以看到EMA比较通用,几乎适用于任何模型训练中。
目前商汤开源的mmdet框架已经复现了YOLOX,里面也包含了EMA的实现。而目前Facebook AI的detectron2还没有包含EMA的实现,但是其移动端版本D2Go已经实现了EMA,两个版本其实是互通的,只有略微的差别。这里就讲一下如何将D2Go的EMA应用到detectron2中,这主要包括三个部分:模型中添加EMA参数、训练过程中进行更新以及测试时使用EMA参数。
EMA需要多维护一份模型参数,就是EMA参数,这里定义一个EMAState类来存储EMA参数,这个类里面的state字典存储EMA参数。这里的get_model_state_iterator方法是获得模型的参数,包括训练参数params以及buffers,BN的一些参数moving_mean和moving_var属于buffers,一般情况下对BN的moving_mean和moving_var也进行EMA效果会更好一点。
class EMAState(object):
def __init__(self):
self.state = {}
@classmethod
def FromModel(cls, model: torch.nn.Module, device: str = ""):
ret = cls()
ret.save_from(model, device)
return ret
def save_from(self, model: torch.nn.Module, device: str = ""):
"""Save model state from `model` to this object"""
for name, val in self.get_model_state_iterator(model):
val = val.detach().clone()
self.state[name] = val.to(device) if device else val
def apply_to(self, model: torch.nn.Module):
"""Apply state to `model` from this object"""
with torch.no_grad():
for name, val in self.get_model_state_iterator(model):
assert (
name in self.state
), f"Name {name} not existed, available names {self.state.keys()}"
val.copy_(self.state[name])
def get_ema_model(self, model):
ret = copy.deepcopy(model)
self.apply_to(ret)
return ret
@property
def device(self):
if not self.has_inited():
return None
return next(iter(self.state.values())).device
def to(self, device):
for name in self.state:
self.state[name] = self.state[name].to(device)
return self
def has_inited(self):
return self.state
def clear(self):
self.state.clear()
return self
def get_model_state_iterator(self, model):
param_iter = model.named_parameters()
buffer_iter = model.named_buffers()
return itertools.chain(param_iter, buffer_iter)
def state_dict(self):
return self.state
def load_state_dict(self, state_dict, strict: bool = True):
self.clear()
for x, y in state_dict.items():
self.state[x] = y
return torch.nn.modules.module._IncompatibleKeys(
missing_keys=[], unexpected_keys=[]
)
def __repr__(self):
ret = f"EMAState(state=[{','.join(self.state.keys())}])"
return ret
这样在d2的Trainer中,创建model的同时也定义EMA,添加后model会多一个model_ema属性,它是EMAState的一个实例:
def may_build_model_ema(cfg, model):
if not cfg.MODEL_EMA.ENABLED:
return
model = _remove_ddp(model)
assert not hasattr(
model, "ema_state"
), "Name `ema_state` is reserved for model ema."
model.ema_state = EMAState() # 添加到model的属性中
logger.info("Using Model EMA.")
class Trainer(DefaultTrainer):
# override build_model,在里面添加ema
@classmethod
def build_model(cls, cfg):
"""
Returns:
torch.nn.Module:
It now calls :func:`detectron2.modeling.build_model`.
Overwrite it if you'd like a different model.
"""
model = build_model(cfg)
logger = logging.getLogger(__name__)
logger.info("Model:\n{}".format(model))
# add model EMA if enabled
model_ema.may_build_model_ema(cfg, model)
return model
上面实现了ema的添加,但是在训练后还需要保存ema参数,这可以通过d2的DetectionCheckpointer来实现,DetectionCheckpointer在创建时可以传入额外的checkpointable objects,在save和load时除了模型参数也会同步对这些objects进行保存和加载。checkpointable objects需要实现两个方法:state_dict()和load_state_dict(),而前面定义的EMAState类也包含了这两个方法,用于save和load对应的ema参数。具体的实现代码如下:
class Trainer(DefaultTrainer):
def __init__(self, cfg):
# add model EMA
kwargs = {
'trainer': weakref.proxy(self),
}
kwargs.update(model_ema.may_get_ema_checkpointer(cfg, model)) # 添加ema到checkpointables
self.checkpointer = DetectionCheckpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
cfg.OUTPUT_DIR,
**kwargs,
)
上面完成了第一个部分,就是在模型中添加ema参数,第二个要做的工作就是实现ema参数在训练过程的更新,首先定义一个EMAUpdater,其中update方法用来进行一次ema更新:
class EMAUpdater(object):
"""Model Exponential Moving Average
Keep a moving average of everything in the model state_dict (parameters and
buffers). This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
Note: It's very important to set EMA for ALL network parameters (instead of
parameters that require gradient), including batch-norm moving average mean
and variance. This leads to significant improvement in accuracy.
For example, for EfficientNetB3, with default setting (no mixup, lr exponential
decay) without bn_sync, the EMA accuracy with EMA on params that requires
gradient is 79.87%, while the corresponding accuracy with EMA on all params
is 80.61%.
Also, bn sync should be switched on for EMA.
"""
def __init__(self, state: EMAState, decay: float = 0.999, device: str = ""):
self.decay = decay
self.device = device
self.state = state
def init_state(self, model):
self.state.clear()
self.state.save_from(model, self.device)
def update(self, model):
with torch.no_grad():
for name, val in self.state.get_model_state_iterator(model):
ema_val = self.state.state[name]
if self.device:
val = val.to(self.device)
# 指数移动平均
ema_val.copy_(ema_val * self.decay + val * (1.0 - self.decay))
要实现训练过程中的更新,可以采用hook的方式,这里定义一个EMAHook,这里主要是在after_step方法中加入ema的update:
class EMAHook(HookBase):
def __init__(self, cfg, model):
model = _remove_ddp(model)
assert cfg.MODEL_EMA.ENABLED
assert hasattr(
model, "ema_state"
), "Call `may_build_model_ema` first to initilaize the model ema"
self.model = model
self.ema = self.model.ema_state
self.device = cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE
self.ema_updater = EMAUpdater(
self.model.ema_state, decay=cfg.MODEL_EMA.DECAY, device=self.device
)
def before_train(self):
if self.ema.has_inited():
self.ema.to(self.device)
else:
self.ema_updater.init_state(self.model)
def after_train(self):
pass
def before_step(self):
pass
def after_step(self):
if not self.model.train:
return
self.ema_updater.update(self.model)
然后把EMAHook加到trainer中的hooks里:
def build_hooks(self):
"""
Build a list of default hooks, including timing, evaluation,
checkpointing, lr scheduling, precise BN, writing events.
Returns:
list[HookBase]:
"""
cfg = self.cfg.clone()
cfg.defrost()
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
ret = [
hooks.IterationTimer(),
model_ema.EMAHook(self.cfg, self.model) if cfg.MODEL_EMA.ENABLED else None, # add EMA hook
hooks.LRScheduler(),
hooks.PreciseBN(
# Run at the same freq as (but before) evaluation.
cfg.TEST.EVAL_PERIOD,
self.model,
# Build a new data loader to not affect training
self.build_train_loader(cfg),
cfg.TEST.PRECISE_BN.NUM_ITER,
)
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
else None,
]
最后一个要实现的就是如何在测试时采用ema参数,这里采用的方法是每次进行test时,先将model参数保存一个副本,然后用ema参数替换,完成测试后再用保存的副本复原回来,在实现上,可以采用python的上下文管理器来巧妙地实现:
@contextmanager
def apply_model_ema_and_restore(model, state=None):
"""Apply ema stored in `model` to model and returns a function to restore
the weights are applied
"""
model = _remove_ddp(model)
if state is None:
state = get_model_ema_state(model)
old_state = EMAState.FromModel(model, state.device) # 创建当前模型参数副本
state.apply_to(model) # 用ema替换模型参数
yield old_state
old_state.apply_to(model) # 恢复模型参数
用这个上下文管理器对test进行包装,就可以实现想要的效果了:
@classmethod
def do_test(cls, cfg, model, evaluators=None):
# model with ema weights
logger = logging.getLogger("detectron2")
if cfg.MODEL_EMA.ENABLED:
logger.info("Run evaluation with EMA.")
with model_ema.apply_model_ema_and_restore(model):
results = cls.test(cfg, model, evaluators=evaluators)
else:
results = cls.test(cfg, model, evaluators=evaluators)
return results
完整的代码放在了github上,欢迎试用和star(https://github.com/xiaohu2015/detectron2_ema)。我初步用RetinaNet_R_50_FPN_1x测试的话,采用ema比原始效果要好一点(37.23 vs 37.18),而YOLOv5采用ema能提升1~2个点的。在YOLOv5中,ema的实现有一个额外的trick,那就是在训练前期,采用较小的decay,然后逐步增到默认值,因为前期模型训练速度快,应该对ema参数更新更激进一些,具体的实现如下:
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
这个实现应该很容易在d2的EMA中添加,有时间再更新(mmdet的ema已经实现这个功能了)。
参考
fvcore d2go yolov5
推荐阅读
谷歌AI用30亿数据训练了一个20亿参数Vision Transformer模型,在ImageNet上达到新的SOTA!
"未来"的经典之作ViT:transformer is all you need!
PVT:可用于密集任务backbone的金字塔视觉transformer!
涨点神器FixRes:两次超越ImageNet数据集上的SOTA
不妨试试MoCo,来替换ImageNet上pretrain模型!
机器学习算法工程师
一个用心的公众号