TVM VTA

Table of Contents

1. TVM VTA

1.1. Overview

import os
import tvm
from tvm import te
import vta
import numpy as np

env = vta.get_env()

from tvm import rpc
from tvm.contrib import utils
from vta.testing import simulator

remote = rpc.LocalSession()

# Output channel factor m - total 64 x 16 = 1024 output channels
m = 64
# Batch factor o - total 1 x 1 = 1
o = 1
A = te.placeholder((o, m, env.BATCH, env.BLOCK_OUT), name="A", dtype=env.acc_dtype)
B = te.placeholder((o, m, env.BATCH, env.BLOCK_OUT), name="B", dtype=env.acc_dtype)

# A copy buffer
A_buf = te.compute((o, m, env.BATCH, env.BLOCK_OUT), lambda *i: A(*i), "A_buf")
# B copy buffer
B_buf = te.compute((o, m, env.BATCH, env.BLOCK_OUT), lambda *i: B(*i), "B_buf")

C_buf = te.compute(
    (o, m, env.BATCH, env.BLOCK_OUT),
    lambda *i: A_buf(*i).astype(env.acc_dtype) + B_buf(*i).astype(env.acc_dtype),
    name="C_buf",
)

C = te.compute(
    (o, m, env.BATCH, env.BLOCK_OUT),
    lambda *i: C_buf(*i).astype(env.inp_dtype),
    name="C",
)

s = te.create_schedule(C.op)

s[A_buf].set_scope(env.acc_scope)
s[B_buf].set_scope(env.acc_scope)
s[C_buf].set_scope(env.acc_scope)

s[A_buf].pragma(s[A_buf].op.axis[0], env.dma_copy)
s[B_buf].pragma(s[B_buf].op.axis[0], env.dma_copy)
s[C].pragma(s[C].op.axis[0], env.dma_copy)

s[C_buf].pragma(C_buf.op.axis[0], env.alu)

print(vta.lower(s, [A, B, C], simple_mode=True))

