Relay Transform

Table of Contents

1. Relay Transform

tvm/tests/python/relay

1.1. relay::qnn::transform::Legalize

这个 transform 用来把 qnn.xxx 转换成 relay IR relay.qnn

1.2. RemoveUnusedFunctions

从 entry_functions 开始, 递归的标记所有 GlobalVarNode (函数引用), 然后删除没有标记到的函数

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 2021-08-03 11:11
import tvm
from tvm import relay

mod = tvm.IRModule()

x = relay.var("x", shape=(1, 1000))
mod["fn1"] = relay.Function([x], x)

y = relay.var("y", shape=(1, 1000))
mod["fn2"] = relay.Function([y], y)

z = relay.var("z", shape=(1, 1000))
mod["main"] = relay.Function([z], relay.add(z, relay.GlobalVar("fn1")(z)))

print("----------before----------")
print(mod.get_global_vars())
print(mod)

print("----------after----------")
mod = relay.transform.RemoveUnusedFunctions()(mod)
print(mod)

-----–—before-----–— [GlobalVar(fn1), GlobalVar(fn2), GlobalVar(main)] def @fn1(%x: Tensor[(1, 1000), float32]) { %x }

def @fn2(%y: Tensor[(1, 1000), float32]) { %y }

def @main(%z: Tensor[(1, 1000), float32]) { %0 = @fn1(%z); add(%z, %0) }

-----–—after-----–— def @fn1(%x: Tensor[(1, 1000), float32]) { %x }

def @main(%z: Tensor[(1, 1000), float32]) { %0 = @fn1(%z); add(%z, %0) }

1.3. SimplifyInference

  1. 把 batchnorm 转换为 multiply/add
  2. 去掉 dropout
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 2021-08-03 11:11
import tvm
from tvm import relay

mod = tvm.IRModule()

x = relay.var("x", shape=(1, 10))
alpha, gamma, mean, var = (
    relay.var("x", shape=(10,)),
    relay.var("y", shape=(10,)),
    relay.var("a", shape=(10,)),
    relay.var("b", shape=(10,)),
)

# fmt:off
bn,_,_, = relay.nn.batch_norm(x, alpha, gamma, mean, var)
# fmt:on

mod["main"] = relay.Function(
    [x, alpha, gamma, mean, var],
    relay.nn.dropout(bn),
)

print("----------before----------")
print(mod)

print("----------after----------")
mod = relay.transform.InferType()(mod)
mod = relay.transform.SimplifyInference()(mod)
print(mod)

-----–—before-----–— def @main(%x: Tensor[(1, 10), float32], %x1: Tensor[(10), float32], %y: Tensor[(10), float32], %a: Tensor[(10), float32], %b: Tensor[(10), float32]) { %0 = nn.batch_norm(%x, %x1, %y, %a, %b); %1 = %0.0; %2 = nn.dropout(%1); %2.0 }

-----–—after-----–— def @main(%x: Tensor[(1, 10), float32], %x1: Tensor[(10), float32], %y: Tensor[(10), float32], %a: Tensor[(10), float32], %b: Tensor[(10), float32]) -> Tensor[(1, 10), float32] { %0 = add(%b, 1e-05f * ty=float32 *) * ty=Tensor[(10), float32] *; %1 = sqrt(%0) * ty=Tensor[(10), float32] *; %2 = divide(1f * ty=float32 *, %1) * ty=Tensor[(10), float32] *; %3 = multiply(%2, %x1) * ty=Tensor[(10), float32] *; %4 = negative(%a) * ty=Tensor[(10), float32] *; %5 = multiply(%4, %3) * ty=Tensor[(10), float32] *; %6 = multiply(%x, %3) * ty=Tensor[(1, 10), float32] *; %7 = add(%5, %y) * ty=Tensor[(10), float32] *; add(%6, %7) * ty=Tensor[(1, 10), float32] * }

if (const auto* call = new_n->tuple.as<CallNode>()) {
    if (call->op == batch_norm_op_) {
        return BatchNormToInferUnpack(
            call->attrs, call->args[0], call->args[1], call->args[2],
            call->args[3], call->args[4], ty_map_.at(call->args[0]));
    } else if (call->op == dropout_op_) {
        return call->args[0];
    }
}

