TVM Quantization

Table of Contents

1. WAIT TVM Quantization

file:///home/sunway/Gitbox/code/hello_world/hello_tvm/graph_runner/run_model.py

TVM 对量化的支持分为两种:

  1. 通过 relay.quantize 完成浮点模型的量化
  2. 通过一种称为 qnn 的 `relay方言` (relay.qnn.xxx) 直接执行已经量化过的模型

1.1. relay.quantize

https://tvm.apache.org/docs/tutorials/frontend/deploy_quantized.html

quantize 分为三步:

  1. annotate
  2. calibrate
  3. realize
calibrate_pass = tvm.transform.module_pass(
    calibrate(dataset), opt_level=1, name="QuantizeCalibrate"
)
quant_passes = [
    partition(),
    annotate(),
    calibrate_pass,
    tvm.relay.transform.InferType(),
    realize(),
]

quantize_seq = tvm.transform.Sequential(quant_passes)
with tvm.transform.PassContext(
    opt_level=3,
    required_pass=["QuantizeAnnotate", "QuantizeCalibrate", "QuantizeRealize"],
):
    with quantize_context():
        mod = quantize_seq(mod)
1.1.0.1. calibrate
def calibrate(dataset=None):
    def wrapped_func(mod, _):
        cfg = quantize.current_qconfig()

        if cfg.calibrate_mode == "kl_divergence":
            input_scale_func = _kl_scale(mod, dataset)
        elif cfg.calibrate_mode == "global_scale":
            input_scale_func = _global_scale
        elif cfg.calibrate_mode == "percentile":
            input_scale_func = _percentile_scale(mod, dataset)

        if cfg.weight_scale == "max":
            weight_scale_func = _max_scale
        elif cfg.weight_scale == "power2":
            weight_scale_func = _power2_scale

        return _set_params(mod, input_scale_func, weight_scale_func)

    return wrapped_func


_set_params:
    # nbit 默认的定义在 ~/source/tvm/python/tvm/relay/quantize/quantize.py, 例如
    #      "nbit_input": 8,
    #      "nbit_weight": 8,
    #      "nbit_activation": 32,
    nbit = cfg.get_nbit_by_kind(kind)
    valid_bit = nbit - attrs.sign
    if kind == quantize.QAnnotateKind.WEIGHT:
        scale = weight_scale_func(expr)
    else:
        scale = input_scale_func(expr)

    # 这里显示 tvm 的量化是对称量化, 上面计算的 scale 在原理上应该等于
    # np.max(np.abs(orig_data))

    valid_range = 2 ** valid_bit
    const_params[ndom_scale] = _make_const(scale / valid_range)
    const_params[nclip_min] = _make_const(-(valid_range - 1))
    const_params[nclip_max] = _make_const((valid_range - 1))
1.1.0.1.1. percentile
def _percentile_scale(mod, dataset):
    cfg = quantize.current_qconfig()
    chunk_by = cfg.calibrate_chunk_by
    scales = []
    for samples in collect_stats(mod, dataset, chunk_by):
        logging.info("finding threshold with percentile for calibration...")
        with mp.Pool() as pool:
            scales += list(pool.map(_find_scale_by_percentile, samples))

    def func(_):
        # func 返回每一层的 scale (实际上相当于每层的`最大值`, 后续会通过
        # scale/valid_range 获得最终的 scale)
        scale = scales[func.scale_idx]
        func.scale_idx += 1
        return scale

    func.scale_idx = 0

    return func


# 使用 calibrate_dataset 获得每一层的输出
def collect_stats(mod, dataset, chunk_by=-1):
    runtime = _get_profile_runtime(mod)
    num_outputs = runtime.get_num_outputs()

    for i in range(0, num_outputs, chunk_by):
        outputs = [[] for i in range(min(chunk_by, num_outputs - i))]
        for batch in dataset:
            runtime.set_input(**batch)
            runtime.run()
            for j in range(i, min(i + chunk_by, num_outputs)):
                outputs[j - i].append(runtime.get_output(j).numpy())
        yield [np.concatenate(output).reshape(-1) for output in outputs]


# 这个函数相当于一个 get_k_largest 函数, 其中 k=percentile*(arr.size), 用来避免
# 个别极大的值造成量化误差
# file:~/Gitbox/code/leetcode/go/src/kth_largest_element_in_an_array/kth_largest_element_in_an_array.go
def _find_scale_by_percentile(arr, percentile=0.99999):
    assert isinstance(arr, np.ndarray)
    x = np.abs(arr)
    max_k = int(x.size * percentile)
    return np.partition(x, max_k)[max_k]