my_vadd = vta.build(s, [A, B, C], "ext_dev", "llvm", name="my_vadd")
my_vadd.save("/tmp/my_vadd.o")
primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {C: Buffer(C_2: Pointer(int8), int8, [1, 64, 1, 16], []),
             B: Buffer(B_2: Pointer(int32), int32, [1, 64, 1, 16], []),
             A: Buffer(A_2: Pointer(int32), int32, [1, 64, 1, 16], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  attr [IterVar(vta: int32, (nullptr), "ThreadIndex", "vta")] "coproc_scope" = 2 {
    @tir.call_extern("VTALoadBuffer2D", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), A_2, 0, 64, 1, 64, 0, 0, 0, 0, 0, 3, dtype=int32)
    @tir.call_extern("VTALoadBuffer2D", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), B_2, 0, 64, 1, 64, 0, 0, 0, 0, 64, 3, dtype=int32)
    attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_uop_scope" = "VTAPushALUOp" {
      @tir.call_extern("VTAUopLoopBegin", 64, 1, 1, 0, dtype=int32)
      @tir.vta.uop_push(1, 0, 0, 64, 0, 2, 0, 0, dtype=int32)
      @tir.call_extern("VTAUopLoopEnd", dtype=int32)
    }
    @tir.vta.coproc_dep_push(2, 3, dtype=int32)
  }
  attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_scope" = 3 {
    @tir.vta.coproc_dep_pop(2, 3, dtype=int32)
    @tir.call_extern("VTAStoreBuffer2D", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), 0, 4, C_2, 0, 64, 1, 64, dtype=int32)
  }
  @tir.vta.coproc_sync(, dtype=int32)
}
objdump -dr /tmp/my_vadd.o |grep -A200 my_vadd_compute_\>:
0000000000000560 <my_vadd_compute_>:
 560:	41 57                	push   %r15
 562:	41 56                	push   %r14
 564:	53                   	push   %rbx
 565:	49 89 ce             	mov    %rcx,%r14
 568:	49 89 d7             	mov    %rdx,%r15
 56b:	48 89 fb             	mov    %rdi,%rbx
 56e:	ba 00 00 00 00       	mov    $0x0,%edx
 573:	b9 40 00 00 00       	mov    $0x40,%ecx
 578:	41 b8 01 00 00 00    	mov    $0x1,%r8d
 57e:	41 b9 40 00 00 00    	mov    $0x40,%r9d
 584:	6a 03                	pushq  $0x3
 586:	6a 00                	pushq  $0x0
 588:	6a 00                	pushq  $0x0
 58a:	6a 00                	pushq  $0x0
 58c:	6a 00                	pushq  $0x0
 58e:	6a 00                	pushq  $0x0
 590:	e8 00 00 00 00       	callq  595 <my_vadd_compute_+0x35>
            591: R_X86_64_PLT32	VTALoadBuffer2D-0x4
 595:	48 83 c4 30          	add    $0x30,%rsp
 599:	48 89 df             	mov    %rbx,%rdi
 59c:	4c 89 fe             	mov    %r15,%rsi
 59f:	31 d2                	xor    %edx,%edx
 5a1:	b9 40 00 00 00       	mov    $0x40,%ecx
 5a6:	41 b8 01 00 00 00    	mov    $0x1,%r8d
 5ac:	41 b9 40 00 00 00    	mov    $0x40,%r9d
 5b2:	6a 03                	pushq  $0x3
 5b4:	6a 40                	pushq  $0x40
 5b6:	6a 00                	pushq  $0x0
 5b8:	6a 00                	pushq  $0x0
 5ba:	6a 00                	pushq  $0x0
 5bc:	6a 00                	pushq  $0x0
 5be:	e8 00 00 00 00       	callq  5c3 <my_vadd_compute_+0x63>
            5bf: R_X86_64_PLT32	VTALoadBuffer2D-0x4
 5c3:	48 83 c4 30          	add    $0x30,%rsp
 5c7:	48 8d 3d 00 00 00 00 	lea    0x0(%rip),%rdi        # 5ce <my_vadd_compute_+0x6e>
            5ca: R_X86_64_PC32	.bss+0x24
 5ce:	48 8d 35 6b 00 00 00 	lea    0x6b(%rip),%rsi        # 640 <my_vadd_compute_+0xe0>
 5d5:	31 d2                	xor    %edx,%edx
 5d7:	31 c9                	xor    %ecx,%ecx
 5d9:	e8 00 00 00 00       	callq  5de <my_vadd_compute_+0x7e>
            5da: R_X86_64_PLT32	VTAPushALUOp-0x4
 5de:	85 c0                	test   %eax,%eax
 5e0:	75 56                	jne    638 <my_vadd_compute_+0xd8>
 5e2:	48 89 df             	mov    %rbx,%rdi
 5e5:	be 02 00 00 00       	mov    $0x2,%esi
 5ea:	ba 03 00 00 00       	mov    $0x3,%edx
 5ef:	e8 00 00 00 00       	callq  5f4 <my_vadd_compute_+0x94>
            5f0: R_X86_64_PLT32	VTADepPush-0x4
 5f4:	48 89 df             	mov    %rbx,%rdi
 5f7:	be 02 00 00 00       	mov    $0x2,%esi
 5fc:	ba 03 00 00 00       	mov    $0x3,%edx
 601:	e8 00 00 00 00       	callq  606 <my_vadd_compute_+0xa6>
            602: R_X86_64_PLT32	VTADepPop-0x4
 606:	48 89 df             	mov    %rbx,%rdi
 609:	31 f6                	xor    %esi,%esi
 60b:	ba 04 00 00 00       	mov    $0x4,%edx
 610:	4c 89 f1             	mov    %r14,%rcx
 613:	45 31 c0             	xor    %r8d,%r8d
 616:	41 b9 40 00 00 00    	mov    $0x40,%r9d
 61c:	6a 40                	pushq  $0x40
 61e:	6a 01                	pushq  $0x1
 620:	e8 00 00 00 00       	callq  625 <my_vadd_compute_+0xc5>
            621: R_X86_64_PLT32	VTAStoreBuffer2D-0x4
 625:	48 83 c4 10          	add    $0x10,%rsp
 629:	48 89 df             	mov    %rbx,%rdi
 62c:	be 00 00 00 80       	mov    $0x80000000,%esi
 631:	e8 00 00 00 00       	callq  636 <my_vadd_compute_+0xd6>
            632: R_X86_64_PLT32	VTASynchronize-0x4
 636:	31 c0                	xor    %eax,%eax
 638:	5b                   	pop    %rbx
 639:	41 5e                	pop    %r14
 63b:	41 5f                	pop    %r15
 63d:	c3                   	retq   
 63e:	90                   	nop
 63f:	90                   	nop
 640:	50                   	push   %rax
 641:	bf 40 00 00 00       	mov    $0x40,%edi
 646:	be 01 00 00 00       	mov    $0x1,%esi
 64b:	ba 01 00 00 00       	mov    $0x1,%edx
 650:	31 c9                	xor    %ecx,%ecx
 652:	e8 00 00 00 00       	callq  657 <my_vadd_compute_+0xf7>
            653: R_X86_64_PLT32	VTAUopLoopBegin-0x4
 657:	bf 01 00 00 00       	mov    $0x1,%edi
 65c:	31 f6                	xor    %esi,%esi
 65e:	31 d2                	xor    %edx,%edx
 660:	b9 40 00 00 00       	mov    $0x40,%ecx
 665:	45 31 c0             	xor    %r8d,%r8d
 668:	41 b9 02 00 00 00    	mov    $0x2,%r9d
 66e:	6a 00                	pushq  $0x0
 670:	6a 00                	pushq  $0x0
 672:	e8 00 00 00 00       	callq  677 <my_vadd_compute_+0x117>
            673: R_X86_64_PLT32	VTAUopPush-0x4
 677:	48 83 c4 10          	add    $0x10,%rsp
 67b:	e8 00 00 00 00       	callq  680 <my_vadd_compute_+0x120>
            67c: R_X86_64_PLT32	VTAUopLoopEnd-0x4
 680:	31 c0                	xor    %eax,%eax
 682:	59                   	pop    %rcx
 683:	c3                   	retq   