Expr BatchNormToInferUnpack(
    const Attrs attrs, Expr data, Expr gamma, Expr beta, Expr moving_mean,
    Expr moving_var, Type tdata) {
    if (param->scale) {
        scale = Multiply(scale, gamma);
    }
    Expr neg_mean = Negative(moving_mean);
    Expr shift = Multiply(neg_mean, scale);
    if (param->center) {
        shift = Add(shift, beta);
    }

    Expr out = Multiply(data, scale);
    out = Add(out, shift);
    return out;
}

1.4. Inline

函数可以被 inline 的条件:

  1. 有 `Inline` 属性
  2. 不是递归函数
  3. 调用的其它函数也是可以 inline 的
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 2021-08-03 11:11
import tvm
from tvm import relay

mod = tvm.IRModule()

x1 = relay.var("x1", shape=(1, 10))
fn1 = relay.Function([x1], x1)
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1

x2 = relay.var("x2", shape=(1, 10))
fn2 = relay.Function([x2], x2)
# fn2 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
g2 = relay.GlobalVar("g2")
mod[g2] = fn2

p0 = relay.var("p0", shape=(1, 10))
mod["main"] = relay.Function([p0], relay.add(g1(p0), g2(p0)))

print("----------before----------")
print(mod)
mod = relay.transform.Inline()(mod)
print("----------after----------")
print(mod)

-----–—before-----–— def @g1(%x1: Tensor[(1, 10), float32], Inline=1) { %x1 }

def @g2(%x2: Tensor[(1, 10), float32]) { %x2 }

def @main(%p0: Tensor[(1, 10), float32]) { %0 = @g1(%p0); %1 = @g2(%p0); add(%0, %1) }

-----–—after-----–— def @main(%p0: Tensor[(1, 10), float32]) { %0 = @g2(%p0); add(%p0, %0) }

def @g2(%x2: Tensor[(1, 10), float32]) { %x2 }

1.5. RunDeviceAnnotationPass

RunDeviceAnnotationPass 是为了处理 on_device annotation, vta build 时的 graph_pack 依赖 build 时的 RunDeviceAnnotationPass 才能工作

NOTE: 新的代码里相关功能放在了 PlanDevices 里了

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 2021-08-03 11:11
import tvm
from tvm import relay

x = relay.var("x", shape=(1, 10))
y = relay.var("y", shape=(1, 10))
add = relay.add(x, y)
sqrt = relay.sqrt(add)
_sqrt = relay.annotation.on_device(sqrt, "cuda")
log = relay.log(add)
subtract = relay.subtract(_sqrt, log)
exp = relay.exp(subtract)
_exp = relay.annotation.on_device(exp, "cuda")

func = relay.Function([x, y], _exp)
mod = tvm.IRModule.from_expr(func)

print("----------before----------")
print(mod)

print("----------after----------")
mod = relay.transform.RewriteAnnotatedOps(1)(mod)
print(mod["main"])

-----–—before-----–— def @main(%x: Tensor[(1, 10), float32], %y: Tensor[(1, 10), float32]) { %0 = add(%x, %y); %1 = sqrt(%0); %2 = on_device(%1, meta[relay.attrs.OnDeviceAttrs][0]); %3 = log(%0); %4 = subtract(%2, %3); %5 = exp(%4); on_device(%5, meta[relay.attrs.OnDeviceAttrs][1]) }

-----–—after-----–— fn (%x: Tensor[(1, 10), float32], %y: Tensor[(1, 10), float32]) -> Tensor[(1, 10), float32] { %0 = add(%x, %y) * ty=Tensor[(1, 10), float32] *; %1 = device_copy(%0, meta[relay.attrs.DeviceCopyAttrs][0]) * ty=Tensor[(1, 10), float32] *; %2 = sqrt(%1) * ty=Tensor[(1, 10), float32] *; %3 = device_copy(%2, meta[relay.attrs.DeviceCopyAttrs][1]) * ty=Tensor[(1, 10), float32] *; %4 = log(%0) * ty=Tensor[(1, 10), float32] *; %5 = subtract(%3, %4) * ty=Tensor[(1, 10), float32] *; %6 = device_copy(%5, meta[relay.attrs.DeviceCopyAttrs][2]) * ty=Tensor[(1, 10), float32] *; exp(%6) * ty=Tensor[(1, 10), float32] * }

