100行代码使用torch.fx极简量化教程

共 9378字,需浏览 19分钟

 ·

2022-04-18 14:31

↑ 点击蓝字 关注极市平台

作者丨金天@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/498286238
编辑丨极市平台

极市导读

 

本文使用100行代码,极简的教大家入门比较标准的量化步骤,从怎么用、用在哪里、哪里不能用等问题都将涵盖。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

网上很多关于量化的文章,要么就是跑一跑官方残缺的例子,要么就是过旧的API,早已经不潮流。现在比较fashion的方式,是使用 torch.fx来做量化。本文将使用100行代码,极简的教你入门比较标准的量化步骤。这些步骤不是简单的告诉你torch.fx有什么卵用,大家都知道它有什么卵用,只是怎么用,用在哪里,哪里不能用,这些问题需要解答。本文100行代码,麻雀虽小五脏俱全,不管你量化什么模型,一顿套用就是了,出了问题我背锅。

很多古老的文章,还在用手动插入stub来做量化节点,这就好比在21世纪还在飞鸽传书。我们必然会包含一下几个完整的内容:

  • fx怎么插入量化节点,不要吓倒,这就一行代码;
  • 量化的模型怎么保存权重到本地;
  • 怎么把量化后的权重再load回来;
  • 怎么做calibration,做跟不做区别多大;
  • fx到底有没有局限性;

以上问题,本文都将囊括。

量化前期知识

此处省略三万字,具体大家清百度。没啥好讲的。

量化现状

如果你要问我现在最好的量化工具是什么,我的回答是没有。真的,不管是 nni,还是 nvidia的 pytorch_quantization ,还是nncf so on,不是说这些东西不好,而是在做的各位都是垃圾。

这些东西本质上是在做一件事情,至少从量化角度上看是这样的,但是到最后不具备通用性,当你看到 pytorch_quanzation 这个工具保存的模型体积根float32一样的时候,就会开始怀疑人生了,这tm是人干的事儿?这就好比普通人想要中杯,他便要说这是大杯。

轮子不好用,那就只能自己造轮子了。只能说,torch.fxyyds. 用了都说好,谁用谁知道。

100行代码

talk is cheap,我们直接上代码。需要注意的是,torch.fx最好使用最新的stable版本,老版本API或有不同之处,我测试的是 `1.11`。

由于pytorch的自带的 imagnet系列模型,我们没有办法做calibration,我们用小一些的Cifra10, 不需要下载,pytorch自己可以处理,但是这就需要我们自己finetune一下。

先把finetune的代码备好:

这只是用来fintune一个我们准备去量化,并且校准的模型:

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import torchvision
from torchvision import transforms
from torchvision.models.resnet import resnet50, resnet18
from torch.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization.fx.graph_module import ObservedGraphModule
from torch.quantization import (
    get_default_qconfig,
)
from torch import optim
import os
import time


def train_model(model, train_loader, test_loader, device):
    # The training configurations were not carefully selected.
    learning_rate = 1e-2
    num_epochs = 20
    criterion = nn.CrossEntropyLoss()
    model.to(device)
    # It seems that SGD optimizer is better than Adam optimizer for ResNet18 training on CIFAR10.
    optimizer = optim.SGD(
        model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5
    )
    # optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
    for epoch in range(num_epochs):
        # Training
        model.train()

        running_loss = 0
        running_corrects = 0

        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = running_corrects / len(train_loader.dataset)

        # Evaluation
        model.eval()
        eval_loss, eval_accuracy = evaluate_model(
            model=model, test_loader=test_loader, device=device, criterion=criterion
        )
        print(
            "Epoch: {:02d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(
                epoch, train_loss, train_accuracy, eval_loss, eval_accuracy
            )
        )
    return model

def prepare_dataloader(num_workers=8, train_batch_size=128, eval_batch_size=256):
    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
    train_set = torchvision.datasets.CIFAR10(
        root="data", train=True, download=True, transform=train_transform
    )
    # We will use test set for validation and test in this project.
    # Do not use test set for validation in practice!
    test_set = torchvision.datasets.CIFAR10(
        root="data", train=False, download=True, transform=test_transform
    )
    train_sampler = torch.utils.data.RandomSampler(train_set)
    test_sampler = torch.utils.data.SequentialSampler(test_set)

    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=train_batch_size,
        sampler=train_sampler,
        num_workers=num_workers,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=eval_batch_size,
        sampler=test_sampler,
        num_workers=num_workers,
    )
    return train_loader, test_loader

然后训练一波模型:

if __name__ == "__main__":
    train_loader, test_loader = prepare_dataloader()

    # first finetune model on cifar, we don't have imagnet so using cifar as test
    model = resnet18(pretrained=True)
    model.fc = nn.Linear(512, 10)
    if os.path.exists("r18_row.pth"):
        model.load_state_dict(torch.load("r18_row.pth", map_location="cpu"))
    else:
        train_model(model, train_loader, test_loader, torch.device("cuda"))
        print("train finished.")
        torch.save(model.state_dict(), "r18_row.pth")

接下来就是核心代码:

