TVM VTA
Table of Contents
1. TVM VTA
1.1. Overview
import os
import tvm
from tvm import te
import vta
import numpy as np
env = vta.get_env()
from tvm import rpc
from tvm.contrib import utils
from vta.testing import simulator
remote = rpc.LocalSession()
# Output channel factor m - total 64 x 16 = 1024 output channels
m = 64
# Batch factor o - total 1 x 1 = 1
o = 1
A = te.placeholder((o, m, env.BATCH, env.BLOCK_OUT), name="A", dtype=env.acc_dtype)
B = te.placeholder((o, m, env.BATCH, env.BLOCK_OUT), name="B", dtype=env.acc_dtype)
# A copy buffer
A_buf = te.compute((o, m, env.BATCH, env.BLOCK_OUT), lambda *i: A(*i), "A_buf")
# B copy buffer
B_buf = te.compute((o, m, env.BATCH, env.BLOCK_OUT), lambda *i: B(*i), "B_buf")
C_buf = te.compute(
(o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: A_buf(*i).astype(env.acc_dtype) + B_buf(*i).astype(env.acc_dtype),
name="C_buf",
)
C = te.compute(
(o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: C_buf(*i).astype(env.inp_dtype),
name="C",
)
s = te.create_schedule(C.op)
s[A_buf].set_scope(env.acc_scope)
s[B_buf].set_scope(env.acc_scope)
s[C_buf].set_scope(env.acc_scope)
s[A_buf].pragma(s[A_buf].op.axis[0], env.dma_copy)
s[B_buf].pragma(s[B_buf].op.axis[0], env.dma_copy)
s[C].pragma(s[C].op.axis[0], env.dma_copy)
s[C_buf].pragma(C_buf.op.axis[0], env.alu)
print(vta.lower(s, [A, B, C], simple_mode=True))
my_vadd = vta.build(s, [A, B, C], "ext_dev", "llvm", name="my_vadd")
my_vadd.save("/tmp/my_vadd.o")
primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {C: Buffer(C_2: Pointer(int8), int8, [1, 64, 1, 16], []),
B: Buffer(B_2: Pointer(int32), int32, [1, 64, 1, 16], []),
A: Buffer(A_2: Pointer(int32), int32, [1, 64, 1, 16], [])}
buffer_map = {A_1: A, B_1: B, C_1: C} {
attr [IterVar(vta: int32, (nullptr), "ThreadIndex", "vta")] "coproc_scope" = 2 {
@tir.call_extern("VTALoadBuffer2D", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), A_2, 0, 64, 1, 64, 0, 0, 0, 0, 0, 3, dtype=int32)
@tir.call_extern("VTALoadBuffer2D", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), B_2, 0, 64, 1, 64, 0, 0, 0, 0, 64, 3, dtype=int32)
attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_uop_scope" = "VTAPushALUOp" {
@tir.call_extern("VTAUopLoopBegin", 64, 1, 1, 0, dtype=int32)
@tir.vta.uop_push(1, 0, 0, 64, 0, 2, 0, 0, dtype=int32)
@tir.call_extern("VTAUopLoopEnd", dtype=int32)
}
@tir.vta.coproc_dep_push(2, 3, dtype=int32)
}
attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_scope" = 3 {
@tir.vta.coproc_dep_pop(2, 3, dtype=int32)
@tir.call_extern("VTAStoreBuffer2D", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), 0, 4, C_2, 0, 64, 1, 64, dtype=int32)
}
@tir.vta.coproc_sync(, dtype=int32)
}
objdump -dr /tmp/my_vadd.o |grep -A200 my_vadd_compute_\>:
0000000000000560 <my_vadd_compute_>:
560: 41 57 push %r15
562: 41 56 push %r14
564: 53 push %rbx
565: 49 89 ce mov %rcx,%r14
568: 49 89 d7 mov %rdx,%r15
56b: 48 89 fb mov %rdi,%rbx
56e: ba 00 00 00 00 mov $0x0,%edx
573: b9 40 00 00 00 mov $0x40,%ecx
578: 41 b8 01 00 00 00 mov $0x1,%r8d
57e: 41 b9 40 00 00 00 mov $0x40,%r9d
584: 6a 03 pushq $0x3
586: 6a 00 pushq $0x0
588: 6a 00 pushq $0x0
58a: 6a 00 pushq $0x0
58c: 6a 00 pushq $0x0
58e: 6a 00 pushq $0x0
590: e8 00 00 00 00 callq 595 <my_vadd_compute_+0x35>
591: R_X86_64_PLT32 VTALoadBuffer2D-0x4
595: 48 83 c4 30 add $0x30,%rsp
599: 48 89 df mov %rbx,%rdi
59c: 4c 89 fe mov %r15,%rsi
59f: 31 d2 xor %edx,%edx
5a1: b9 40 00 00 00 mov $0x40,%ecx
5a6: 41 b8 01 00 00 00 mov $0x1,%r8d
5ac: 41 b9 40 00 00 00 mov $0x40,%r9d
5b2: 6a 03 pushq $0x3
5b4: 6a 40 pushq $0x40
5b6: 6a 00 pushq $0x0
5b8: 6a 00 pushq $0x0
5ba: 6a 00 pushq $0x0
5bc: 6a 00 pushq $0x0
5be: e8 00 00 00 00 callq 5c3 <my_vadd_compute_+0x63>
5bf: R_X86_64_PLT32 VTALoadBuffer2D-0x4
5c3: 48 83 c4 30 add $0x30,%rsp
5c7: 48 8d 3d 00 00 00 00 lea 0x0(%rip),%rdi # 5ce <my_vadd_compute_+0x6e>
5ca: R_X86_64_PC32 .bss+0x24
5ce: 48 8d 35 6b 00 00 00 lea 0x6b(%rip),%rsi # 640 <my_vadd_compute_+0xe0>
5d5: 31 d2 xor %edx,%edx
5d7: 31 c9 xor %ecx,%ecx
5d9: e8 00 00 00 00 callq 5de <my_vadd_compute_+0x7e>
5da: R_X86_64_PLT32 VTAPushALUOp-0x4
5de: 85 c0 test %eax,%eax
5e0: 75 56 jne 638 <my_vadd_compute_+0xd8>
5e2: 48 89 df mov %rbx,%rdi
5e5: be 02 00 00 00 mov $0x2,%esi
5ea: ba 03 00 00 00 mov $0x3,%edx
5ef: e8 00 00 00 00 callq 5f4 <my_vadd_compute_+0x94>
5f0: R_X86_64_PLT32 VTADepPush-0x4
5f4: 48 89 df mov %rbx,%rdi
5f7: be 02 00 00 00 mov $0x2,%esi
5fc: ba 03 00 00 00 mov $0x3,%edx
601: e8 00 00 00 00 callq 606 <my_vadd_compute_+0xa6>
602: R_X86_64_PLT32 VTADepPop-0x4
606: 48 89 df mov %rbx,%rdi
609: 31 f6 xor %esi,%esi
60b: ba 04 00 00 00 mov $0x4,%edx
610: 4c 89 f1 mov %r14,%rcx
613: 45 31 c0 xor %r8d,%r8d
616: 41 b9 40 00 00 00 mov $0x40,%r9d
61c: 6a 40 pushq $0x40
61e: 6a 01 pushq $0x1
620: e8 00 00 00 00 callq 625 <my_vadd_compute_+0xc5>
621: R_X86_64_PLT32 VTAStoreBuffer2D-0x4
625: 48 83 c4 10 add $0x10,%rsp
629: 48 89 df mov %rbx,%rdi
62c: be 00 00 00 80 mov $0x80000000,%esi
631: e8 00 00 00 00 callq 636 <my_vadd_compute_+0xd6>
632: R_X86_64_PLT32 VTASynchronize-0x4
636: 31 c0 xor %eax,%eax
638: 5b pop %rbx
639: 41 5e pop %r14
63b: 41 5f pop %r15
63d: c3 retq
63e: 90 nop
63f: 90 nop
640: 50 push %rax
641: bf 40 00 00 00 mov $0x40,%edi
646: be 01 00 00 00 mov $0x1,%esi
64b: ba 01 00 00 00 mov $0x1,%edx
650: 31 c9 xor %ecx,%ecx
652: e8 00 00 00 00 callq 657 <my_vadd_compute_+0xf7>
653: R_X86_64_PLT32 VTAUopLoopBegin-0x4
657: bf 01 00 00 00 mov $0x1,%edi
65c: 31 f6 xor %esi,%esi
65e: 31 d2 xor %edx,%edx
660: b9 40 00 00 00 mov $0x40,%ecx
665: 45 31 c0 xor %r8d,%r8d
668: 41 b9 02 00 00 00 mov $0x2,%r9d
66e: 6a 00 pushq $0x0
670: 6a 00 pushq $0x0
672: e8 00 00 00 00 callq 677 <my_vadd_compute_+0x117>
673: R_X86_64_PLT32 VTAUopPush-0x4
677: 48 83 c4 10 add $0x10,%rsp
67b: e8 00 00 00 00 callq 680 <my_vadd_compute_+0x120>
67c: R_X86_64_PLT32 VTAUopLoopEnd-0x4
680: 31 c0 xor %eax,%eax
682: 59 pop %rcx
683: c3 retq
1.2. vta.build
1.2.1. tir.add_lower_pass
vta.build 是在 tvm.build 基础上通过 add_lower_pass 添加了一些自定义 pass, 根据 tir stmt 的 pragma 可以生成新的对 tir 完成对 vta runtime 调用, vta.build 通过 tir.add_lower_pass 添加的 pass 会在 LowerWithPassList 时会被调用到
def build(*args, **kwargs): pass_ctx = tvm.transform.PassContext.current() with build_config(): return tvm.build(*args, **kwargs) def build_config(debug_flag=0, **kwargs): """Build a build config for VTA. Example -------- .. code-block:: python # build a vta module. with vta.build_config(): vta_module = tvm.build(s, ...) """ env = get_env() @tvm.tir.transform.prim_func_pass(opt_level=0) def add_debug(f, *_): debug = tvm.tir.call_extern( "int32", "VTASetDebugMode", env.dev.command_handle, debug_flag ) return f.with_body(tvm.tir.stmt_seq(debug, f.body)) pass_list = [ (0, transform.InjectConv2DTransposeSkip()), (1, transform.InjectDMAIntrin()), (1, transform.InjectSkipCopy()), (1, transform.AnnotateALUCoProcScope()), (1, tvm.tir.transform.LiftAttrScope("coproc_uop_scope")), (1, transform.LiftAllocToScopeBegin()), (1, tvm.tir.transform.LiftAttrScope("coproc_scope")), (1, transform.InjectCoProcSync()), (1, EarlyRewrite()), ] pass_list.append((2, transform.InjectALUIntrin())) pass_list.append((3, tvm.tir.transform.LowerDeviceStorageAccessInfo())) pass_list.append((3, transform.FoldUopLoop())) pass_list.append((3, transform.CPUAccessRewrite())) config = {"tir.add_lower_pass": pass_list} return tvm.transform.PassContext(config=config, **kwargs)
1.2.2. InjectALUIntrin
InjectALUIntrin 可以把 `alu` pragma 指示的 tir 替换成对 vta runtime 中 VTAUopPush 等的调用
def InjectALUIntrin(): def _ftransform(func, mod, ctx): env = get_env() idxm = tvm.tir.indexmod analyzer = tvm.arith.Analyzer() def _do_fold(stmt): # ... # the `alu` pragma if _match_pragma(stmt, "alu"): # Get to the innermost loop body loop_body = stmt.body nest_size = 0 while isinstance(loop_body, tvm.tir.For): loop_body = loop_body.body nest_size += 1 # ... if isinstance(loop_body.value, tvm.tir.Add): alu_opcode = env.dev.ALU_OPCODE_ADD lhs = loop_body.value.a rhs = loop_body.value.b elif isinstance(loop_body.value, tvm.tir.Sub): alu_opcode = env.dev.ALU_OPCODE_SUB lhs = loop_body.value.a rhs = loop_body.value.b elif isinstance(loop_body.value, tvm.tir.Mul): alu_opcode = env.dev.ALU_OPCODE_MUL lhs = loop_body.value.a rhs = loop_body.value.b elif isinstance(loop_body.value, tvm.tir.Min): alu_opcode = env.dev.ALU_OPCODE_MIN lhs = loop_body.value.a rhs = loop_body.value.b elif isinstance(loop_body.value, tvm.tir.Max): alu_opcode = env.dev.ALU_OPCODE_MAX lhs = loop_body.value.a rhs = loop_body.value.b # ... # Insert ALU micro-ops irb = tvm.tir.ir_builder.create() for idx, extent in enumerate(extents): irb.emit( tvm.tir.call_extern( "int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx], 0 ) ) use_imm = int(use_imm) irb.emit( tvm.tir.call_intrin( "int32", "tir.vta.uop_push", 1, 0, dst_coeff[len(dst_coeff) - 1], src_coeff[len(src_coeff) - 1], 0, alu_opcode, use_imm, imm_val, ) ) for extent in extents: irb.emit(tvm.tir.call_extern("int32", "VTAUopLoopEnd")) return irb.get() return stmt return func.with_body( tvm.tir.stmt_functor.ir_transform(func.body, None, _do_fold, ["tir.AttrStmt"]) ) return tvm.tir.transform.prim_func_pass( _ftransform, opt_level=0, name="tir.vta.InjectALUIntrin" )
1.2.3. Heterogeneous Execution
通过把 relay 编译成 vta 和 cpu 混合的操作, 可以达到 cpu/vta 异构执行的目的
relay.build 可以支持 vta:
`tvm/vta/python/top` 中包含了针对 vta 的 topi 实现, 它会针对 `vta` 这个 device 定义自己的 compute 和 schedule, 其中 schedule 里会向上面 te 的例子一样使用 pragma 标记 vta 支持的操作 (alu, copy_dma, GEMM 等)
tvm/vta/python/vta/top/op.py::def schedule_alu_packed(cfg, outs):用户调用 relay.build 时需要通过 vta.build_config 引入 vta 自己的 tir.add_lower_pass 来处理那些 pragma
# env.target_host 为 llvm # env.target 为 ext_dev -keys=vta,cpu -device=vta -model=sim_1x16_i8w8a32_15_15_18_17 with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): graph, lib, params = relay.build( relay_prog, target=env.target, params=params, target_host=env.target_host )
1.2.3.1. graph_pack
由于 vta 支持的操作非常少 (只支持整型的 mul,add,…, 以及 dense 和 conv2d), 所以 relay.build 时需要确定哪些 relay IR 需要由 vta 执行, vta 不支持的需要跑在 cpu 上.
graph_pack 函数提供了两个功能:
- annotation, 通过 relay.annotation.on_device 修改 relay IR, 标记上哪些 relay IR 需要由 vta 执行 (on_device annotation 最终在 relay.build 由处理)
- graph_pack, 根据 vta 的配置对数据进行 reshape, 以便后续 schedule 时能应用 tensorize
但现在 graph_pack 的功能还有些问题:
- graph_pack 要求 shape[1] 必须为 cfactor 的倍数, shape[1] 对 conv2d 来说是 NCHW 中的 C, 对 dense 来说是 output size, 这个限制非常大, 直接限制了网络的参数
- 它认为 conv2d 一定是 NCHW 格式, 所以 tflite 的 conv2d 无法 用 vta 来执行
- 通过 (start_name, stop_name), (annot_start_name, annot_end_name) 参数来决定哪些 ir 需要处理, 但不太好用
1.2.3.2. Example
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # 2021-09-08 18:19 from mxnet.gluon.model_zoo import vision import numpy as np import tvm from tvm import te from tvm import rpc, autotvm, relay from tvm.contrib import graph_executor, utils, download, graph_runtime from tvm.contrib.debugger import debug_executor from tvm.relay import transform import vta from vta.top import graph_pack env = vta.get_env() target = env.target dtype_dict = {"data": "float32"} shape_dict = {"data": (env.BATCH, 3, 224, 224)} gluon_model = vision.get_model("resnet18_v1", pretrained=True) mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict) shape_dict.update({k: v.shape for k, v in params.items()}) dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) with tvm.transform.PassContext(opt_level=3): with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]): mod = relay.quantize.quantize(mod, params=params) relay_prog = graph_pack( mod["main"], env.BATCH, env.BLOCK_OUT, env.WGT_WIDTH, start_name="nn.max_pool2d", stop_name="nn.global_avg_pool2d", annot_start_name="nn.conv2d", annot_end_name="annotation.stop_fusion", device_annot=(env.TARGET == "intelfocl" or env.TARGET == "sim"), ) print(relay_prog) with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): graph, lib, params = relay.build( relay_prog, target={"cpu": env.target_vta_cpu, "ext_dev": target}, params=params, target_host=env.target_host, )
fn (%data: Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 1000), float32] {
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%1 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%2 = nn.relu(%1) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%3 = nn.max_pool2d(%2, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%4 = reshape(%3, newshape=[1, 1, 4, 16, 56, 56]) /* ty=Tensor[(1, 1, 4, 16, 56, 56), float32] */;
%5 = transpose(%4, axes=[0, 2, 4, 5, 1, 3]) /* ty=Tensor[(1, 4, 56, 56, 1, 16), float32] */;
%6 = annotation.stop_fusion(%5) /* ty=Tensor[(1, 4, 56, 56, 1, 16), float32] */;
%7 = multiply(%6, 16f /* ty=float32 */) /* ty=Tensor[(1, 4, 56, 56, 1, 16), float32] */;
%8 = round(%7) /* ty=Tensor[(1, 4, 56, 56, 1, 16), float32] */;
%9 = clip(%8, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 4, 56, 56, 1, 16), float32] */;
%10 = cast(%9, dtype="int32") /* ty=Tensor[(1, 4, 56, 56, 1, 16), int32] */;
...
add(%631, meta[relay.Constant][41] /* ty=Tensor[(1000), float32] */) /* ty=Tensor[(1, 1000), float32] */
}
在上面的输出的 relay_prog 中,
nn.conv2d pack 的结果
data_layout 变成为 NCHW1n16c, 其中 16c 表示 512 的 channels 变成为 16x(512/16)
%566 = nn.conv2d(%564, %565, padding=[1, 1, 1, 1], channels=512, kernel_size=[3, 3], data_layout="NCHW1n16c", kernel_layout="OIHW16o16i", out_dtype="int32") /* ty=Tensor[(1, 32, 7, 7, 1, 16), int32] */;
annotation 的结果
nn.relu 添加了一个 on_device 的属性, 表示 relay build 时对这个 nn.relu 使用 vta 做为 device.
实际上 vta 的 env.target 的值为 ext_dev -keys=vta,cpu -device=vta -model=sim_1x16_i8w8a32_15_15_18_17
因为指定了 `keys=vta,cpu`,relay.build 查找 strategy 时会把 cpu 做为 fallback, 所以即使前面 annotate 时把 vta 不支持的 relay IR 比如 nn.relu 标记为 on_device(env.target), 但编译时最终还是会用 cpu 的 strategy 去编译 (
tvm/python/tvm/target/generic_func.py::for k in target.keys:)%617 = nn.relu(%616) /* ty=Tensor[(1, 32, 7, 7, 1, 16), int32] */; %618 = on_device(%617, meta[relay.attrs.OnDeviceAttrs][299]) /* ty=Tensor[(1, 32, 7, 7, 1, 16), int32] */;
1.3. VTA Driver
vta.build 时通过 add_lower_pass 添加自定义 pass, 根据 tir 的 pragma 信息把 tir 修改成对 vta runtime 相关函数的调用, 以 Overview 中的代码为例, 会依次调用以下函数:
- VTALoadBuffer2D
- VTALoadBuffer2D
- VTAUopLoopBegin
- VTAUopPush
- VTAUopLoopEnd
- VTAPushALUOp
- VTADepPush
- VTADepPop
- VTAStoreBuffer2D
- VTASynchronize
上面的函数会再转换成 vta instruction 及对 vta driver 的调用, 以 VTAPushALUOp 为例:
VTAPushALUOp 先生成 instruction 放在 insn_queue_ 中
void PushALUUop(UopKernel* kernel) { VTAAluInsn* insn = insn_queue_.CreateAluInsn(); insn->opcode = VTA_OPCODE_ALU; insn->reset_reg = kernel->reset_out_; insn->uop_bgn = kernel->sram_begin_; insn->uop_end = kernel->sram_end_; insn->alu_opcode = kernel->opcode_; insn->use_imm = kernel->use_imm_; insn->imm = kernel->imm_val_; const std::vector<UopKernel::LoopEntry>& loop = kernel->loop(); // ... insn->iter_out = loop[0].extent; insn->dst_factor_out = loop[0].dst_factor; insn->src_factor_out = loop[0].src_factor; insn->iter_in = loop[1].extent; insn->dst_factor_in = loop[1].dst_factor; insn->src_factor_in = loop[1].src_factor; // ... }
VTASynchronize 会把 insn_queue_ 从本地复制到 fpga, 然后再让 fpga 执行
void Synchronize(uint32_t wait_cycles) { insn_queue_.AutoReadBarrier(); // dram_buffer_ 是本地的一个 std::vector // fpga_buff_ 是通过 VTAMemAlloc 分配的 fpga 上的内存 VTAMemCopyFromHost(fpga_buff_, dram_buffer_.data(), buff_size); VTADeviceRun( device_, insn_queue_.dram_phy_addr(), insn_queue_.count(), wait_cycles); }
VTA driver 提供的功能主要有:
- VTADeviceAlloc
- VTAMemAlloc / VTAMemFree
- VTAMemCopyFromHost / VTAMemCopyToHost
- VTADeviceRun
1.4. MISC
1.4.1. vta 与 opencl target 的区别
vta 无法像 opencl 一样支持所有的操作, 它的方式称为 heterogeneous execution:
编译时需要通过 on_device 修改各个 IR 的 target, 而对于 opencl, 所有 IR 都使用 opencl target, 且运行时需要通过 on_device 生成的 device_copy 频繁的在 cpu与 vta 之拷贝数据, opencl 只需要针对 input/output 进行 copyopencl 也支持 heterogeneous execution: GraphExecutorCodegen
- vta build 是在 target.build.llvm 基础上通过 tir.add_lower_pass 加入几个自定义 pass 实现的针对 alu, gemm 等的 offload, 而 opencl 有它自己的 target.build.opencl