1.2. vta.build

1.2.1. tir.add_lower_pass

vta.build 是在 tvm.build 基础上通过 add_lower_pass 添加了一些自定义 pass, 根据 tir stmt 的 pragma 可以生成新的对 tir 完成对 vta runtime 调用, vta.build 通过 tir.add_lower_pass 添加的 pass 会在 LowerWithPassList 时会被调用到

def build(*args, **kwargs):
    pass_ctx = tvm.transform.PassContext.current()
    with build_config():
        return tvm.build(*args, **kwargs)


def build_config(debug_flag=0, **kwargs):
    """Build a build config for VTA.
    Example
    --------
    .. code-block:: python

      # build a vta module.
      with vta.build_config():
          vta_module = tvm.build(s, ...)
    """
    env = get_env()

    @tvm.tir.transform.prim_func_pass(opt_level=0)
    def add_debug(f, *_):
        debug = tvm.tir.call_extern(
            "int32", "VTASetDebugMode", env.dev.command_handle, debug_flag
        )

        return f.with_body(tvm.tir.stmt_seq(debug, f.body))

    pass_list = [
        (0, transform.InjectConv2DTransposeSkip()),
        (1, transform.InjectDMAIntrin()),
        (1, transform.InjectSkipCopy()),
        (1, transform.AnnotateALUCoProcScope()),
        (1, tvm.tir.transform.LiftAttrScope("coproc_uop_scope")),
        (1, transform.LiftAllocToScopeBegin()),
        (1, tvm.tir.transform.LiftAttrScope("coproc_scope")),
        (1, transform.InjectCoProcSync()),
        (1, EarlyRewrite()),
    ]
    pass_list.append((2, transform.InjectALUIntrin()))
    pass_list.append((3, tvm.tir.transform.LowerDeviceStorageAccessInfo()))
    pass_list.append((3, transform.FoldUopLoop()))
    pass_list.append((3, transform.CPUAccessRewrite()))
    config = {"tir.add_lower_pass": pass_list}

    return tvm.transform.PassContext(config=config, **kwargs)

1.2.2. InjectALUIntrin

InjectALUIntrin 可以把 `alu` pragma 指示的 tir 替换成对 vta runtime 中 VTAUopPush 等的调用