1.5.1. device_copy

device_copy 是由 在 LowerTE 时处理的:

  1. 一方面它被用来确定 expr 所在的 target, 以确定 topi strategy
  2. 另一方面运行时 graph_executor 会根据这个标记做真正的数据拷贝.

tvm/src/relay/backend/te_compiler_cache.cc::// Set the name to `__copy`. It will be detected in graph runtime to perform

例如, graph_executor 在执行时碰到 `__copy` 时会调用设备相关的 CopyDataFromTo 等函数, 对于 opencl 来说就是 clEnqueueCopyBuffer:

void OpenCLWorkspace::CopyDataFromTo(
    DLTensor* from, DLTensor* to, TVMStreamHandle stream) {
    if (IsOpenCLDevice(from->device) && IsOpenCLDevice(to->device)) {
        const auto* from_desc =
            static_cast<const cl::BufferDescriptor*>(from->data);
        auto* to_desc = static_cast<cl::BufferDescriptor*>(to->data);
        clEnqueueCopyBuffer(
            this->GetQueue(to->device), from_desc->buffer, to_desc->buffer,
            from->byte_offset, to->byte_offset, nbytes, 0, nullptr, nullptr);
    }
    // ...
}

除了上面 __copy, 使用 opencl 等 target 在 set_input/get_output 时也会通过 CopyDataFromTo 与设备交换数据.

1.6. FuseOps

1.6.1. Example

1.6.1.1. 不使用 FuseOps:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 2021-09-08 18:19
import tvm
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import run_opt_pass
import numpy as np


def test_fuse_simple():
    """Simple testcase."""

    def before():
        x = relay.var("x", shape=(10, 20))
        y = relay.add(x, x)
        y = relay.add(y, y)
        y = relay.add(y, y)
        y = relay.add(y, y)
        y = relay.add(y, y)
        z = relay.exp(y)
        return relay.Function([x], z)

    z = before()
    print(z)
    # z = run_opt_pass(z, transform.FuseOps())
    # print(z)
    with tvm.transform.PassContext(opt_level=0):
        graph, lib, params = relay.build(z, target="llvm", params=None)

    print(graph)


if __name__ == "__main__":
    test_fuse_simple()

fn (%x: Tensor[(10, 20), float32]) { %0 = add(%x, %x); %1 = add(%0, %0); %2 = add(%1, %1); %3 = add(%2, %2); %4 = add(%3, %3); exp(%4) } { "nodes": [ { "op": "null", "name": "x", "inputs": [] }, { "op": "tvm_op", "name": "tvmgen_default_fused_add", "attrs": { "num_outputs": "1", "num_inputs": "1", "flatten_data": "0", "func_name": "tvmgen_default_fused_add", "hash": "aadf70b47b6beaf4" }, "inputs": [ [ 0, 0, 0 ] ] }, { "op": "tvm_op", "name": "tvmgen_default_fused_add1", "attrs": { "num_outputs": "1", "num_inputs": "1", "flatten_data": "0", "func_name": "tvmgen_default_fused_add", "hash": "aadf70b47b6beaf4" }, "inputs": [ [ 1, 0, 0 ] ] }, { "op": "tvm_op", "name": "tvmgen_default_fused_add2", "attrs": { "num_outputs": "1", "num_inputs": "1", "flatten_data": "0", "func_name": "tvmgen_default_fused_add", "hash": "aadf70b47b6beaf4" }, "inputs": [ [ 2, 0, 0 ] ] }, { "op": "tvm_op", "name": "tvmgen_default_fused_add3", "attrs": { "num_outputs": "1", "num_inputs": "1", "flatten_data": "0", "func_name": "tvmgen_default_fused_add", "hash": "aadf70b47b6beaf4" }, "inputs": [ [ 3, 0, 0 ] ] }, { "op": "tvm_op", "name": "tvmgen_default_fused_add4", "attrs": { "num_outputs": "1", "num_inputs": "1", "flatten_data": "0", "func_name": "tvmgen_default_fused_add", "hash": "aadf70b47b6beaf4" }, "inputs": [ [ 4, 0, 0 ] ] }, { "op": "tvm_op", "name": "tvmgen_default_fused_exp", "attrs": { "num_outputs": "1", "num_inputs": "1", "flatten_data": "0", "func_name": "tvmgen_default_fused_exp", "hash": "de3b50c71256954a" }, "inputs": [ [ 5, 0, 0 ] ] } ], "arg_nodes": [0], "heads": [ [ 6, 0, 0 ] ], "attrs": { "dltype": [ "list_str", [ "float32", "float32", "float32", "float32", "float32", "float32", "float32" ] ], "device_index": [ "list_int", [1, 1, 1, 1, 1, 1, 1] ], "storage_id": [ "list_int", [0, 1, 2, 1, 2, 1, 2] ], "shape": [ "list_shape", [ [10, 20], [10, 20], [10, 20], [10, 20], [10, 20], [10, 20], [10, 20] ] ] }, "node_row_ptr": [0, 1, 2, 3, 4, 5, 6, 7] }

