如何使用PyTorch的量化功能?
点击上方“AI算法与图像处理”,选择加"星标"或“置顶”
重磅干货,第一时间送达
来源:paperweekly

背景
- 更少的模型体积,接近 4 倍的减少; 
- 可以更快的计算,由于更少的内存访问和更快的 int8 计算,可以快 2~4 倍。 
- Post Training Dynamic Quantization,模型训练完毕后的动态量化; 
- Post Training Static Quantization,模型训练完毕后的静态量化; 
- QAT(Quantization Aware Training),模型训练中开启量化。 

Tensor的量化
>>> x = torch.rand(2,3, dtype=torch.float32) 
>>> x
tensor([[0.6839, 0.4741, 0.7451],
        [0.9301, 0.1742, 0.6835]])
>>> xq = torch.quantize_per_tensor(x, scale = 0.5, zero_point = 8, dtype=torch.quint8)
tensor([[0.5000, 0.5000, 0.5000],
        [1.0000, 0.0000, 0.5000]], size=(2, 3), dtype=torch.quint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.5, zero_point=8)
>>> xq.int_repr()
tensor([[ 9,  9,  9],
        [10,  8,  9]], dtype=torch.uint8)
xq = round(x / scale + zero_point)
xq = round(zero_point) = zero_point
# xq is a quantized tensor with data represented as quint8
>>> xdq = xq.dequantize()
>>> xdq
tensor([[0.5000, 0.5000, 0.5000],
        [1.0000, 0.0000, 0.5000]])
xdq = (xq - zero_point) * scale
- 量化会有精度损失; 
- 我们这里随便选取的 scale 和 zp 太烂,选择合适的 scale 和 zp 可以有效降低精度损失。不信你把 scale 和 zp 分别换成 scale = 0.0036, zero_point = 0试试。 

Post Training Dynamic Quantization
- Post:也就是训练完成后再量化模型的权重参数; 
- Dynamic:也就是网络在前向推理的时候动态的量化 float32 类型的输入。 
torch.quantization.quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False)
- Linear 
- LSTM 
- LSTMCell 
- RNNCell 
- GRUCell 
- qconfig_spec 指定了一组 qconfig,具体就是哪个 op 对应哪个 qconfig; 
- 每个 qconfig 是 QConfig 类的实例,封装了两个 observer; 
- 这两个 observer 分别是 activation 的 observer 和 weight 的 observer; 
- 但是动态量化使用的是 QConfig 子类 QConfigDynamic 的实例,该实例实际上只封装了 weight 的 observer; 
- activate 就是 post process,就是 op forward 之后的后处理,但在动态量化中不包含; 
- observer 用来根据四元组(min_val,max_val,qmin, qmax)来计算 2 个量化的参数:scale 和 zero_point; 
- qmin、qmax 是算法提前确定好的,min_val 和 max_val 是从输入数据中观察到的,所以起名叫 observer。 
- qconfig_spec 赋值为一个 set,比如:{nn.LSTM, nn.Linear},意思是指定当前模型中的哪些 layer 要被 dynamic quantization; 
- qconfig_spec 赋值为一个 dict,key 为 submodule 的 name 或 type,value 为 QConfigDynamic 实例(其包含了特定的 Observer,比如 MinMaxObserver、MovingAverageMinMaxObserver、PerChannelMinMaxObserver、MovingAveragePerChannelMinMaxObserver、HistogramObserver)。 
qconfig_spec = {
                nn.Linear : default_dynamic_qconfig,
                nn.LSTM : default_dynamic_qconfig,
                nn.GRU : default_dynamic_qconfig,
                nn.LSTMCell : default_dynamic_qconfig,
                nn.RNNCell : default_dynamic_qconfig,
                nn.GRUCell : default_dynamic_qconfig,
            }
