TVM Build
Table of Contents
1. TVM Build
1.1. Relay Optimizations
1.1.1. Overview
with tvm.transform.PassContext( opt_level=3, disabled_pass=[], required_pass=[], config={} ): graph, lib, params = relay.build(mod, target="c", params=None)
# build_module.py@relay: relay.build(): bld_mod = BuildModule() bld_mod.build(mod=ir_mod, target=target, params=params, executor=executor, mod_name=mod_name) self._build(mod, target, target_host, executor, mod_name)
# build_module[email protected] void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host, const String executor, const String mod_name) { targets_ = targets; target_host_ = target_host; executor_ = executor; BuildRelay(mod, params_, mod_name); BuildRelay: relay_module = Optimize(relay_module, targets_, params);
1.1.2. optimize
所有的针对 relay 的 optimizations pass 在 tvm::relay::transform 下.
IRModule Optimize( IRModule relay_module, const TargetsMap& targets, const std::unordered_map<std::string, runtime::NDArray>& params) { Array<Pass> pass_seqs = GetPassPrefix(targets, false); transform::Pass seq = transform::Sequential(pass_seqs); relay_module = seq(relay_module); // Handle heterogeneous compilation. // // 只有 targets 指定了多个时才会执行 RunDeviceAnnotationPass, 即处理 // on_device annotation, 例如 build vta 时需要指定 // target={"cpu": env.target_vta_cpu, "ext_dev": env.target}, transform::PassContext pass_ctx = PassContext::Current(); if (targets_.size() > 1) { Optional<Integer> opt_fallback_dev = pass_ctx->GetConfig( "relay.fallback_device_type", Integer(static_cast<int>(kDLCPU))); auto fallback_dev = opt_fallback_dev.value(); relay_module = RunDeviceAnnotationPass(relay_module, fallback_dev->value); } // ... relay_module = transform::FuseOps()(relay_module); // ... relay_module = transform::Inline()(relay_module); relay_module = transform::InferType()(relay_module); relay_module = transform::LabelOps()(relay_module); return relay_module; }
Array<Pass> GetPassPrefix( const Map<tvm::Integer, tvm::Target>& targets, bool is_vm) { Array<Pass> pass_seqs; Array<runtime::String> entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); pass_seqs.push_back(transform::ToBasicBlockNormalForm()); pass_seqs.push_back(relay::qnn::transform::Legalize()); pass_seqs.push_back(transform::SimplifyInference()); pass_seqs.push_back(transform::DynamicToStatic()); pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); pass_seqs.push_back(transform::SimplifyExpr()); pass_seqs.push_back(transform::CombineParallelConv2D(3)); pass_seqs.push_back(transform::CombineParallelDense(3)); pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); pass_seqs.push_back(transform::FoldConstant()); pass_seqs.push_back(transform::FoldScaleAxis()); pass_seqs.push_back(transform::CanonicalizeCast()); pass_seqs.push_back(transform::CanonicalizeOps()); pass_seqs.push_back(transform::FastMath()); pass_seqs.push_back(transform::FoldConstant()); return pass_seqs; }
1.1.3. PassContext
PassContext 可以控制 pass 的开关, 包括 relay 和 tir 相关的 pass. 有时会需要关掉某个 pass, 例如: batchnorm 正常会被 SimplifyInference Pass 转换成 mul/add, 也许对某些 target 有更高效的实现
1.1.3.1. opt_level
1.1.3.2. disabled_pass
1.1.3.3. required_pass
1.1.3.4. config
TVM_REGISTER_PASS_CONFIG_OPTION("tir.noalias", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
1.1.4. Relay Transform
1.2. Relay IR -> TE -> TIR
1.2.1. Relay IR -> OpRegistry
Relay IR 例如 relay.op.sort() 返回的实际上是 Relay Expr: tvm::relay::Call(Op::Get("sort"))
1.2.2. OpRegistry -> TE
1.2.2.1. BuildRelay
BuildRelay: executor_codegen_ = MakeExecutorCodegen(executor_); executor_codegen_->Init(nullptr, targets_); executor_codegen_->Codegen(func, mod_name); auto lowered_module = tec::LowerTE() executor_codegen_->UpdateOutput(&ret_); ret_.params = executor_codegen_->GetParams(); auto lowered_funcs = executor_codegen_->GetIRModule();
1.2.2.2. ExecutorCodegen
1.2.2.2.1. GraphExecutorCodegen
1.2.2.2.2. AOTExecutorCodegen
1.2.2.3. LowerTE
LowerTE: // 这里 VisitExpr 处理的都是 Relay IR LowerTensorExpr.VisitExpr() LowerInternal() auto cfunc = PrimFuncFor(key->source_func, key->target, [&](std::string name) { auto mangled = mangle_fn(name); return GetUniqueName(mangled, &name_map_); }); return ScheduleBuilder(target).Create(source_func, renamer); this->VisitExpr(prim_func->body); // VisitExpr_(xxx) // 这里会调用到 python, 查找 strategy, static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_); // !!! outputs 是 relay 对应的 te outputs = lowered_out->outputs; // !!! impl 中包括 schedule impl = lowered_out->implementation; // external functional, 参考 [[file:tvm_byoc_codegen.org::*编译 external functions][编译 external functions]] lowered_module.external_mods = compiler->LowerExternalFunctions();
1.2.3. TE -> TIR
TE 是没有经过 schedule 的 TIR, 因为 te.compute 返回的已经是 TIR 了.
以下面代码为例:
M = 10 A = te.placeholder((M, ), name="A") B = te.placeholder((M, ), name="B") import ipdb; ipdb.set_trace() C = te.compute((M, ), lambda x: A[x] + B[x])
当 te.compute 返回后, C.op 是由 TIR 组成的 ComputeOp:
加了些 log 后打印出来的 ComputeOp 中的表达式类型为:
<Array> <tir.Add> <tir.ProducerLoad>A[{<tir.Var>x} <tir.ProducerLoad>B[{<tir.Var>x}
te.compute 执行时会直接执行 lambda 产生 Op, 具体过程依赖于 python 的 operator overloading, 例如:
- A 的类型为 Tensor, A[x] 会调用 Tensor.__getitem__, 产生一个 tir:ProducerLoad
- A[x]+B[x] 会调用 ExprOp.__add__, 产生一个 tir::Add
1.2.3.1. Schedule
Schedule 的原理: Schedule
LowerInternal : // LowerSchedule 把 TE 根据 schedule 翻译成 TIR // tvm.build, tvm.lower 也是调用的这个函数把 te.schedule 翻译成 TIR cfunc->funcs->Update( tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds));
以一个简单的 split 为例:
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # 2021-08-03 11:11 import tvm from tvm import te M = 10 A = te.placeholder((M,), name="A") B = te.placeholder((M,), name="B") C = te.compute((M,), lambda x: A[x] + B[x]) s = te.create_schedule(C.op) s[C].split(C.op.axis[0], factor=5) f = tvm.build(s, [A, B, C], target="c", name="hello")
1.2.3.1.1. ScheduleToModule
LowerSchedule: IRModule mod = ScheduleToModule(std::move(sch), args, name, binds); Array<transform::Pass> pass_list = CreatePassList(simple_mode, true); return LowerWithPassList(mod, pass_list); ScheduleToModule: // InferBound 根据 split 的 factor 得到 itervar 及其循环的范围, 比如这里打印 // 出 bounds 的信息: // // for (auto& p : bounds) { // LOG_INFO << p.first << p.second; // } // // 打印出结果为: // // bound.cc:262: {<tir.IterVar>iter_var(x.outer, )}{<Range>range(min={<IntImm>0}, ext={<IntImm>2})} // bound.cc:262: {<tir.IterVar>iter_var(x, {<Range>range(min={<IntImm>0}, ext={<IntImm>10})})}{<Range>range(min={<IntImm>0}, ext={<IntImm>10})} // bound.cc:262: {<tir.IterVar>iter_var(x.inner, )}{<Range>range(min={<IntImm>0}, ext={<IntImm>5})} Map<tir::IterVar, Range> bounds = te::InferBound(sch); tir::Stmt stmt = te::ScheduleOps(sch, std::move(bounds), false); body = MakePipeline(s, dom_map, body, debug_keep_trivial_loop); Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop); MakeComputeStmt(this, stage, dom_map, debug_ keep_trivial_loop); // 初始 body 中只有一条 stmt, 即 lambda(x) :A[x]+B[x] 通过 compute 生成 的 stmt // {<tir.ProducerStore>compute[{<tir.Var>x }] ={<tir.Add>({<tir.ProducerLoad>A[{<tir.Var>x}]} + {<tir.ProducerLoad>B[{<tir.Var>x}]})} for (size_t i = 0; i < self->body.size(); ++i) { provides.emplace_back(MakeProvide(self, stage->op.output(i))); } // MergeNest 最终根据 bounds (即 dom_map) 生成额外的循环语句, // MergeNest 返回的 stmt 为: // // {<tir.For>for ({<tir.Var>x.outer}, {<In tImm>0}, {<IntImm>2}) { // {<tir.AttrStmt> // attr [{<tir.IterV ar>iter_var(x.outer, )}] loop_scope = {<tir.Var>x.outer} // {<tir.For> for ({<tir.Var>x.inner}, {<IntImm>0}, {<IntImm>5}) { // {<tir.AttrStmt> // attr [{<tir.I terVar>iter_var(x.inner, )}] loop_scope = {<tir.Var>x.inner} // {<tir.ProducerStore> compute[{<t ir.Var>x}] ={<tir.Add>({<tir.ProducerLoad>A[{<tir.Var>x}]} + {<tir.ProducerLoad>B[{<tir.Var>x}]})}}} // }}}} // } provide = MergeNest(n.main_nest, provide);
1.2.3.2. TIR Optimizations
1.2.3.2.1. LowerWithPassList
ScheduleToModule 得到 tir 后, LowerWithPassList 可以对 tir 做进一步的修改和优化, 每一种修改称为一个 pass, 针对 tir 的 pass 在 tvm::tir::transform 下
其中 tir.add_lower_pass 是用户自己添加的 pass, 通过 add_lower_pass 机制, 上层可以定制生成的 tir. VTA 就是通过 add_lower_pass 添加自定义的 pass, 把一个 tir (例如 tir.add) 转换为对 vta runtime 的调用的 (例如 VTAPushALUOp)
Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for_te_schedule) { transform::PassContext pass_ctx = transform::PassContext::Current(); // Get any user-added passes Array<Array<ObjectRef>> add_lower_pass = pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>()) .value(); Array<transform::Pass> user_lower_phase0 = Array<transform::Pass>(); Array<transform::Pass> user_lower_phase1 = Array<transform::Pass>(); Array<transform::Pass> user_lower_phase2 = Array<transform::Pass>(); Array<transform::Pass> user_lower_phase3 = Array<transform::Pass>(); // phase pasees is of the form // [[phase_number, pass], [phase_number, pass]... ] for (Array<ObjectRef> phase_pass : add_lower_pass) { const IntImmNode* phase_num = phase_pass[0].as<IntImmNode>(); ICHECK(phase_num) << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer"; int phase_num_val = phase_num->value; CHECK_GE(phase_num_val, 0); const tvm::transform::PassNode* pass_node = phase_pass[1].as<tvm::transform::PassNode>(); tvm::transform::Pass pass = GetRef<tvm::transform::Pass>(pass_node); // Copy the pass into the correct phase if (phase_num_val == 0) { user_lower_phase0.push_back(pass); } else if (phase_num_val == 1) { user_lower_phase1.push_back(pass); } else if (phase_num_val == 2) { user_lower_phase2.push_back(pass); } else if (phase_num_val >= 3) { user_lower_phase3.push_back(pass); } } // Construct the pass list, inserting the user provided passes at the end of the phase // PHASE 0 Array<tvm::transform::Pass> pass_list = user_lower_phase0; // PHASE 1 if (for_te_schedule) { pass_list.push_back(tir::transform::InjectPrefetch()); pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); } else { pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::FlattenBuffer()); } pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); // Add user-defined phase-1 passes pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end()); // PHASE 2 if (!disable_loop_partition) { pass_list.push_back(tir::transform::LoopPartition()); } pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize)); pass_list.push_back(tir::transform::InjectVirtualThread()); pass_list.push_back(tir::transform::InjectDoubleBuffer()); pass_list.push_back(tir::transform::StorageRewrite()); pass_list.push_back(tir::transform::UnrollLoop()); // Add user-defined phase-2 passes pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end()); // PHASE 3 pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::RemoveNoOp()); pass_list.push_back(tir::transform::RewriteUnsafeSelect()); pass_list.push_back(tir::transform::HoistIfThenElse()); // Add user-defined phase-3 passes pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end()); if (instrument_bound_checkers) { pass_list.push_back(tir::transform::InstrumentBoundCheckers()); } return pass_list; }
1.2.3.2.2. TIR Transform
1.3. TIR -> Codegen
BuildRelay: ret_.mod = tvm::build(lowered_funcs, target_host_); // driver_api.cc runtime::Module build(const Map<Target, IRModule>& inputs_arg, const Target& target_host_arg) : for (const auto& it : inputs) { if (mdevice->functions.size() != 0) { device_modules.push_back(codegen::Build(mdevice, it.first)); runtime::Module mhost = codegen::Build(mhost_all, target_host); // codegen.cc @ src/target/ runtime::Module Build(IRModule mod, Target target) { std::string build_f_name = "target.build." + target->kind->name; const PackedFunc* bf = runtime::Registry::Get(build_f_name); return (*bf)(mod, target); // "target.build.llvm" @ src/target/llvm/llvm_module.cc TVM_REGISTER_GLOBAL("target.build.llvm") .set_body_typed([](IRModule mod, Target target) -> runtime::Module { auto n = make_object<LLVMModuleNode>(); n->Init(mod, target); return runtime::Module(n); }); void LLVMModuleNode::Init(const IRModule& mod, const Target& target): InitializeLLVM(); tm_ = GetLLVMTargetMachine(target); bool system_lib = target->GetAttr<Bool>("system-lib").value_or(Bool(false)); bool target_c_runtime = (target->GetAttr<String>("runtime").value_or("") == kTvmRuntimeCrt); ctx_ = std::make_shared<llvm::LLVMContext>(); std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_.get()); // ... for (const auto& f : funcs): cg->AddFunction(f); // .... module_ = cg->Finish();
// codegen_llvm.cc @ src/target/llvm AddFunction: void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void): function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, global_symbol.value().operator std::string(), module_.get()); this->VisitStmt(f->body); // ..... CodeGenLLVM::VisitStmt_() MakeValue() VisitExpr() VisitExpr_()
1.3.1. Example
由于 relay.build 时无法看到对应的 TIR, https://discuss.tvm.apache.org/t/capture-tensor-level-ir-and-schedule-from-relay/9630, 所以直接使用一段 TE 来展示 codegen 的过程
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # 2021-08-03 11:11 import tvm from tvm import te M = 10 N = 10 A = te.placeholder((M, N), name="A") B = te.placeholder((M, N), name="B") C = te.compute((M, N), lambda x, y: A[x, y] + B[x, y]) s = te.create_schedule(C.op) print(tvm.lower(s, [A, B, C])) f = tvm.build(s, [A, B, C], target = "c", name = "hello") # print(f.get_source())
primfn(A_1: handle, B_1: handle, compute_1: handle) -> () attr = {"global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [10, 10], []), compute: Buffer(compute_2: Pointer(float32), float32, [10, 10], []), B: Buffer(B_2: Pointer(float32), float32, [10, 10], [])} buffer_map = {A_1: A, B_1: B, compute_1: compute} { for (x: int32, 0, 10) { for (y: int32, 0, 10) { compute_2[((x*10) + y)] = ((float32*)A_2[((x*10) + y)] + (float32*)B_2[((x*10) + y)]) } } }
下面的 log 展示的是 codegen_c.cc 生成如下的 TE 的过程
for (x: int32, 0, 10) { for (y: int32, 0, 10) { compute_2[((x*10) + y)] = ((float32*)A_2[((x*10) + y)] + (float32*)B_2[((x*10) + y)]) } }
tvm/src/target/source/codegen_c.cc:920: >>VisitStmt_:ForNode tvm/src/target/source/codegen_c.cc:469: VisitExpr_: IntImmNode: 10 tvm/src/target/source/codegen_c.cc:922: extent: 10 tvm/src/target/source/codegen_c.cc:925: vid: x name_hint: x tvm/src/target/source/codegen_c.cc:920: >>VisitStmt_:ForNode tvm/src/target/source/codegen_c.cc:469: VisitExpr_: IntImmNode: 10 tvm/src/target/source/codegen_c.cc:922: extent: 10 tvm/src/target/source/codegen_c.cc:925: vid: y name_hint: y tvm/src/target/source/codegen_c.cc:754: >>VisitStmt_:StoreNode tvm/src/target/source/codegen_c.cc:526: >>VisitExpr_:ADD tvm/src/target/source/codegen_c.cc:159: >>GetBufferRef tvm/src/target/source/codegen_c.cc:526: >>VisitExpr_:ADD tvm/src/target/source/codegen_c.cc:534: >>VisitExpr_:MUL tvm/src/target/source/codegen_c.cc:469: VisitExpr_: IntImmNode: 10 tvm/src/target/source/codegen_c.cc:536: <<VisitExpr_:MUL tvm/src/target/source/codegen_c.cc:528: <<VisitExpr_:ADD tvm/src/target/source/codegen_c.cc:231: <<GetBufferRef:((float*)A)[(((x * 10) + y))] tvm/src/target/source/codegen_c.cc:159: >>GetBufferRef tvm/src/target/source/codegen_c.cc:526: >>VisitExpr_:ADD tvm/src/target/source/codegen_c.cc:534: >>VisitExpr_:MUL tvm/src/target/source/codegen_c.cc:469: VisitExpr_: IntImmNode: 10 tvm/src/target/source/codegen_c.cc:536: <<VisitExpr_:MUL tvm/src/target/source/codegen_c.cc:528: <<VisitExpr_:ADD tvm/src/target/source/codegen_c.cc:231: <<GetBufferRef:((float*)B)[(((x * 10) + y))] tvm/src/target/source/codegen_c.cc:528: <<VisitExpr_:ADD tvm/src/target/source/codegen_c.cc:758: <<VisitStmt_:StoreNode: (((float*)A)[(((x * 10) + y))] + ((float*)B)[(((x * 10) + y))]) tvm/src/target/source/codegen_c.cc:159: >>GetBufferRef tvm/src/target/source/codegen_c.cc:526: >>VisitExpr_:ADD tvm/src/target/source/codegen_c.cc:534: >>VisitExpr_:MUL tvm/src/target/source/codegen_c.cc:469: VisitExpr_: IntImmNode: 10 tvm/src/target/source/codegen_c.cc:536: <<VisitExpr_:MUL tvm/src/target/source/codegen_c.cc:528: <<VisitExpr_:ADD tvm/src/target/source/codegen_c.cc:231: <<GetBufferRef:((float*)compute)[(((x * 10) + y))] tvm/src/target/source/codegen_c.cc:935: <<VisitStmt_:ForNode tvm/src/target/source/codegen_c.cc:935: <<VisitStmt_:ForNode
1.4. Build Target
1.4.1. target 参数是如何被解析和使用的
build 时的 target 参数, 例如
llvm --device=arm_cpu --mtriple=armv7a-linux-gnueabihf,
会通过 TargetInternal::FromConfig 进行解析, 变成一个 Target 对象, 打印出来是
llvm -keys=arm_cpu,cpu -device=arm_cpu -link-params=0 -mtriple=armv7a-linux-gnueabihf
keys 中的 `arm_cpu, cpu` 是 TargetInternal::FromConfig 补上去的.
后续代码可能会根据需求访问 target 中的成员, 例如 x86 相关的 topi 会通过 "mcpu" 成员确定 simd 宽度:
def get_simd_32bit_lanes(): mcpu = tvm.target.Target.current().mcpu fp32_vec_len = 4 if target_has_avx512(mcpu): fp32_vec_len = 16 elif target_has_avx2(mcpu): fp32_vec_len = 8 return fp32_vec_len
Target 的支持的成员通过 target_kind 定义的, 主要包括:
- kind
- keys
- device
- mcpu
- mtriple
- runtime
- system-lib
TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option<Array<String>>("mattr") .add_attr_option<String>("mcpu") .add_attr_option<String>("mtriple") .add_attr_option<String>("mfloat-abi") .add_attr_option<String>("mabi") .add_attr_option<Bool>("system-lib") .add_attr_option<String>("runtime") .add_attr_option<Bool>("link-params", Bool(false)) .add_attr_option<Bool>("unpacked-api") .add_attr_option<String>("interface-api") .set_default_keys({"cpu"}); #define TVM_REGISTER_TARGET_KIND(TargetKindName, DeviceType) \ TVM_STR_CONCAT(TVM_TARGET_KIND_REGISTER_VAR_DEF, __COUNTER__) = \ ::tvm::TargetKindRegEntry::RegisterOrGet(TargetKindName) \ .set_name() \ .set_device_type(DeviceType) \ .add_attr_option<Array<String>>("keys") \ .add_attr_option<String>("tag") \ .add_attr_option<String>("device") \ .add_attr_option<String>("model") \ .add_attr_option<Array<String>>("libs") \ .add_attr_option<Target>("host") \ .add_attr_option<Integer>("from_device") } // namespace tvm
Target 中最重要的成员是 kind 和 keys:
kind 决定了 target codegen
std::string build_f_name = "target.build." + target->kind->name;
- keys 决定了查找哪些 strategy
tvm/python/tvm/target/generic_func.py::for k in target.keys:
1.4.2. LLVMTargetToString
1.4.3. arm_cpu 针对 arm 的特殊处理
1.4.3.1. vectorize
neon 最多支持 16 字节的向量操作, 所以 tvm 使用 vectorize shedule 时需要使用特定大小的 split
arm_cpu 对应的 schedule_injective 为:
def schedule_injective(outs): outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) x = outs[0] if list(s[x].op.axis): # split 大小为 4, 因为 4 个 float32 的大小即为 16 # 这里的代码有问题: 若数据类型是 int8, 这里的 split 设置为 16 性能才是最好的 # https://github.com/apache/tvm/pull/9339 (io, ii) = s[x].split(list(s[x].op.axis)[-1], 4) s[x].vectorize(ii) tvm.te.schedule.AutoInlineInjective(s) if not is_empty_shape(x.shape): schedule_injective_from_existing(s, x) return s
#!/usr/bin/env python3 # -*- coding: utf-8 -*- import tvm from tvm import relay x = relay.var("x", shape=(1, 1024), dtype="float32") y = relay.add(x, x) func = relay.Function([x], y) mod = tvm.IRModule.from_expr(func) with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build( mod, # target 中的 `-mtriple=armv7a-linux-gnueabihf -mattr=+neon` 会直接传给 llvm # (LLVMTargetToString), 其中armv7a 的 neon 是可选支持, 所以不使用 +neon # 的话 llvm 会使用 vfp. 如果 triple 改成 armv8l 则不需要指定 +neon, 因为 # neon 对于 armv8l 是必选支持 target="llvm --device=arm_cpu -mtriple=armv7a-linux-gnueabihf -mattr=+neon", params=None, ) lib.save("/tmp/a.o")
arm-linux-gnueabihf-objdump -d /tmp/a.o|tail -n 12
000002a8 <tvmgen_default_fused_add_compute_>: 2a8: e3a02000 mov r2, #0 2ac: e0813002 add r3, r1, r2 2b0: f4630aef vld1.64 {d16-d17}, [r3 :128] 2b4: e0803002 add r3, r0, r2 2b8: f2400de0 vadd.f32 q8, q8, q8 2bc: e2822010 add r2, r2, #16 2c0: e3520a01 cmp r2, #4096 ; 0x1000 2c4: f4430aef vst1.64 {d16-d17}, [r3 :128] 2c8: 1afffff7 bne 2ac <tvmgen_default_fused_add_compute_+0x4> 2cc: e12fff1e bx lr
1.4.3.2. tensorize
为了加速 arm_cpu 的 conv2d, TVM 会直接使用 neon 相关的 llvm_intrin 进行 tensorize, 例如:
- llvm.aarch64.neon.sdot
- llvm.aarch64.neon.udot
- llvm.aarch64.neon.ummla
- llvm.aarch64.neon.smmla
- llvm.aarch64.neon.umull
- llvm.aarch64.neon.smull
- llvm.aarch64.neon.saddlp
- llvm.aarch64.neon.uaddlp
- llvm.aarch64.neon.addp
- llvm.aarch64.neon.sqrdmulh
- llvm.aarch64.neon.srshl
- llvm.arm.neon.vpadd.v8i8
- llvm.arm.neon.vpadd.v8u8
- llvm.arm.neon.vpadals.v16i8.v8i16
- llvm.arm.neon.vpadalu.v16u8.v8u16
最终 llvm 会生成对 ummla/smmla/udot/sdot 等指令的调用