def InjectALUIntrin():
    def _ftransform(func, mod, ctx):
        env = get_env()
        idxm = tvm.tir.indexmod
        analyzer = tvm.arith.Analyzer()

        def _do_fold(stmt):
            # ...
            # the `alu` pragma
            if _match_pragma(stmt, "alu"):
                # Get to the innermost loop body
                loop_body = stmt.body
                nest_size = 0
                while isinstance(loop_body, tvm.tir.For):
                    loop_body = loop_body.body
                    nest_size += 1
                # ...
                if isinstance(loop_body.value, tvm.tir.Add):
                    alu_opcode = env.dev.ALU_OPCODE_ADD
                    lhs = loop_body.value.a
                    rhs = loop_body.value.b
                elif isinstance(loop_body.value, tvm.tir.Sub):
                    alu_opcode = env.dev.ALU_OPCODE_SUB
                    lhs = loop_body.value.a
                    rhs = loop_body.value.b
                elif isinstance(loop_body.value, tvm.tir.Mul):
                    alu_opcode = env.dev.ALU_OPCODE_MUL
                    lhs = loop_body.value.a
                    rhs = loop_body.value.b
                elif isinstance(loop_body.value, tvm.tir.Min):
                    alu_opcode = env.dev.ALU_OPCODE_MIN
                    lhs = loop_body.value.a
                    rhs = loop_body.value.b
                elif isinstance(loop_body.value, tvm.tir.Max):
                    alu_opcode = env.dev.ALU_OPCODE_MAX
                    lhs = loop_body.value.a
                    rhs = loop_body.value.b
                # ...
                # Insert ALU micro-ops
                irb = tvm.tir.ir_builder.create()
                for idx, extent in enumerate(extents):
                    irb.emit(
                        tvm.tir.call_extern(
                            "int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx], 0
                        )
                    )
                use_imm = int(use_imm)
                irb.emit(
                    tvm.tir.call_intrin(
                        "int32",
                        "tir.vta.uop_push",
                        1,
                        0,
                        dst_coeff[len(dst_coeff) - 1],
                        src_coeff[len(src_coeff) - 1],
                        0,
                        alu_opcode,
                        use_imm,
                        imm_val,
                    )
                )
                for extent in extents:
                    irb.emit(tvm.tir.call_extern("int32", "VTAUopLoopEnd"))
                return irb.get()
            return stmt

        return func.with_body(
            tvm.tir.stmt_functor.ir_transform(func.body, None, _do_fold, ["tir.AttrStmt"])
        )

    return tvm.tir.transform.prim_func_pass(
        _ftransform, opt_level=0, name="tir.vta.InjectALUIntrin"
    )

1.2.3. Heterogeneous Execution

通过把 relay 编译成 vta 和 cpu 混合的操作, 可以达到 cpu/vta 异构执行的目的

relay.build 可以支持 vta:

  1. `tvm/vta/python/top` 中包含了针对 vta 的 topi 实现, 它会针对 `vta` 这个 device 定义自己的 compute 和 schedule, 其中 schedule 里会向上面 te 的例子一样使用 pragma 标记 vta 支持的操作 (alu, copy_dma, GEMM 等)

    tvm/vta/python/vta/top/op.py::def schedule_alu_packed(cfg, outs):

  2. 用户调用 relay.build 时需要通过 vta.build_config 引入 vta 自己的 tir.add_lower_pass 来处理那些 pragma

    # env.target_host 为 llvm
    # env.target 为 ext_dev -keys=vta,cpu -device=vta -model=sim_1x16_i8w8a32_15_15_18_17
    with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
        graph, lib, params = relay.build(
            relay_prog, target=env.target, params=params, target_host=env.target_host
        )
    
1.2.3.1. graph_pack

由于 vta 支持的操作非常少 (只支持整型的 mul,add,…, 以及 dense 和 conv2d), 所以 relay.build 时需要确定哪些 relay IR 需要由 vta 执行, vta 不支持的需要跑在 cpu 上.