default_dynamic_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer, weight=default_weight_observer)
default_dynamic_quant_observer = PlaceholderObserver.with_args(dtype=torch.float, compute_dtype=torch.quint8)
default_weight_observer = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)
class CivilNet(nn.Module):
    def __init__(self):
        super(CivilNet, self).__init__()
        gemfieldin = 1
        gemfieldout = 1
        self.conv = nn.Conv2d(gemfieldin, gemfieldout, kernel_size=1, stride=1, padding=0, groups=1, bias=False)
        self.fc = nn.Linear(3, 2,bias=False)
        self.relu = nn.ReLU(inplace=False)
    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        x = self.relu(x)
        return x
#原始网络
CivilNet(
  (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (fc): Linear(in_features=3, out_features=2, bias=False)
  (relu): ReLU()
)
#quantize_dynamic 后
CivilNet(
  (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (fc): DynamicQuantizedLinear(in_features=3, out_features=2, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
)
# Default map for swapping dynamic modules
DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS = {
    nn.GRUCell: nnqd.GRUCell,
    nn.Linear: nnqd.Linear,
    nn.LSTM: nnqd.LSTM,
    nn.LSTMCell: nnqd.LSTMCell,
    nn.RNNCell: nnqd.RNNCell,
}
new_mod = mapping[type(mod)].from_float(mod)
- 使用 MinMaxObserver 计算模型中 op 权重参数中 tensor 的最大值最小值(这个例子中只有 Linear op),缩小量化时原始值的取值范围,提高量化的精度; 
- 通过上述步骤中得到四元组中的 min_val 和 max_val,再结合算法确定的 qmin, qmax 计算出 scale 和 zp,参考前文“Tensor的量化”小节,计算得到量化后的weight,这个量化过程有torch.quantize_per_tensor 和 torch.quantize_per_channel两种,默认是前者(因为qchema默认是torch.per_tensor_affine); 
- 实例化 nnqd.Linear,然后使用 qlinear.set_weight_bias 将量化后的 weight 和原始的 bias 设置到新的 layer 上。其中最后一步还涉及到 weight 和 bias 的打包,在源代码中是这样的: 
#ifdef USE_FBGEMM
    if (ctx.qEngine() == at::QEngine::FBGEMM) {
      return PackedLinearWeight::prepack(std::move(weight), std::move(bias));
    }
#endif
#ifdef USE_PYTORCH_QNNPACK
    if (ctx.qEngine() == at::QEngine::QNNPACK) {
      return PackedLinearWeightsQnnp::prepack(std::move(weight), std::move(bias));
    }
#endif
    TORCH_CHECK(false,"Didn't find engine for operation quantized::linear_prepack ",toString(ctx.qEngine()));
#input
torch.Tensor([[[[-1,-2,-3],[1,2,3]]]])
#经过卷积后(权重为torch.Tensor([[[[-0.7867]]]]))
torch.Tensor([[[[ 0.7867,  1.5734,  2.3601],[-0.7867, -1.5734, -2.3601]]]])
#经过fc后(权重为torch.Tensor([[ 0.4097, -0.2896, -0.4931], [-0.3738, -0.5541,  0.3243]]) )
torch.Tensor([[[[-1.2972, -0.4004], [1.2972,  0.4004]]]])
#经过relu后
torch.Tensor([[[[0.0000, 0.0000],[1.2972, 0.4004]]]])
#input
torch.Tensor([[[[-1,-2,-3],[1,2,3]]]])
#经过卷积后(权重为torch.Tensor([[[[-0.7867]]]]))
torch.Tensor([[[[ 0.7867,  1.5734,  2.3601],[-0.7867, -1.5734, -2.3601]]]])
#经过fc后(权重为torch.Tensor([[ 0.4085, -0.2912, -0.4911],[-0.3737, -0.5563,  0.3259]], dtype=torch.qint8,scale=0.0043458822183310986,zero_point=0) )
torch.Tensor([[[[-1.3038, -0.3847], [1.2856,  0.3969]]]])
#经过relu后
torch.Tensor([[[[0.0000, 0.0000], [1.2856, 0.3969]]]])
- scale = max_val / (float(qmax - qmin) / 2) = 0.5541 / ((127 + 128) / 2) = 0.004345882... 
- zp = 0 
#ifdef USE_FBGEMM
at::Tensor PackedLinearWeight::apply_dynamic_impl(at::Tensor input, bool reduce_range) {
  ...
  fbgemm::xxxx
  ...
}
#endif // USE_FBGEMM
#ifdef USE_PYTORCH_QNNPACK
at::Tensor PackedLinearWeightsQnnp::apply_dynamic_impl(at::Tensor input) {
  ...
  qnnpack::qnnpackLinearDynamic(xxxx)
  ...
}
#endif // USE_PYTORCH_QNNPACK
Tensor q_input = at::quantize_per_tensor(input_contig, q_params.scale, q_params.zero_point, c10::kQUInt8);
requant_scale = input_scale_fp32 * weight_scale_fp32 / output_scale_fp32
auto output_scale = 1.f
auto inverse_output_scale = 1.f /output_scale;
requant_scales[i] = (weight_scales_data[i] * input_scale) * inverse_output_scale;
#原始的模型,所有的tensor和计算都是浮点型
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
                 /
linear_weight_fp32
#动态量化后的模型,Linear和LSTM的权重是int8
previous_layer_fp32 -- linear_int8_w_fp32_inp -- activation_fp32 -- next_layer_fp32
                     /
   linear_weight_int8

Post Training Static Quantization
fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None)
torch.quantization.fuse_modules(gemfield_model, [['conv1', 'bn1', 'relu1']], inplace=True)
class CivilNet(nn.Module):
    def __init__(self):
        super(CivilNet, self).__init__()
        syszuxin = 1
        syszuxout = 1
        self.conv = nn.Conv2d(syszuxin, syszuxout, kernel_size=1, stride=1, padding=0, groups=1, bias=False)
        self.fc = nn.Linear(3, 2,bias=False)
        self.relu = nn.ReLU(inplace=False)
    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        x = self.relu(x)
        return x
CivilNet(
  (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (fc): Linear(in_features=3, out_features=2, bias=False)
  (relu): ReLU()
)
CivilNet(
  (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (fc): LinearReLU(
    (0): Linear(in_features=3, out_features=2, bias=False)
    (1): ReLU()
  )
  (relu): Identity()
)
torch.quantization.fuse_modules(a_sequential_module, ['0', '1', '2'], inplace=True)
- Convolution, Batch normalization 
- Convolution, Batch normalization, Relu 
- Convolution, Relu 
- Linear, Relu 
- Batch normalization, Relu 
DEFAULT_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
    (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
    (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
    (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
    (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
    (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
    (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
    (nn.Conv1d, nn.ReLU): nni.ConvReLU1d,
    (nn.Conv2d, nn.ReLU): nni.ConvReLU2d,
    (nn.Conv3d, nn.ReLU): nni.ConvReLU3d,
    (nn.Linear, nn.ReLU): nni.LinearReLU,
    (nn.BatchNorm2d, nn.ReLU): nni.BNReLU2d,
    (nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d,
}
#如果要部署在x86 server上
gemfield_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
#如果要部署在ARM上
gemfield_model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
def get_default_qconfig(backend='fbgemm'):
    if backend == 'fbgemm':
        qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),weight=default_per_channel_weight_observer)
    elif backend == 'qnnpack':
        qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),weight=default_weight_observer)
    else:
        qconfig = default_qconfig
    return qconfig
| 量化的 backend | activation | weight | 
| fbgemm | HistogramObserver | PerChannelMin MaxObserver (default_per_channel _weight_observer) | 
| qnnpack | HistogramObserver | MinMaxObserver (default_weight _observer) | 
| 默认(非 fbgemm和qnnpack) | MinMaxObserver (default_observer) | MinMaxObserver (default_weight _observer) | 
gemfield_model_prepared = torch.quantization.prepare(gemfield_model)
CivilNet(
  (conv): Conv2d(
    1, 1, kernel_size=(1, 1), stride=(1, 1), bias=False
    (activation_post_process): HistogramObserver()
  )
  (fc): Linear(
    in_features=3, out_features=2, bias=False
    (activation_post_process): HistogramObserver()
  )
  (relu): ReLU(
    (activation_post_process): HistogramObserver()
  )
  (quant): QuantStub(
    (activation_post_process): HistogramObserver()
  )
  (dequant): DeQuantStub()
)
#至少观察个几百迭代
for data in data_loader:
    gemfield_model_prepared(data)
gemfield_model_prepared_int8 = torch.quantization.convert(gemfield_model_prepared)
DEFAULT_STATIC_QUANT_MODULE_MAPPINGS = {
    QuantStub: nnq.Quantize,
    DeQuantStub: nnq.DeQuantize,
    nn.BatchNorm2d: nnq.BatchNorm2d,
    nn.BatchNorm3d: nnq.BatchNorm3d,
    nn.Conv1d: nnq.Conv1d,
    nn.Conv2d: nnq.Conv2d,
    nn.Conv3d: nnq.Conv3d,
    nn.ConvTranspose1d: nnq.ConvTranspose1d,
    nn.ConvTranspose2d: nnq.ConvTranspose2d,
    nn.ELU: nnq.ELU,
    nn.Embedding: nnq.Embedding,
    nn.EmbeddingBag: nnq.EmbeddingBag,
    nn.GroupNorm: nnq.GroupNorm,
    nn.Hardswish: nnq.Hardswish,
    nn.InstanceNorm1d: nnq.InstanceNorm1d,
    nn.InstanceNorm2d: nnq.InstanceNorm2d,
    nn.InstanceNorm3d: nnq.InstanceNorm3d,
    nn.LayerNorm: nnq.LayerNorm,
    nn.LeakyReLU: nnq.LeakyReLU,
    nn.Linear: nnq.Linear,
    nn.ReLU6: nnq.ReLU6,
    # Wrapper Modules:
    nnq.FloatFunctional: nnq.QFunctional,
    # Intrinsic modules:
    nni.BNReLU2d: nniq.BNReLU2d,
    nni.BNReLU3d: nniq.BNReLU3d,
    nni.ConvReLU1d: nniq.ConvReLU1d,
    nni.ConvReLU2d: nniq.ConvReLU2d,
    nni.ConvReLU3d: nniq.ConvReLU3d,
    nni.LinearReLU: nniq.LinearReLU,
    nniqat.ConvBn1d: nnq.Conv1d,
    nniqat.ConvBn2d: nnq.Conv2d,
    nniqat.ConvBnReLU1d: nniq.ConvReLU1d,
    nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
    nniqat.ConvReLU2d: nniq.ConvReLU2d,
    nniqat.LinearReLU: nniq.LinearReLU,
    # QAT modules:
    nnqat.Linear: nnq.Linear,
    nnqat.Conv2d: nnq.Conv2d,
} 
- QuantStub 的 scale 和 zp 是怎么来的(静态量化需要插入 QuantStub,后文有说明)? 
- conv activation 的 scale 和 zp 是怎么来的? 
- conv weight 的 scale 和 zp 是怎么来的? 
- fc activation 的 scale 和 zp 是怎么来的? 
- fc weight 的 scale 和 zp 是怎么来的? 
- relu activation 的 scale 和 zp 是怎么来的? 
- relu weight 的...等等,relu 没有 weight。 
if self.dtype == torch.qint8:
    if self.reduce_range:
        qmin, qmax = -64, 63
    else:
        qmin, qmax = -128, 127
else:
    if self.reduce_range:
        qmin, qmax = 0, 127
    else:
        qmin, qmax = 0, 255
#qscheme 是 torch.per_tensor_symmetric 或者torch.per_channel_symmetric时
max_val = torch.max(-min_val, max_val)
scale = max_val / (float(qmax - qmin) / 2)
scale = torch.max(scale, torch.tensor(self.eps, device=device, dtype=scale.dtype))
if self.dtype == torch.quint8:
    zero_point = zero_point.new_full(zero_point.size(), 128)
#qscheme 是 torch.per_tensor_affine时
scale = (max_val - min_val) / float(qmax - qmin)
scale = torch.max(scale, torch.tensor(self.eps, device=device, dtype=scale.dtype))
zero_point = qmin - torch.round(min_val / scale)
zero_point = torch.max(zero_point, torch.tensor(qmin, device=device, dtype=zero_point.dtype))
zero_point = torch.min(zero_point, torch.tensor(qmax, device=device, dtype=zero_point.dtype))
- scale = 0.7898 / ((127 + 128)/2 ) = 0.0062 
- zp = 0 
- scale = (2.9971 + 3) / (127 - 0) = 0.0472 
- zp = 0 - round(-3 /0.0472) = 64 
#原始的CivilNet网络:
CivilNet(
  (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (fc): Linear(in_features=3, out_features=2, bias=False)
  (relu): ReLU()
)
#静态量化后的CivilNet网络:
CivilNet(
  (conv): QuantizedConv2d(1, 1, kernel_size=(1, 1), stride=(1, 1), scale=0.0077941399067640305, zero_point=0, bias=False)
  (fc): QuantizedLinear(in_features=3, out_features=2, scale=0.002811126410961151, zero_point=14, qscheme=torch.per_channel_affine)
  (relu): QuantizedReLU()
)
import torch
import torch.nn as nn
class CivilNet(nn.Module):
    def __init__(self):
        super(CivilNet, self).__init__()
        in_planes = 1
        out_planes = 1
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, groups=1, bias=False)
        self.fc = nn.Linear(3, 2,bias=False)
        self.relu = nn.ReLU(inplace=False)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.fc(x)
        x = self.relu(x)
        x = self.dequant(x)
        return x
def forward(self, X):
    return torch.quantize_per_tensor(X, float(self.scale), int(self.zero_point), self.dtype)
def forward(self, Xq):
    return Xq.dequantize()
c = CivilNet()
t = torch.Tensor([[[[-1,-2,-3],[1,2,3]]]])
c(t)
#input
torch.Tensor([[[[-1,-2,-3],[1,2,3]]]])
#经过卷积后(权重为torch.Tensor([[[[-0.7867]]]]))
torch.Tensor([[[[ 0.7867,  1.5734,  2.3601],[-0.7867, -1.5734, -2.3601]]]])
#经过fc后(权重为torch.Tensor([[ 0.4097, -0.2896, -0.4931], [-0.3738, -0.5541,  0.3243]]) )
torch.Tensor([[[[-1.2972, -0.4004], [1.2972,  0.4004]]]])
#经过relu后
torch.Tensor([[[[0.0000, 0.0000],[1.2972, 0.4004]]]])
#input
torch.Tensor([[[[-1,-2,-3],[1,2,3]]]])
#QuantStub后 (scale=tensor([0.0472]), zero_point=tensor([64]))
tensor([[[[-0.9916, -1.9833, -3.0221],[ 0.9916,  1.9833,  3.0221]]]],
       dtype=torch.quint8, scale=0.04722102731466293, zero_point=64)
#经过卷积后(权重为torch.Tensor([[[[-0.7898]]]], dtype=torch.qint8, scale=0.0062, zero_point=0))
#conv activation(输入)的scale为0.03714831545948982,zp为64
torch.Tensor([[[[ 0.7801,  1.5602,  2.3775],[-0.7801, -1.5602, -2.3775]]]], scale=0.03714831545948982, zero_point=64)
#经过fc后(权重为torch.Tensor([[ 0.4100, -0.2901, -0.4951],[-0.3737, -0.5562,  0.3259]], dtype=torch.qint8, scale=tensor([0.0039, 0.0043]),zero_point=tensor([0, 0])) )
#fc activation(输入)的scale为0.020418135449290276, zp为64
torch.Tensor([[[[-1.3068, -0.3879],[ 1.3068,  0.3879]]]], dtype=torch.quint8, scale=0.020418135449290276, zero_point=64)
#经过relu后
torch.Tensor([[[[0.0000, 0.0000],[1.3068, 0.3879]]]], dtype=torch.quint8, scale=0.020418135449290276, zero_point=64)
#经过DeQuantStub后
torch.Tensor([[[[0.0000, 0.0000],[1.3068, 0.3879]]]])
import torch
import torch.nn.quantized as nnq
#输入
>>> x = torch.Tensor([[[[-1,-2,-3],[1,2,3]]]])
>>> x
tensor([[[[-1., -2., -3.],
          [ 1.,  2.,  3.]]]])
#经过QuantStub
>>> xq = torch.quantize_per_tensor(x, scale = 0.0472, zero_point = 64, dtype=torch.quint8)
>>> xq
tensor([[[[-0.9912, -1.9824, -3.0208],
          [ 0.9912,  1.9824,  3.0208]]]], size=(1, 1, 2, 3),
       dtype=torch.quint8, quantization_scheme=torch.per_tensor_affine,
       scale=0.0472, zero_point=64)
>>> xq.int_repr()
tensor([[[[ 43,  22,   0],
          [ 85, 106, 128]]]], dtype=torch.uint8)
>>> c = nnq.Conv2d(1,1,1)
>>> weight = torch.Tensor([[[[-0.7898]]]])
>>> qweight = torch.quantize_per_channel(weight, scales=torch.Tensor([0.0062]).to(torch.double), zero_points = torch.Tensor([0]).to(torch.int64), axis=0, dtype=torch.qint8)
>>> c.set_weight_bias(qweight, None)
>>> c.scale = 0.03714831545948982
>>> c.zero_point = 64
>>> x = c(xq)
>>> x
tensor([[[[ 0.7801,  1.5602,  2.3775],
          [-0.7801, -1.5602, -2.3775]]]], size=(1, 1, 2, 3),
       dtype=torch.quint8, quantization_scheme=torch.per_tensor_affine,
       scale=0.03714831545948982, zero_point=64)
def forward(self, input):
    return ops.quantized.conv2d(input, self._packed_params, self.scale, self.zero_point)
def forward(self, x):
    return torch.ops.quantized.linear(x, self._packed_params._packed_params, self.scale, self.zero_point)
#原始的模型,所有的tensor和计算都是浮点型
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
                    /
    linear_weight_fp32
#静态量化的模型,权重和输入都是int8
previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8
                    /
  linear_weight_int8
- 静态量化的 float 输入必经 QuantStub 变为 int,此后到输出之前都是 int; 
- 动态量化的 float 输入是经动态计算的 scale 和 zp 量化为 int,op 输出时转换回 float。 

QAT(Quantization Aware Training)
cnet = CivilNet()
cnet.train()
cnet.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
#activation的observer的参数
FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=0,quant_max=255,reduce_range=True)
#权重的observer的参数
FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
                                                               quant_min=-128,
                                                               quant_max=127,
                                                               dtype=torch.qint8,
                                                               qscheme=torch.per_channel_symmetric,
                                                               reduce_range=False,
                                                               ch_axis=0)

- Xmin、Xmax 是当前运行中正在求解和最终求解的最小值、最大值; 
- X 是当前输入的 tensor; 
- c 是一个常数,PyTorch 中默认为 0.01,也就是最新一次的极值由上一次贡献 99%,当前的 tensor 贡献 1%。 
- prepare_qa t 要把 qconfig 安插到每个 op 上,qconfig 的内容本身就不同,参考五部曲中的第一步; 
- prepare_qat 中需要多做一步转换子 module 的工作,需要 inplace 的把模型中的一些子 module 替换了,替换的逻辑就是从 DEFAULT_QAT_MODULE_MAPPINGS 的 key 替换为 value,这个字典的定义如下: 
# Default map for swapping float module to qat modules
DEFAULT_QAT_MODULE_MAPPINGS : Dict[Callable, Any] = {
    nn.Conv2d: nnqat.Conv2d,
    nn.Linear: nnqat.Linear,
    # Intrinsic modules:
    nni.ConvBn1d: nniqat.ConvBn1d,
    nni.ConvBn2d: nniqat.ConvBn2d,
    nni.ConvBnReLU1d: nniqat.ConvBnReLU1d,
    nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
    nni.ConvReLU2d: nniqat.ConvReLU2d,
    nni.LinearReLU: nniqat.LinearReLU
}
CivilNet(
  (conv): QATConv2d(
    1, 1, kernel_size=(1, 1), stride=(1, 1), bias=False
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            scale=tensor([1.]), zero_point=tensor([0])
      (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
    )
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            scale=tensor([1.]), zero_point=tensor([0])
      (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
    )
  )
  (fc): QATLinear(
    in_features=3, out_features=2, bias=False
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            scale=tensor([1.]), zero_point=tensor([0])
      (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
    )
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            scale=tensor([1.]), zero_point=tensor([0])
      (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
    )
  )
  (relu): ReLU(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            scale=tensor([1.]), zero_point=tensor([0])
      (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
    )
  )
  (quant): QuantStub(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            scale=tensor([1.]), zero_point=tensor([0])
      (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
    )
  )
  (dequant): DeQuantStub()
)
def forward(self, input):
   return self.activation_post_process(self._conv_forward(input, self.weight_fake_quant(self.weight)))
def forward(self, input):
    return self.activation_post_process(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
#conv2d
weight=functools.partial(<class 'torch.quantization.fake_quantize.FakeQuantize'>, 
           observer=<class 'torch.quantization.observer.MovingAveragePerChannelMinMaxObserver'>, 
           quant_min=-128, quant_max=127, dtype=torch.qint8, 
           qscheme=torch.per_channel_symmetric, reduce_range=False, ch_axis=0))
activation=functools.partial(<class 'torch.quantization.fake_quantize.FakeQuantize'>, 
            observer=<class 'torch.quantization.observer.MovingAverageMinMaxObserver'>, 
            quant_min=0, quant_max=255, reduce_range=True)
def forward(self, X):
        if self.observer_enabled[0] == 1:
            #使用移动平均算法计算scale和zp
        if self.fake_quant_enabled[0] == 1:
            X = torch.fake_quantize_per_channel_or_tensor_affine(X...)
        return X
DEFAULT_STATIC_QUANT_MODULE_MAPPINGS = {
    ......
    # QAT modules:
    nnqat.Linear: nnq.Linear,
    nnqat.Conv2d: nnq.Conv2d,
} 
# 原始的模型,所有的tensor和计算都是浮点
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
                      /
    linear_weight_fp32
# 训练过程中,fake_quants发挥作用
previous_layer_fp32 -- fq -- linear_fp32 -- activation_fp32 -- fq -- next_layer_fp32
                           /
   linear_weight_fp32 -- fq
# 量化后的模型进行推理,权重和输入都是int8
previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8
                     /
   linear_weight_int8

总结
https://github.com/DeepVAC/deepvac
参考文献
个人微信(如果没有备注不拉群!) 请注明:地区+学校/企业+研究方向+昵称 
下载1:何恺明顶会分享
在「AI算法与图像处理」公众号后台回复:何恺明,即可下载。总共有6份PDF,涉及 ResNet、Mask RCNN等经典工作的总结分析
下载2:终身受益的编程指南:Google编程风格指南
在「AI算法与图像处理」公众号后台回复:c++,即可下载。历经十年考验,最权威的编程规范!
下载3 CVPR2020 在「AI算法与图像处理」公众号后台回复:CVPR2020,即可下载1467篇CVPR 2020论文 
觉得不错就点亮在看吧