可以看过每个 add 操作都会对应一个 graph 中的 node

1.6.1.2. 使用 FuseOps
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 2021-09-08 18:19
import tvm
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import run_opt_pass

import numpy as np


def test_fuse_simple():
    """Simple testcase."""

    def before():
        x = relay.var("x", shape=(10, 20))
        y = relay.add(x, x)
        y = relay.add(y, y)
        y = relay.add(y, y)
        y = relay.add(y, y)
        y = relay.add(y, y)
        z = relay.exp(y)
        return relay.Function([x], z)

    z = before()
    z = run_opt_pass(z, transform.FuseOps())
    print(z)
    with tvm.transform.PassContext(opt_level=0):
        graph, lib, params = relay.build(z, target="llvm", params=None)

    print(graph)


if __name__ == "__main__":
    test_fuse_simple()

fn (%x: Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] { %5 = fn (%p0: Tensor[(10, 20), float32], Primitive=1) -> Tensor[(10, 20), float32] { %0 = add(%p0, %p0) * ty=Tensor[(10, 20), float32] *; %1 = add(%0, %0) * ty=Tensor[(10, 20), float32] *; %2 = add(%1, %1) * ty=Tensor[(10, 20), float32] *; %3 = add(%2, %2) * ty=Tensor[(10, 20), float32] *; %4 = add(%3, %3) * ty=Tensor[(10, 20), float32] *; exp(%4) * ty=Tensor[(10, 20), float32] * }; %5(%x) * ty=Tensor[(10, 20), float32] * } { "nodes": [ { "op": "null", "name": "x", "inputs": [] }, { "op": "tvm_op", "name": "tvmgen_default_fused_add_add_add_add_add_exp", "attrs": { "num_outputs": "1", "num_inputs": "1", "flatten_data": "0", "func_name": "tvmgen_default_fused_add_add_add_add_add_exp", "hash": "bcd04c940b541895" }, "inputs": [ [ 0, 0, 0 ] ] } ], "arg_nodes": [0], "heads": [ [ 1, 0, 0 ] ], "attrs": { "dltype": [ "list_str", [ "float32", "float32" ] ], "device_index": [ "list_int", [1, 1] ], "storage_id": [ "list_int", [0, 1] ], "shape": [ "list_shape", [ [10, 20], [10, 20] ] ] }, "node_row_ptr": [0, 1, 2] }

所有的 op 都被合并到同一个函数 tvmgen_default_fused_add_add_add_add_add_exp

1.6.2. 为什么需要 FuseOps

通过 FuseOps 显然可以减少函数调用的开销. 那么, 每次都把所有的 op 都合并到一个函数中不就可以了么?

每个 op 都有一个 schedule, 如果 schedule 不是 `相容` 的, 则它们是无法合并在一起的, 例如:

import tvm
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import run_opt_pass
import numpy as np


def test_fuse_simple():
    def before():
        x = relay.var("x", shape=(20, 20))
        y1 = relay.add(x, x)
        y2 = relay.add(y1, y1)
        y3 = relay.sum(y2, axis=-1)
        y = relay.add(y3, y3)

        return relay.Function([x], y)

    z = before()
    z = run_opt_pass(z, transform.FuseOps())
    print(z)

    with tvm.transform.PassContext(opt_level=0):
        graph, lib, params = relay.build(z, target="c", params=None)

    print(lib.get_source())