def quant_fx(model):
    model.eval()
    qconfig = get_default_qconfig("fbgemm")
    qconfig_dict = {
        "": qconfig,
        # 'object_type': []
    }
    model_to_quantize = copy.deepcopy(model)
    prepared_model = prepare_fx(model_to_quantize, qconfig_dict)
    print("prepared model: ", prepared_model)

    quantized_model = convert_fx(prepared_model)
    print("quantized model: ", quantized_model)
    torch.save(model.state_dict(), "r18.pth")
    torch.save(quantized_model.state_dict(), "r18_quant.pth")

懂了吗?很快阿,啪一下,一个int8的量化模型就生成了。

没错,其实都不用100行,15行就够了。torch.fx 就是这么的牛逼!

我们做一个evaluation,来验证一下,在不校准的情况下,精度如何:

def evaluate_model(model, test_loader, device=torch.device("cpu"), criterion=None):
    t0 = time.time()
    model.eval()
    model.to(device)
    running_loss = 0
    running_corrects = 0
    for inputs, labels in test_loader:

        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        if criterion is not None:
            loss = criterion(outputs, labels).item()
        else:
            loss = 0

        # statistics
        running_loss += loss * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    eval_loss = running_loss / len(test_loader.dataset)
    eval_accuracy = running_corrects / len(test_loader.dataset)
    t1 = time.time()
    print(f"eval loss: {eval_loss}, eval acc: {eval_accuracy}, cost: {t1 - t0}")
    return eval_loss, eval_accuracy

这是evaluation的结果:

eval loss: 0.0, eval acc: 0.8476999998092651, cost: 2.8914074897766113
eval loss: 0.0, eval acc: 0.15240000188350677, cost: 1.240293264389038

可以看到,精度下降严重。此时需要进行一下校准,我直接放校准函数:

def calib_quant_model(model, calib_dataloader):
    assert isinstance(
        model, ObservedGraphModule
    ), "model must be a perpared fx ObservedGraphModule."
    model.eval()
    with torch.inference_mode():
        for inputs, labels in calib_dataloader:
            model(inputs)
    print("calib done.")

that's all. 就这么简单。

如果你有其他非分类模型,也可以直接把dataloader丢进来。请注意,这里的标签并没有用到。只需要统计数据的分布即可。

非常简单。

最后我们再次eval一下:

def quant_calib_and_eval(model):
    # test only on CPU
    model.to(torch.device("cpu"))
    model.eval()

    qconfig = get_default_qconfig("fbgemm")
    qconfig_dict = {
        "": qconfig,
        # 'object_type': []
    }

    model2 = copy.deepcopy(model)
    model_prepared = prepare_fx(model2, qconfig_dict)
    model_int8 = convert_fx(model_prepared)
    model_int8.load_state_dict(torch.load("r18_quant.pth"))
    model_int8.eval()

    a = torch.randn([1, 3, 224, 224])
    o1 = model(a)
    o2 = model_int8(a)

    diff = torch.allclose(o1, o2, 1e-4)
    print(diff)
    print(o1.shape, o2.shape)
    print(o1, o2)
    get_output_from_logits(o1)
    get_output_from_logits(o2)

    train_loader, test_loader = prepare_dataloader()
    evaluate_model(model, test_loader)
    evaluate_model(model_int8, test_loader)

    # calib quant model
    model2 = copy.deepcopy(model)
    model_prepared = prepare_fx(model2, qconfig_dict)
    model_int8 = convert_fx(model_prepared)
    torch.save(model_int8.state_dict(), "r18.pth")
    model_int8.eval()

    model_prepared = prepare_fx(model2, qconfig_dict)
    calib_quant_model(model_prepared, test_loader)
    model_int8 = convert_fx(model_prepared)
    torch.save(model_int8.state_dict(), "r18_quant_calib.pth")
    evaluate_model(model_int8, test_loader)

得到结果:

eval loss: 0.0, eval acc: 0.8476999998092651, cost: 2.8914074897766113
eval loss: 0.0, eval acc: 0.15240000188350677, cost: 1.240293264389038
calib done.
eval loss: 0.0, eval acc: 0.8442999720573425, cost: 1.2966759204864502

精度瞬间恢复了。速度快了超过一半。

总结

ok,我们用几十行代码就完成这个量化过程。并且使用校准,恢复了精度。由此可见fx的强大之处。

抛出一个问题,欢迎留言区解答:

  • torch.fx量化的模型,如果export 到onnx并使用其他前推引擎推理。



公众号后台回复“CVPR 2022”获取论文打包合集下载~

△点击卡片关注极市平台,获取最新CV干货
极市干货
数据集资源汇总:10个开源工业检测数据集汇总21个深度学习开源数据集分类汇总
算法trick目标检测比赛中的tricks集锦从39个kaggle竞赛中总结出来的图像分割的Tips和Tricks
技术综述:一文弄懂各种loss function工业图像异常检测最新研究总结(2019-2020)


CV技术社群邀请函 #

△长按添加极市小助手
添加极市小助手微信(ID : cvmart4)

备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳)


即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群


每月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企视觉开发者互动交流~


觉得有用麻烦给个在看啦~  


浏览 39
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报