1.1.0.1.2. kl_divergence

KL divergence 的定义: \(KL(p,q)=\sum\big({p_i*log\frac{p_i}{q_i}}\big)\)

percentile 是使用 get_k_largest 计算 saturation

kl_divergence 则选择能够 minimize(kl_divergence) 的 saturation

https://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf

Input: abs(FP32) histogram H with 2048 bins: bin[0], …, bin[2047]

for i in range( 128 , 2048 ):
    P = [:i] 
    P[-1] += sum(bin[i:])

    Q = split bin[:i] into 128 levels and expand back to `i` bins

    divergence[ i ] = KL_divergence( P /= sum(P), Q /= sum(Q))
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 2021-08-23 22:25
import numpy as np


def kl_divergence(p, q):
    return np.sum(p * np.log(p / q))


def check(P, target_bins):
    Q = np.array([np.average(x) for x in np.split(P, target_bins)])
    Q = np.repeat(Q, len(P) // target_bins)

    return kl_divergence(P / np.sum(P), Q / np.sum(Q))


if __name__ == "__main__":
    print(check(np.array([1, 2, 3, 4, 5, 6, 7, 8]), 2))
    print(check(np.array([1, 3, 5, 7, 2, 4, 6, 8]), 2))
    print(check(np.array([2, 2, 2, 2, 6, 6, 6, 6]), 2))

0.04033873632737679 0.13645806664911772 0.0

1.1.0.1.3. global_scale

global_scale 实际上是假设每层的`最大值`都是相同的值, 例如 8 或 16, 所以它不需要 calibrate_dataset

def _global_scale(sq_call):
    cfg = quantize.current_qconfig()
    return cfg.global_scale
1.1.0.2. annotate
1.1.0.3. realize

1.2. relay.qnn

https://tvm.apache.org/docs/tutorials/frontend/deploy_prequantized.html#

对于已经量化过的模型, tvm 使用 relay.qnn 来支持. relay.qnn 并不是全新定义的 relay IR, 它是在原来的 relay 之前通过 tvm 的 Canonicalize 机制添加了一些转换

例如, 一个 qnn.mul 会通过 QnnMulCanonicalize 函数转换成多个 relay.ir 的操作,用来把 frontend 的量化方式转换为 tvm 的量化方式. QnnMulCanonicalize 是通过 relay::qnn::transform::Legalize 被执行的

1.2.1. QnnMulCanonicalize

1.2.1.1. Example

tvm/tests/python/relay/test_op_qnn_mul.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 2021-09-22 16:20

import tvm
from tvm import te
import numpy as np
from tvm import relay

x_data = np.array((1, 153, 2, 178)).reshape((1, 4))
y_data = np.array((204, 178, 1, 8)).reshape((1, 4))

x_scale = y_scale = z_scale = 0.00784314
x_zero_point = y_zero_point = z_zero_point = 127


def dequant(data, scale, zp):
    return scale * (np.asarray(data) - zp)


def quant(data, scale, zp):
    z = np.around(data / scale + zp)
    q_min = np.iinfo(np.uint8).min
    q_max = np.iinfo(np.uint8).max
    return np.clip(z, q_min, q_max)


def mul_manually():
    return quant(
        dequant(x_data, x_scale, x_zero_point) * dequant(y_data, y_scale, y_zero_point),
        z_scale,
        z_zero_point,
    )


if __name__ == "__main__":
    x = relay.var("x", shape=(1, 4), dtype="uint8")
    y = relay.var("y", shape=(1, 4), dtype="uint8")
    z = relay.qnn.op.mul(
        lhs=x,
        rhs=y,
        lhs_scale=relay.const(x_scale, "float32"),
        lhs_zero_point=relay.const(x_zero_point, "int32"),
        rhs_scale=relay.const(y_scale, "float32"),
        rhs_zero_point=relay.const(y_zero_point, "int32"),
        output_scale=relay.const(z_scale, "float32"),
        output_zero_point=relay.const(z_zero_point, "int32"),
    )

    func = relay.Function([x, y], z)
    mod = tvm.IRModule.from_expr(func)
    print("----------before qnn transform----------")
    print(mod)
    mod = relay.transform.InferType()(mod)
    mod = relay.qnn.transform.CanonicalizeOps()(mod)
    print("----------after qnn transform----------")
    print(mod)
    func = mod["main"]

    intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm")
    op_res = intrp.evaluate(func)(x_data, y_data)

    print("----------")
    print(op_res.numpy())
    print("----------")
    golden = mul_manually()
    print(golden.astype("uint8"))

-----–—before qnn transform-----–— def @main(%x: Tensor[(1, 4), uint8], %y: Tensor[(1, 4), uint8]) { qnn.mul(%x, %y, 0.00784314f, 127, 0.00784314f, 127, 0.00784314f, 127) }

-----–—after qnn transform-----–— def @main(%x: Tensor[(1, 4), uint8], %y: Tensor[(1, 4), uint8]) -> Tensor[(1, 4), uint8] { %0 = cast(%x, dtype="int32") * ty=Tensor[(1, 4), int32] *; %1 = cast(%y, dtype="int32") * ty=Tensor[(1, 4), int32] *; %2 = subtract(%0, 127 * ty=int32 *) * ty=Tensor[(1, 4), int32] *; %3 = subtract(%1, 127 * ty=int32 *) * ty=Tensor[(1, 4), int32] *; %4 = multiply(%2, %3) * ty=Tensor[(1, 4), int32] *; %5 = cast(%4, dtype="int32") * ty=Tensor[(1, 4), int32] *; %6 = cast(127 * ty=int32 *, dtype="int32") * ty=int32 *; %7 = fixed_point_multiply(%5, multiplier=1077952893, shift=-6) * ty=Tensor[(1, 4), int32] *; %8 = add(%6, %7) * ty=Tensor[(1, 4), int32] *; %9 = clip(%8, a_min=0f, a_max=255f) * ty=Tensor[(1, 4), int32] *; cast(%9, dtype="uint8") * ty=Tensor[(1, 4), uint8] * }



1.2.1.2. How is `QnnMulCanonicalize` invoked
QNN_REGISTER_BINARY_OP("mul")
    .describe("Elementwise mul with with broadcasting for quantized tensors.")
    .set_support_level(11)
    .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnMulCanonicalize);