if __name__ == "__main__":
    test_fuse_simple()

fn (%x: Tensor[(20, 20), float32]) -> Tensor[(20), float32] { %2 = fn (%p01: Tensor[(20, 20), float32], Primitive=1) -> Tensor[(20), float32] { %0 = add(%p01, %p01) * ty=Tensor[(20, 20), float32] *; %1 = add(%0, %0) * ty=Tensor[(20, 20), float32] *; sum(%1, axis=[-1]) * ty=Tensor[(20), float32] * }; %3 = %2(%x) * ty=Tensor[(20), float32] *; %4 = fn (%p0: Tensor[(20), float32], Primitive=1) -> Tensor[(20), float32] { add(%p0, %p0) * ty=Tensor[(20), float32] * }; %4(%3) * ty=Tensor[(20), float32] * } // tvm target: c -keys=cpu -link-params=0 #define TVM_EXPORTS #include "tvm/runtime/c_runtime_api.h" #include "tvm/runtime/c_backend_api.h" #include <math.h> #ifdef __cplusplus extern "C" #endif TVM_DLL int32_t tvmgen_default_fused_add_add_sum(void* args, void* arg_type_ids, int32_t num_args, void* out_ret_value, void* out_ret_tcode, void* resource_handle) { void* arg0 = (((TVMValue*)args)[0].v_handle); int32_t arg0_code = ((int32_t*)arg_type_ids)[(0)]; void* arg1 = (((TVMValue*)args)[1].v_handle); int32_t arg1_code = ((int32_t*)arg_type_ids)[(1)]; void* placeholder = (((DLTensor*)arg0)[0].data); void* arg0_shape = (((DLTensor*)arg0)[0].shape); void* arg0_strides = (((DLTensor*)arg0)[0].strides); int32_t dev_id = (((DLTensor*)arg0)[0].device.device_id); void* T_add_red = (((DLTensor*)arg1)[0].data); void* arg1_shape = (((DLTensor*)arg1)[0].shape); void* arg1_strides = (((DLTensor*)arg1)[0].strides); if (!(arg0_strides = NULL)) { } if (!(arg1_strides = NULL)) { } for (int32_t ax0 = 0; ax0 < 20; ++ax0) { ((float*)T_add_red)[(ax0)] = 0.000000e+00f; for (int32_t k1 = 0; k1 < 20; ++k1) { ((float*)T_add_red)[(ax0)] = (((float*)T_add_red)[(ax0)] + ((((float*)placeholder)[(((ax0 * 20) + k1))] + ((float*)placeholder)[(((ax0 * 20) + k1))]) + (((float*)placeholder)[(((ax0 * 20) + k1))] + ((float*)placeholder)[(((ax0 * 20) + k1))]))); } } return 0; }

#ifdef __cplusplus extern "C" #endif TVM_DLL int32_t tvmgen_default_fused_add(void* args, void* arg_type_ids, int32_t num_args, void* out_ret_value, void* out_ret_tcode, void* resource_handle) { void* arg0 = (((TVMValue*)args)[0].v_handle); int32_t arg0_code = ((int32_t*)arg_type_ids)[(0)]; void* arg1 = (((TVMValue*)args)[1].v_handle); int32_t arg1_code = ((int32_t*)arg_type_ids)[(1)]; void* placeholder = (((DLTensor*)arg0)[0].data); void* arg0_shape = (((DLTensor*)arg0)[0].shape); void* arg0_strides = (((DLTensor*)arg0)[0].strides); int32_t dev_id = (((DLTensor*)arg0)[0].device.device_id); void* T_add = (((DLTensor*)arg1)[0].data); void* arg1_shape = (((DLTensor*)arg1)[0].shape); void* arg1_strides = (((DLTensor*)arg1)[0].strides); if (!(arg0_strides = NULL)) { } if (!(arg1_strides = NULL)) { } for (int32_t ax0_outer = 0; ax0_outer < 2; ++ax0_outer) { for (int32_t ax0_inner_s = 0; ax0_inner_s < 16; ++ax0_inner_s) { if (((ax0_outer * 16) + ax0_inner_s) < 20) { ((float*)T_add)[(((ax0_outer * 16) + ax0_inner_s))] = (((float*)placeholder)[(((ax0_outer * 16) + ax0_inner_s))] + ((float*)placeholder)[(((ax0_outer * 16) + ax0_inner_s))]); } } } return 0; }

