TVM User Pass
Table of Contents
1. TVM User Pass
下面实现的 `Broadcast` transform 是为了解决 dnnl 无法处理 broadcast 的问题, 做法是针对某些操作 (如 add), 把它的 lhs 和 rhs 按需加上 broadcast_to 操作
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # 2021-09-24 10:32 import numpy as np import tvm from tvm import relay from tvm.contrib import graph_executor from tvm.topi.utils import get_const_tuple @relay.transform.function_pass(opt_level=1) class Broadcast: def __init__(self, *args): self.supported_ops = args def transform_function(self, func, mod, ctx): obj = self class BroadcastTo(tvm.relay.ExprMutator): def infer_type(self, node): mod = tvm.IRModule.from_expr(node) mod = relay.transform.InferType()(mod) entry = mod["main"] return entry if isinstance(node, relay.Function) else entry.body def visit_call(self, call): if call.op.name not in obj.supported_ops: return super().visit_call(call) if len(call.args) != 2: raise TypeError( f"only 2 args is supported, {call.op.name} have {len(call.args)} args" ) lhs = self.visit(call.args[0]) rhs = self.visit(call.args[1]) lhs_shape = get_const_tuple(self.infer_type(lhs).checked_type.shape) rhs_shape = get_const_tuple(self.infer_type(rhs).checked_type.shape) dtype = self.infer_type(lhs).checked_type.dtype out_shape = np.broadcast(np.empty(lhs_shape), np.empty(rhs_shape)).shape if out_shape != lhs_shape: lhs = relay.op.broadcast_to(lhs, out_shape) if out_shape != rhs_shape: rhs = relay.op.broadcast_to(rhs, out_shape) return relay.expr.Call( call.op, (lhs, rhs), call.attrs, call.type_args, call.span ) return BroadcastTo().visit(func) shape_a, shape_b, shape_c = (1, 10), (1,), (1, 10) a = relay.var("a", shape=shape_a, dtype="float32") b = relay.var("b", shape=shape_b, dtype="float32") c = relay.var("c", shape=shape_c, dtype="float32") out = relay.add(a, b) out = relay.add(out, c) func = relay.Function([a, b, c], out) mod = tvm.IRModule.from_expr(func) print("----------before----------") print(mod) print("----------after----------") # tvm 的 dnnl 的 add 操作不支持自动 broadcast, 所以没有下面一句会报错 mod = Broadcast("add")(mod) mod = relay.transform.AnnotateTarget("dnnl")(mod) mod = relay.transform.PartitionGraph()(mod) print(mod) with tvm.transform.PassContext(opt_level=0): lib = relay.build(mod, target="llvm", params=None) rt_mod = graph_executor.GraphModule(lib["default"](tvm.cpu(0))) x, y, z = np.ones(shape_a), np.ones(shape_b), np.ones(shape_c) rt_mod.set_input("a", x) rt_mod.set_input("b", y) rt_mod.set_input("c", z) rt_mod.run() result = rt_mod.get_output(0).numpy() print(result)
-----–—before-----–— def @main(%a: Tensor[(1, 10), float32], %b: Tensor[(1), float32], %c: Tensor[(1, 10), float32]) { %0 = add(%a, %b); add(%0, %c) }
-----–—after-----–— def @main(%a: Tensor[(1, 10), float32], %b: Tensor[(1), float32], %c: Tensor[(1, 10), float32]) -> Tensor[(1, 10), float32] { %0 = broadcast_to(%b, shape=[1, 10], dtype="") * ty=Tensor[(1, 10), float32] *; %1 = @tvmgen_default_dnnl_main_0(%a, %0) * ty=Tensor[(1, 10), float32] *; @tvmgen_default_dnnl_main_2(%1, %c) * ty=Tensor[(1, 10), float32] * }
def @tvmgen_default_dnnl_main_0(%dnnl_0_i0: Tensor[(1, 10), float32], %dnnl_0_i1: Tensor[(1, 10), float32], Inline=1, Compiler="dnnl", global_symbol="tvmgen_default_dnnl_main_0", Primitive=1) -> Tensor[(1, 10), float32] { add(%dnnl_0_i0, %dnnl_0_i1) * ty=Tensor[(1, 10), float32] * }
def @tvmgen_default_dnnl_main_2(%dnnl_2_i0: Tensor[(1, 10), float32], %dnnl_2_i1: Tensor[(1, 10), float32], Inline=1, Compiler="dnnl", global_symbol="tvmgen_default_dnnl_main_2", Primitive=1) -> Tensor[(1, 10), float32] { add(%dnnl_2_i0, %dnnl_2_i1) * ty=Tensor[(1, 10), float32] * }