TVM User Pass

Table of Contents

1. TVM User Pass

下面实现的 `Broadcast` transform 是为了解决 dnnl 无法处理 broadcast 的问题, 做法是针对某些操作 (如 add), 把它的 lhs 和 rhs 按需加上 broadcast_to 操作

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 2021-09-24 10:32
import numpy as np
import tvm
from tvm import relay
from tvm.contrib import graph_executor
from tvm.topi.utils import get_const_tuple


@relay.transform.function_pass(opt_level=1)
class Broadcast:
    def __init__(self, *args):
        self.supported_ops = args

    def transform_function(self, func, mod, ctx):
        obj = self

        class BroadcastTo(tvm.relay.ExprMutator):
            def infer_type(self, node):
                mod = tvm.IRModule.from_expr(node)
                mod = relay.transform.InferType()(mod)
                entry = mod["main"]
                return entry if isinstance(node, relay.Function) else entry.body

            def visit_call(self, call):
                if call.op.name not in obj.supported_ops:
                    return super().visit_call(call)

                if len(call.args) != 2:
                    raise TypeError(
                        f"only 2 args is supported, {call.op.name} have {len(call.args)} args"
                    )
                lhs = self.visit(call.args[0])
                rhs = self.visit(call.args[1])
                lhs_shape = get_const_tuple(self.infer_type(lhs).checked_type.shape)
                rhs_shape = get_const_tuple(self.infer_type(rhs).checked_type.shape)

                dtype = self.infer_type(lhs).checked_type.dtype
                out_shape = np.broadcast(np.empty(lhs_shape), np.empty(rhs_shape)).shape

                if out_shape != lhs_shape:
                    lhs = relay.op.broadcast_to(lhs, out_shape)
                if out_shape != rhs_shape:
                    rhs = relay.op.broadcast_to(rhs, out_shape)

                return relay.expr.Call(
                    call.op, (lhs, rhs), call.attrs, call.type_args, call.span
                )

        return BroadcastTo().visit(func)


shape_a, shape_b, shape_c = (1, 10), (1,), (1, 10)
a = relay.var("a", shape=shape_a, dtype="float32")
b = relay.var("b", shape=shape_b, dtype="float32")
c = relay.var("c", shape=shape_c, dtype="float32")
out = relay.add(a, b)
out = relay.add(out, c)

func = relay.Function([a, b, c], out)
mod = tvm.IRModule.from_expr(func)

print("----------before----------")
print(mod)

print("----------after----------")
# tvm 的 dnnl 的 add 操作不支持自动 broadcast, 所以没有下面一句会报错
mod = Broadcast("add")(mod)
mod = relay.transform.AnnotateTarget("dnnl")(mod)
mod = relay.transform.PartitionGraph()(mod)
print(mod)

with tvm.transform.PassContext(opt_level=0):
    lib = relay.build(mod, target="llvm", params=None)

rt_mod = graph_executor.GraphModule(lib["default"](tvm.cpu(0)))
x, y, z = np.ones(shape_a), np.ones(shape_b), np.ones(shape_c)
rt_mod.set_input("a", x)
rt_mod.set_input("b", y)
rt_mod.set_input("c", z)
rt_mod.run()
result = rt_mod.get_output(0).numpy()
print(result)

-----–—before-----–— def @main(%a: Tensor[(1, 10), float32], %b: Tensor[(1), float32], %c: Tensor[(1, 10), float32]) { %0 = add(%a, %b); add(%0, %c) }

-----–—after-----–— def @main(%a: Tensor[(1, 10), float32], %b: Tensor[(1), float32], %c: Tensor[(1, 10), float32]) -> Tensor[(1, 10), float32] { %0 = broadcast_to(%b, shape=[1, 10], dtype="") * ty=Tensor[(1, 10), float32] *; %1 = @tvmgen_default_dnnl_main_0(%a, %0) * ty=Tensor[(1, 10), float32] *; @tvmgen_default_dnnl_main_2(%1, %c) * ty=Tensor[(1, 10), float32] * }

def @tvmgen_default_dnnl_main_0(%dnnl_0_i0: Tensor[(1, 10), float32], %dnnl_0_i1: Tensor[(1, 10), float32], Inline=1, Compiler="dnnl", global_symbol="tvmgen_default_dnnl_main_0", Primitive=1) -> Tensor[(1, 10), float32] { add(%dnnl_0_i0, %dnnl_0_i1) * ty=Tensor[(1, 10), float32] * }

def @tvmgen_default_dnnl_main_2(%dnnl_2_i0: Tensor[(1, 10), float32], %dnnl_2_i1: Tensor[(1, 10), float32], Inline=1, Compiler="dnnl", global_symbol="tvmgen_default_dnnl_main_2", Primitive=1) -> Tensor[(1, 10), float32] { add(%dnnl_2_i0, %dnnl_2_i1) * ty=Tensor[(1, 10), float32] * }

Author: [email protected]
Date: 2021-09-26 Sun 00:00
Last updated: 2022-01-24 Mon 19:34

知识共享许可协议