y1, y2, y3 三个操作使用的 schedule 都是要处理一个 20x20 的循环, 但 y 的 schdule 处理的是 20x1 的循环, 所以它们无法合并成一个 op.

1.7. FoldScaleAxis

FoldScaleAxis 可以把 Conv2D 和 BatchNorm 融合在一起.

算子融合的好处:

  1. 减少算子间的中间内存分配
  2. 两个算子的 loop 可以做成同一个 loop

FoldScaleAxis 除了上面的好处, 还有一个原因是许多加速器并不支持 scalar multiplication, 例如 BatchNorm

TVM 的 FoldScaleAxis 实际上包含三个 transform:

Pass FoldScaleAxis() {
    Pass pass = Sequential(
        {BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()},
        "FoldScaleAxis");
    return pass;
}

其中:

  • BackwardFoldScaleAxis 是把某个 Op (例如 Conv2D) 后面的 scale 与 Op fold 在一起
  • ForwardFoldScaleAxis 是把 Op 前面的 scale 与 Op fold
  • ForwardFoldScaleAxis 或 BackwardFoldScaleAxis 只是负责把 x*scale 或 y*scale 变成 w*scale, 最终还需要 FoldConstant 把 w*scale fold 成一个 constant (前提是 w, scale 都是常量)

1.7.1. Example

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 2021-10-14 22:36
import numpy as np

import tvm
from tvm import te
from tvm import relay
from tvm.relay import transform

I = N = 1
O = C = 2
H = W = 3


def run_opt_pass(expr, opt_pass):
    assert isinstance(opt_pass, tvm.transform.Pass)
    mod = tvm.IRModule.from_expr(expr)
    mod = opt_pass(mod)
    entry = mod["main"]
    return entry if isinstance(expr, relay.Function) else entry.body


def test_fold():
    def get_model(x, weight, scale):
        args = [x]
        # x = relay.multiply(x, scale)
        y = relay.nn.conv2d(
            x,
            weight,
            channels=C,
            kernel_size=(3, 3),
            padding=(1, 1),
        )
        y = relay.multiply(y, scale)

        return relay.Function(args, y)

    def check(shape):
        x = relay.var("x", shape=shape)

        weight = relay.const(np.random.randn(O, I, H, W).astype("float32"))
        scale = relay.const(np.random.randn(C, 1, 1).astype("float32"))

        y = get_model(x, weight, scale)
        y = run_opt_pass(y, transform.InferType())
        print(y)

        y_folded = run_opt_pass(y, transform.BackwardFoldScaleAxis())
        print(y_folded)

        y_folded = run_opt_pass(y_folded, transform.FoldConstant())
        print(y_folded)

    check((1, N, 10, 10))


if __name__ == "__main__":
    test_fold()

fn (%x: Tensor[(1, 1, 10, 10), float32]) -> Tensor[(1, 2, 10, 10), float32] { %0 = nn.conv2d(%x, meta[relay.Constant][0] * ty=Tensor[(2, 1, 3, 3), float32] *, padding=[1, 1, 1, 1], channels=2, kernel_size=[3, 3]) * ty=Tensor[(1, 2, 10, 10), float32] *; multiply(%0, meta[relay.Constant][1] * ty=Tensor[(2, 1, 1), float32] *) * ty=Tensor[(1, 2, 10, 10), float32] * }

fn (%x: Tensor[(1, 1, 10, 10), float32]) -> Tensor[(1, 2, 10, 10), float32] { %0 = squeeze(meta[relay.Constant][1] * ty=Tensor[(2, 1, 1), float32] *, axis=[1, 2]) * ty=Tensor[(2), float32] *; %1 = expand_dims(%0, axis=1, num_newaxis=3) * ty=Tensor[(2, 1, 1, 1), float32] *; %2 = multiply(meta[relay.Constant][0] * ty=Tensor[(2, 1, 3, 3), float32] *, %1) * ty=Tensor[(2, 1, 3, 3), float32] *; nn.conv2d(%x, %2, padding=[1, 1, 1, 1], channels=2, kernel_size=[3, 3]) * ty=Tensor[(1, 2, 10, 10), float32] * }