graph_pack 函数提供了两个功能:

  1. annotation, 通过 relay.annotation.on_device 修改 relay IR, 标记上哪些 relay IR 需要由 vta 执行 (on_device annotation 最终在 relay.build 由处理)
  2. graph_pack, 根据 vta 的配置对数据进行 reshape, 以便后续 schedule 时能应用 tensorize

但现在 graph_pack 的功能还有些问题:

  1. graph_pack 要求 shape[1] 必须为 cfactor 的倍数, shape[1] 对 conv2d 来说是 NCHW 中的 C, 对 dense 来说是 output size, 这个限制非常大, 直接限制了网络的参数
  2. 它认为 conv2d 一定是 NCHW 格式, 所以 tflite 的 conv2d 无法 用 vta 来执行
  3. 通过 (start_name, stop_name), (annot_start_name, annot_end_name) 参数来决定哪些 ir 需要处理, 但不太好用
1.2.3.2. Example
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 2021-09-08 18:19
from mxnet.gluon.model_zoo import vision
import numpy as np

import tvm
from tvm import te
from tvm import rpc, autotvm, relay
from tvm.contrib import graph_executor, utils, download, graph_runtime
from tvm.contrib.debugger import debug_executor
from tvm.relay import transform

import vta
from vta.top import graph_pack

env = vta.get_env()
target = env.target

dtype_dict = {"data": "float32"}
shape_dict = {"data": (env.BATCH, 3, 224, 224)}

gluon_model = vision.get_model("resnet18_v1", pretrained=True)
mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict)

shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})

with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]):
        mod = relay.quantize.quantize(mod, params=params)
    relay_prog = graph_pack(
        mod["main"],
        env.BATCH,
        env.BLOCK_OUT,
        env.WGT_WIDTH,
        start_name="nn.max_pool2d",
        stop_name="nn.global_avg_pool2d",
        annot_start_name="nn.conv2d",
        annot_end_name="annotation.stop_fusion",        
        device_annot=(env.TARGET == "intelfocl" or env.TARGET == "sim"),
    )

print(relay_prog)

with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
    graph, lib, params = relay.build(
        relay_prog,
        target={"cpu": env.target_vta_cpu, "ext_dev": target},
        params=params,
        target_host=env.target_host,
    )
fn (%data: Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 1000), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %1 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %2 = nn.relu(%1) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %3 = nn.max_pool2d(%2, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %4 = reshape(%3, newshape=[1, 1, 4, 16, 56, 56]) /* ty=Tensor[(1, 1, 4, 16, 56, 56), float32] */;
  %5 = transpose(%4, axes=[0, 2, 4, 5, 1, 3]) /* ty=Tensor[(1, 4, 56, 56, 1, 16), float32] */;
  %6 = annotation.stop_fusion(%5) /* ty=Tensor[(1, 4, 56, 56, 1, 16), float32] */;
  %7 = multiply(%6, 16f /* ty=float32 */) /* ty=Tensor[(1, 4, 56, 56, 1, 16), float32] */;
  %8 = round(%7) /* ty=Tensor[(1, 4, 56, 56, 1, 16), float32] */;
  %9 = clip(%8, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 4, 56, 56, 1, 16), float32] */;
  %10 = cast(%9, dtype="int32") /* ty=Tensor[(1, 4, 56, 56, 1, 16), int32] */;
  ... 
  add(%631, meta[relay.Constant][41] /* ty=Tensor[(1000), float32] */) /* ty=Tensor[(1, 1000), float32] */
}

在上面的输出的 relay_prog 中,

  • nn.conv2d pack 的结果

    data_layout 变成为 NCHW1n16c, 其中 16c 表示 512 的 channels 变成为 16x(512/16)

    %566 = nn.conv2d(%564, %565, padding=[1, 1, 1, 1], channels=512, kernel_size=[3, 3], data_layout="NCHW1n16c", kernel_layout="OIHW16o16i", out_dtype="int32") /* ty=Tensor[(1, 32, 7, 7, 1, 16), int32] */;
    
  • annotation 的结果

    nn.relu 添加了一个 on_device 的属性, 表示 relay build 时对这个 nn.relu 使用 vta 做为 device.

    实际上 vta 的 env.target 的值为 ext_dev -keys=vta,cpu -device=vta -model=sim_1x16_i8w8a32_15_15_18_17

    因为指定了 `keys=vta,cpu`,relay.build 查找 strategy 时会把 cpu 做为 fallback, 所以即使前面 annotate 时把 vta 不支持的 relay IR 比如 nn.relu 标记为 on_device(env.target), 但编译时最终还是会用 cpu 的 strategy 去编译 (tvm/python/tvm/target/generic_func.py::for k in target.keys:)

    %617 = nn.relu(%616) /* ty=Tensor[(1, 32, 7, 7, 1, 16), int32] */;
    %618 = on_device(%617, meta[relay.attrs.OnDeviceAttrs][299]) /* ty=Tensor[(1, 32, 7, 7, 1, 16), int32] */;
    