Pass Legalize() {
  Array<Pass> pass_seqs;
  pass_seqs.push_back(relay::transform::Legalize("FTVMQnnLegalize"));
  pass_seqs.push_back(relay::transform::Legalize("FTVMQnnCanonicalize"));
  relay::transform::Pass seq = relay::transform::Sequential(pass_seqs);
  return seq;
}

Array<Pass> GetPassPrefix(const Map<tvm::Integer, tvm::Target>& targets, bool is_vm) {
  Array<Pass> pass_seqs;
  Array<runtime::String> entry_functions{"main"};
  // ..
  // Run all dialect legalization passes.
  pass_seqs.push_back(relay::qnn::transform::Legalize());
  // ...
}

1.2.1.3. What `QnnMulCanonicalize` does

relay.qnn.mul 不能直接转换成 mul, 因为 relay.mul 只支持对称量化.

以 Example 的代码为例, QnnMulCanonicalize 所做的事实际上就是用下面的公式计算 \(Z_q\)

\(Z_q=\frac{X_sY_s}{Z_s}(X_q-127)(Y_q-127)+127\)

def @main(%x: Tensor[(1, 4), uint8], %y: Tensor[(1, 4), uint8]) -> Tensor[(1, 4), uint8] {
  //...
  %2 = subtract(%0, 127 /* ty=int32 */) /* ty=Tensor[(1, 4), int32] */;
  %3 = subtract(%1, 127 /* ty=int32 */) /* ty=Tensor[(1, 4), int32] */;
  %4 = multiply(%2, %3) /* ty=Tensor[(1, 4), int32] */;
  %5 = cast(%4, dtype="int32") /* ty=Tensor[(1, 4), int32] */;
  %6 = cast(127 /* ty=int32 */, dtype="int32") /* ty=int32 */;
  %7 = fixed_point_multiply(%5, multiplier=1077952893, shift=-6) /* ty=Tensor[(1, 4), int32] */;
  %8 = add(%6, %7) /* ty=Tensor[(1, 4), int32] */;
  %9 = clip(%8, a_min=0f, a_max=255f) /* ty=Tensor[(1, 4), int32] */;
  //...
}

1.3. WAIT Quantization and BYOC

1.3.1. cmsisnn

tvm 的 cmsisnn 依赖于 qnn 及 cmsisnn 的 xxx_s8 系列的 api 来实现对量化的支持

Backlinks

Quantization (Quantization > Overview): 还有一些框架例如 TVM 会采用更复杂的量化方式, 因为选择 (q_min, Q_min), (q_max, Q_max) 来计算线性映射并不是唯一的方法, 理论上最佳的量化方式可以最小化量化误差, 即 minimize_divergency(x, DQ(Q(x)))

Quantization (Quantization > TVM Quantization): TVM Quantization

Author: [email protected]
Date: 2021-08-05 Thu 00:00
Last updated: 2023-12-01 Fri 18:28

知识共享许可协议