fn (%x: Tensor[(1, 1, 10, 10), float32]) -> Tensor[(1, 2, 10, 10), float32] { nn.conv2d(%x, meta[relay.Constant][0] * ty=Tensor[(2, 1, 3, 3), float32] *, padding=[1, 1, 1, 1], channels=2, kernel_size=[3, 3]) * ty=Tensor[(1, 2, 10, 10), float32] * }

1.7.2. Impl

// multiply 会先于 conv2d 被遍历到, 记下 scale
Expr MultiplyBackwardTransform(
    const Call& call, const Message& message, const Expr& scale,
    const BackwardTransformer& transformer) {
    ICHECK(!message.defined()) << "outstanding scale";
    const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
    const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
    Message lhs_message = transformer->GetMessage(call->args[0]);
    Message rhs_message = transformer->GetMessage(call->args[1]);
    if (lhs_message.defined()) {
        Expr rhs = call->args[1];
        if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_message->axes, &rhs) &&
            (!lhs_message->require_positive || IsAllPositiveConstant(rhs))) {
            // 例如, mul(conv2d, scale), call->args[0] 即 conv2d, rhs 即 scale
            return transformer->Transform(call->args[0], lhs_message, rhs);
        }
    } else if (rhs_message.defined()) {
        Expr lhs = call->args[0];
        if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_message->axes, &lhs) &&
            (!rhs_message->require_positive || IsAllPositiveConstant(lhs))) {
            // mul(scale,conv2d), call->args[1] 即 conv2d, lhs 即 scale
            return transformer->Transform(call->args[1], rhs_message, lhs);
        }
    }
    return transformer->NormalCallTransform(call.operator->());
}

Expr Conv2DBackwardTransform(
    const Call& call, const Message& message, const Expr& scale,
    const BackwardTransformer& transformer) {
    // ...
    Expr data = transformer->Transform(
        call->args[0], NullValue<Message>(), NullValue<Expr>());
    Expr weight = transformer->Transform(
        call->args[1], NullValue<Message>(), NullValue<Expr>());
    // scale on input for deptwise.
    Expr wscale =
        ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_ko_axis});
    weight = Multiply(weight, wscale);
    return Call(call->op, {data, weight}, call->attrs, call->type_args);
}

1.8. CombineParallelDense

FuseOps, FoldScaleAxis 等 transform 属于纵向的融合, 而 CombineXXX 属于横向的融合

1.8.1. Example

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 2021-10-14 22:36
import numpy as np

import tvm
from tvm import te
from tvm import relay
from tvm.relay import transform


def run_opt_pass(expr, opt_pass):
    assert isinstance(opt_pass, tvm.transform.Pass)
    mod = tvm.IRModule.from_expr(expr)
    mod = tvm.relay.transform.InferType()(mod)
    mod = opt_pass(mod)
    return mod["main"]


def test_combine_parallel_dense():
    """Simple testcase. One dense cannot be combined due to shape mismatch"""

    def before(x, w1, w2):
        args = [x, w1, w2]
        y1 = relay.nn.dense(x, w1)
        y2 = relay.nn.dense(x, w2)

        y = relay.Tuple((y1, y2))
        return relay.Function(args, y)

    def check(i, j, k):
        x = relay.var("x", shape=(i, k))
        w1 = relay.var("w1", shape=(j, k))
        w2 = relay.var("w2", shape=(j, k))

        y_before = before(x, w1, w2)
        print(y_before)
        y = run_opt_pass(y_before, transform.CombineParallelDense(min_num_branches=2))
        print(y)

    # 3x4 分别和两个 4x5 相乘, 得到两个 3x5
    check(3, 5, 4)


if __name__ == "__main__":
    test_combine_parallel_dense()