1.3. VTA Driver

vta.build 时通过 add_lower_pass 添加自定义 pass, 根据 tir 的 pragma 信息把 tir 修改成对 vta runtime 相关函数的调用, 以 Overview 中的代码为例, 会依次调用以下函数:

  • VTALoadBuffer2D
  • VTALoadBuffer2D
  • VTAUopLoopBegin
  • VTAUopPush
  • VTAUopLoopEnd
  • VTAPushALUOp
  • VTADepPush
  • VTADepPop
  • VTAStoreBuffer2D
  • VTASynchronize

上面的函数会再转换成 vta instruction 及对 vta driver 的调用, 以 VTAPushALUOp 为例:

  1. VTAPushALUOp 先生成 instruction 放在 insn_queue_ 中

    void PushALUUop(UopKernel* kernel) {
        VTAAluInsn* insn = insn_queue_.CreateAluInsn();
        insn->opcode = VTA_OPCODE_ALU;
        insn->reset_reg = kernel->reset_out_;
        insn->uop_bgn = kernel->sram_begin_;
        insn->uop_end = kernel->sram_end_;
        insn->alu_opcode = kernel->opcode_;
        insn->use_imm = kernel->use_imm_;
        insn->imm = kernel->imm_val_;
        const std::vector<UopKernel::LoopEntry>& loop = kernel->loop();
        // ...
        insn->iter_out = loop[0].extent;
        insn->dst_factor_out = loop[0].dst_factor;
        insn->src_factor_out = loop[0].src_factor;
        insn->iter_in = loop[1].extent;
        insn->dst_factor_in = loop[1].dst_factor;
        insn->src_factor_in = loop[1].src_factor;
        // ...
    }
    
  2. VTASynchronize 会把 insn_queue_ 从本地复制到 fpga, 然后再让 fpga 执行

    void Synchronize(uint32_t wait_cycles) {
        insn_queue_.AutoReadBarrier();
            // dram_buffer_ 是本地的一个 std::vector
            // fpga_buff_ 是通过 VTAMemAlloc 分配的 fpga 上的内存
            VTAMemCopyFromHost(fpga_buff_, dram_buffer_.data(), buff_size);
    
        VTADeviceRun(
            device_, insn_queue_.dram_phy_addr(), insn_queue_.count(), wait_cycles);
    }
    

VTA driver 提供的功能主要有:

  1. VTADeviceAlloc
  2. VTAMemAlloc / VTAMemFree
  3. VTAMemCopyFromHost / VTAMemCopyToHost
  4. VTADeviceRun

1.4. MISC

1.4.1. vta 与 opencl target 的区别

vta 无法像 opencl 一样支持所有的操作, 它的方式称为 heterogeneous execution:

  1. 编译时需要通过 on_device 修改各个 IR 的 target, 而对于 opencl, 所有 IR 都使 用 opencl target, 且运行时需要通过 on_device 生成的 device_copy 频繁的在 cpu 与 vta 之拷贝数据, opencl 只需要针对 input/output 进行 copy

    opencl 也支持 heterogeneous execution: GraphExecutorCodegen

  2. vta build 是在 target.build.llvm 基础上通过 tir.add_lower_pass 加入几个自定义 pass 实现的针对 alu, gemm 等的 offload, 而 opencl 有它自己的 target.build.opencl

Author: [email protected]
Date: 2021-09-08 Wed 00:00
Last updated: 2023-12-01 Fri 18:28

知识共享许可协议