fn (%x: Tensor[(3, 4), float32], %w1: Tensor[(5, 4), float32], %w2: Tensor[(5, 4), float32]) { %0 = nn.dense(%x, %w1, units=None); %1 = nn.dense(%x, %w2, units=None); (%0, %1) } fn (%x: Tensor[(3, 4), float32], %w1: Tensor[(5, 4), float32], %w2: Tensor[(5, 4), float32]) -> (Tensor[(3, 5), float32], Tensor[(3, 5), float32]) { %0 = (%x, %x); %1 = (%w1, %w2); %2 = stack(%0) * ty=Tensor[(2, 3, 4), float32] *; %3 = stack(%1) * ty=Tensor[(2, 5, 4), float32] *; %4 = nn.batch_matmul(%2, %3, transpose_b=True) * ty=Tensor[(2, 3, 5), float32] *; %5 = split(%4, indices_or_sections=2) * ty=(Tensor[(1, 3, 5), float32], Tensor[(1, 3, 5), float32]) *; %6 = %5.0; %7 = %5.1; %8 = squeeze(%6, axis=[0]) * ty=Tensor[(3, 5), float32] *; %9 = squeeze(%7, axis=[0]) * ty=Tensor[(3, 5), float32] *; (%8, %9) }

1.9. CombineParallelConv2D

1.9.1. Example

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 2021-10-22 16:46

import tvm
from tvm import relay
from tvm.relay import transform


def test_combine_parallel_conv2d():
    def before(x, w1, w2):
        args = [x, w1, w2]
        y1 = relay.nn.conv2d(x, w1)
        y2 = relay.nn.conv2d(x, w2)
        y3 = relay.nn.max_pool2d(x)
        y = relay.Tuple((y1, y2, y3))
        func = relay.Function(args, y)
        mod = tvm.IRModule.from_expr(func)
        mod = tvm.relay.transform.InferType()(mod)
        return mod

    def check(x_shape, channels1, channels2):
        x = relay.var("x", shape=x_shape)
        in_c = x_shape[1]
        w1 = relay.var("w1", shape=(channels1, in_c, 1, 1))
        w2 = relay.var("w2", shape=(channels2, in_c, 1, 1))

        mod = before(x, w1, w2)
        print("------before------")
        print(mod)
        mod = transform.CombineParallelConv2D(min_num_branches=2)(mod)
        print("------after------")
        print(mod)

    check((1, 4, 16, 16), 4, 4)


if __name__ == "__main__":
    test_combine_parallel_conv2d()

-–—before-–— def @main(%x: Tensor[(1, 4, 16, 16), float32], %w1: Tensor[(4, 4, 1, 1), float32], %w2: Tensor[(4, 4, 1, 1), float32]) -> (Tensor[(1, 4, 16, 16), float32], Tensor[(1, 4, 16, 16), float32], Tensor[(1, 4, 16, 16), float32]) { %0 = nn.conv2d(%x, %w1, padding=[0, 0, 0, 0]) * ty=Tensor[(1, 4, 16, 16), float32] *; %1 = nn.conv2d(%x, %w2, padding=[0, 0, 0, 0]) * ty=Tensor[(1, 4, 16, 16), float32] *; %2 = nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0]) * ty=Tensor[(1, 4, 16, 16), float32] *; (%0, %1, %2) }

-–—after-–— def @main(%x: Tensor[(1, 4, 16, 16), float32], %w1: Tensor[(4, 4, 1, 1), float32], %w2: Tensor[(4, 4, 1, 1), float32]) -> (Tensor[(1, 4, 16, 16), float32], Tensor[(1, 4, 16, 16), float32], Tensor[(1, 4, 16, 16), float32]) { %0 = (%w1, %w2); %1 = concatenate(%0) * ty=Tensor[(8, 4, 1, 1), float32] *; %2 = nn.conv2d(%x, %1, padding=[0, 0, 0, 0], channels=8) * ty=Tensor[(1, 8, 16, 16), float32] *; %3 = strided_slice(%2, begin=[0, 0], end=[-1, 4], strides=[1, 1], slice_mode="size", axes=None) * ty=Tensor[(1, 4, 16, 16), float32] *; %4 = strided_slice(%2, begin=[0, 4], end=[-1, 4], strides=[1, 1], slice_mode="size", axes=None) * ty=Tensor[(1, 4, 16, 16), float32] *; %5 = nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0]) * ty=Tensor[(1, 4, 16, 16), float32] *; (%3, %4, %5) }

Author: [email protected]
Date: 2021-10-25 Mon 00:00
Last updated: 2023-12-01 Fri 18:28

知识共享许可协议