Relay Transform
Table of Contents
1. Relay Transform
tvm/tests/python/relay
1.1. relay::qnn::transform::Legalize
这个 transform 用来把 qnn.xxx 转换成 relay IR relay.qnn
1.2. RemoveUnusedFunctions
从 entry_functions 开始, 递归的标记所有 GlobalVarNode (函数引用), 然后删除没有标记到的函数
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # 2021-08-03 11:11 import tvm from tvm import relay mod = tvm.IRModule() x = relay.var("x", shape=(1, 1000)) mod["fn1"] = relay.Function([x], x) y = relay.var("y", shape=(1, 1000)) mod["fn2"] = relay.Function([y], y) z = relay.var("z", shape=(1, 1000)) mod["main"] = relay.Function([z], relay.add(z, relay.GlobalVar("fn1")(z))) print("----------before----------") print(mod.get_global_vars()) print(mod) print("----------after----------") mod = relay.transform.RemoveUnusedFunctions()(mod) print(mod)
-----–—before-----–— [GlobalVar(fn1), GlobalVar(fn2), GlobalVar(main)] def @fn1(%x: Tensor[(1, 1000), float32]) { %x }
def @fn2(%y: Tensor[(1, 1000), float32]) { %y }
def @main(%z: Tensor[(1, 1000), float32]) { %0 = @fn1(%z); add(%z, %0) }
-----–—after-----–— def @fn1(%x: Tensor[(1, 1000), float32]) { %x }
def @main(%z: Tensor[(1, 1000), float32]) { %0 = @fn1(%z); add(%z, %0) }
1.3. SimplifyInference
- 把 batchnorm 转换为 multiply/add
- 去掉 dropout
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # 2021-08-03 11:11 import tvm from tvm import relay mod = tvm.IRModule() x = relay.var("x", shape=(1, 10)) alpha, gamma, mean, var = ( relay.var("x", shape=(10,)), relay.var("y", shape=(10,)), relay.var("a", shape=(10,)), relay.var("b", shape=(10,)), ) # fmt:off bn,_,_, = relay.nn.batch_norm(x, alpha, gamma, mean, var) # fmt:on mod["main"] = relay.Function( [x, alpha, gamma, mean, var], relay.nn.dropout(bn), ) print("----------before----------") print(mod) print("----------after----------") mod = relay.transform.InferType()(mod) mod = relay.transform.SimplifyInference()(mod) print(mod)
-----–—before-----–— def @main(%x: Tensor[(1, 10), float32], %x1: Tensor[(10), float32], %y: Tensor[(10), float32], %a: Tensor[(10), float32], %b: Tensor[(10), float32]) { %0 = nn.batch_norm(%x, %x1, %y, %a, %b); %1 = %0.0; %2 = nn.dropout(%1); %2.0 }
-----–—after-----–— def @main(%x: Tensor[(1, 10), float32], %x1: Tensor[(10), float32], %y: Tensor[(10), float32], %a: Tensor[(10), float32], %b: Tensor[(10), float32]) -> Tensor[(1, 10), float32] { %0 = add(%b, 1e-05f * ty=float32 *) * ty=Tensor[(10), float32] *; %1 = sqrt(%0) * ty=Tensor[(10), float32] *; %2 = divide(1f * ty=float32 *, %1) * ty=Tensor[(10), float32] *; %3 = multiply(%2, %x1) * ty=Tensor[(10), float32] *; %4 = negative(%a) * ty=Tensor[(10), float32] *; %5 = multiply(%4, %3) * ty=Tensor[(10), float32] *; %6 = multiply(%x, %3) * ty=Tensor[(1, 10), float32] *; %7 = add(%5, %y) * ty=Tensor[(10), float32] *; add(%6, %7) * ty=Tensor[(1, 10), float32] * }
if (const auto* call = new_n->tuple.as<CallNode>()) { if (call->op == batch_norm_op_) { return BatchNormToInferUnpack( call->attrs, call->args[0], call->args[1], call->args[2], call->args[3], call->args[4], ty_map_.at(call->args[0])); } else if (call->op == dropout_op_) { return call->args[0]; } } Expr BatchNormToInferUnpack( const Attrs attrs, Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, Type tdata) { if (param->scale) { scale = Multiply(scale, gamma); } Expr neg_mean = Negative(moving_mean); Expr shift = Multiply(neg_mean, scale); if (param->center) { shift = Add(shift, beta); } Expr out = Multiply(data, scale); out = Add(out, shift); return out; }
1.4. Inline
函数可以被 inline 的条件:
- 有 `Inline` 属性
- 不是递归函数
- 调用的其它函数也是可以 inline 的
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # 2021-08-03 11:11 import tvm from tvm import relay mod = tvm.IRModule() x1 = relay.var("x1", shape=(1, 10)) fn1 = relay.Function([x1], x1) fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) g1 = relay.GlobalVar("g1") mod[g1] = fn1 x2 = relay.var("x2", shape=(1, 10)) fn2 = relay.Function([x2], x2) # fn2 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) g2 = relay.GlobalVar("g2") mod[g2] = fn2 p0 = relay.var("p0", shape=(1, 10)) mod["main"] = relay.Function([p0], relay.add(g1(p0), g2(p0))) print("----------before----------") print(mod) mod = relay.transform.Inline()(mod) print("----------after----------") print(mod)
-----–—before-----–— def @g1(%x1: Tensor[(1, 10), float32], Inline=1) { %x1 }
def @g2(%x2: Tensor[(1, 10), float32]) { %x2 }
def @main(%p0: Tensor[(1, 10), float32]) { %0 = @g1(%p0); %1 = @g2(%p0); add(%0, %1) }
-----–—after-----–— def @main(%p0: Tensor[(1, 10), float32]) { %0 = @g2(%p0); add(%p0, %0) }
def @g2(%x2: Tensor[(1, 10), float32]) { %x2 }
1.5. RunDeviceAnnotationPass
RunDeviceAnnotationPass 是为了处理 on_device annotation, vta build 时的 graph_pack 依赖 build 时的 RunDeviceAnnotationPass 才能工作
NOTE: 新的代码里相关功能放在了 PlanDevices 里了
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # 2021-08-03 11:11 import tvm from tvm import relay x = relay.var("x", shape=(1, 10)) y = relay.var("y", shape=(1, 10)) add = relay.add(x, y) sqrt = relay.sqrt(add) _sqrt = relay.annotation.on_device(sqrt, "cuda") log = relay.log(add) subtract = relay.subtract(_sqrt, log) exp = relay.exp(subtract) _exp = relay.annotation.on_device(exp, "cuda") func = relay.Function([x, y], _exp) mod = tvm.IRModule.from_expr(func) print("----------before----------") print(mod) print("----------after----------") mod = relay.transform.RewriteAnnotatedOps(1)(mod) print(mod["main"])
-----–—before-----–— def @main(%x: Tensor[(1, 10), float32], %y: Tensor[(1, 10), float32]) { %0 = add(%x, %y); %1 = sqrt(%0); %2 = on_device(%1, meta[relay.attrs.OnDeviceAttrs][0]); %3 = log(%0); %4 = subtract(%2, %3); %5 = exp(%4); on_device(%5, meta[relay.attrs.OnDeviceAttrs][1]) }
-----–—after-----–— fn (%x: Tensor[(1, 10), float32], %y: Tensor[(1, 10), float32]) -> Tensor[(1, 10), float32] { %0 = add(%x, %y) * ty=Tensor[(1, 10), float32] *; %1 = device_copy(%0, meta[relay.attrs.DeviceCopyAttrs][0]) * ty=Tensor[(1, 10), float32] *; %2 = sqrt(%1) * ty=Tensor[(1, 10), float32] *; %3 = device_copy(%2, meta[relay.attrs.DeviceCopyAttrs][1]) * ty=Tensor[(1, 10), float32] *; %4 = log(%0) * ty=Tensor[(1, 10), float32] *; %5 = subtract(%3, %4) * ty=Tensor[(1, 10), float32] *; %6 = device_copy(%5, meta[relay.attrs.DeviceCopyAttrs][2]) * ty=Tensor[(1, 10), float32] *; exp(%6) * ty=Tensor[(1, 10), float32] * }
1.5.1. device_copy
device_copy 是由 在 LowerTE 时处理的:
- 一方面它被用来确定 expr 所在的 target, 以确定 topi strategy
- 另一方面运行时 graph_executor 会根据这个标记做真正的数据拷贝.
tvm/src/relay/backend/te_compiler_cache.cc::// Set the name to `__copy`. It will be detected in graph runtime to perform
例如, graph_executor 在执行时碰到 `__copy` 时会调用设备相关的 CopyDataFromTo 等函数, 对于 opencl 来说就是 clEnqueueCopyBuffer:
void OpenCLWorkspace::CopyDataFromTo( DLTensor* from, DLTensor* to, TVMStreamHandle stream) { if (IsOpenCLDevice(from->device) && IsOpenCLDevice(to->device)) { const auto* from_desc = static_cast<const cl::BufferDescriptor*>(from->data); auto* to_desc = static_cast<cl::BufferDescriptor*>(to->data); clEnqueueCopyBuffer( this->GetQueue(to->device), from_desc->buffer, to_desc->buffer, from->byte_offset, to->byte_offset, nbytes, 0, nullptr, nullptr); } // ... }
除了上面 __copy, 使用 opencl 等 target 在 set_input/get_output 时也会通过 CopyDataFromTo 与设备交换数据.
1.6. FuseOps
1.6.1. Example
1.6.1.1. 不使用 FuseOps:
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # 2021-09-08 18:19 import tvm from tvm import relay from tvm.relay import transform from tvm.relay.testing import run_opt_pass import numpy as np def test_fuse_simple(): """Simple testcase.""" def before(): x = relay.var("x", shape=(10, 20)) y = relay.add(x, x) y = relay.add(y, y) y = relay.add(y, y) y = relay.add(y, y) y = relay.add(y, y) z = relay.exp(y) return relay.Function([x], z) z = before() print(z) # z = run_opt_pass(z, transform.FuseOps()) # print(z) with tvm.transform.PassContext(opt_level=0): graph, lib, params = relay.build(z, target="llvm", params=None) print(graph) if __name__ == "__main__": test_fuse_simple()
fn (%x: Tensor[(10, 20), float32]) { %0 = add(%x, %x); %1 = add(%0, %0); %2 = add(%1, %1); %3 = add(%2, %2); %4 = add(%3, %3); exp(%4) } { "nodes": [ { "op": "null", "name": "x", "inputs": [] }, { "op": "tvm_op", "name": "tvmgen_default_fused_add", "attrs": { "num_outputs": "1", "num_inputs": "1", "flatten_data": "0", "func_name": "tvmgen_default_fused_add", "hash": "aadf70b47b6beaf4" }, "inputs": [ [ 0, 0, 0 ] ] }, { "op": "tvm_op", "name": "tvmgen_default_fused_add1", "attrs": { "num_outputs": "1", "num_inputs": "1", "flatten_data": "0", "func_name": "tvmgen_default_fused_add", "hash": "aadf70b47b6beaf4" }, "inputs": [ [ 1, 0, 0 ] ] }, { "op": "tvm_op", "name": "tvmgen_default_fused_add2", "attrs": { "num_outputs": "1", "num_inputs": "1", "flatten_data": "0", "func_name": "tvmgen_default_fused_add", "hash": "aadf70b47b6beaf4" }, "inputs": [ [ 2, 0, 0 ] ] }, { "op": "tvm_op", "name": "tvmgen_default_fused_add3", "attrs": { "num_outputs": "1", "num_inputs": "1", "flatten_data": "0", "func_name": "tvmgen_default_fused_add", "hash": "aadf70b47b6beaf4" }, "inputs": [ [ 3, 0, 0 ] ] }, { "op": "tvm_op", "name": "tvmgen_default_fused_add4", "attrs": { "num_outputs": "1", "num_inputs": "1", "flatten_data": "0", "func_name": "tvmgen_default_fused_add", "hash": "aadf70b47b6beaf4" }, "inputs": [ [ 4, 0, 0 ] ] }, { "op": "tvm_op", "name": "tvmgen_default_fused_exp", "attrs": { "num_outputs": "1", "num_inputs": "1", "flatten_data": "0", "func_name": "tvmgen_default_fused_exp", "hash": "de3b50c71256954a" }, "inputs": [ [ 5, 0, 0 ] ] } ], "arg_nodes": [0], "heads": [ [ 6, 0, 0 ] ], "attrs": { "dltype": [ "list_str", [ "float32", "float32", "float32", "float32", "float32", "float32", "float32" ] ], "device_index": [ "list_int", [1, 1, 1, 1, 1, 1, 1] ], "storage_id": [ "list_int", [0, 1, 2, 1, 2, 1, 2] ], "shape": [ "list_shape", [ [10, 20], [10, 20], [10, 20], [10, 20], [10, 20], [10, 20], [10, 20] ] ] }, "node_row_ptr": [0, 1, 2, 3, 4, 5, 6, 7] }
可以看过每个 add 操作都会对应一个 graph 中的 node
1.6.1.2. 使用 FuseOps
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # 2021-09-08 18:19 import tvm from tvm import relay from tvm.relay import transform from tvm.relay.testing import run_opt_pass import numpy as np def test_fuse_simple(): """Simple testcase.""" def before(): x = relay.var("x", shape=(10, 20)) y = relay.add(x, x) y = relay.add(y, y) y = relay.add(y, y) y = relay.add(y, y) y = relay.add(y, y) z = relay.exp(y) return relay.Function([x], z) z = before() z = run_opt_pass(z, transform.FuseOps()) print(z) with tvm.transform.PassContext(opt_level=0): graph, lib, params = relay.build(z, target="llvm", params=None) print(graph) if __name__ == "__main__": test_fuse_simple()
fn (%x: Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] { %5 = fn (%p0: Tensor[(10, 20), float32], Primitive=1) -> Tensor[(10, 20), float32] { %0 = add(%p0, %p0) * ty=Tensor[(10, 20), float32] *; %1 = add(%0, %0) * ty=Tensor[(10, 20), float32] *; %2 = add(%1, %1) * ty=Tensor[(10, 20), float32] *; %3 = add(%2, %2) * ty=Tensor[(10, 20), float32] *; %4 = add(%3, %3) * ty=Tensor[(10, 20), float32] *; exp(%4) * ty=Tensor[(10, 20), float32] * }; %5(%x) * ty=Tensor[(10, 20), float32] * } { "nodes": [ { "op": "null", "name": "x", "inputs": [] }, { "op": "tvm_op", "name": "tvmgen_default_fused_add_add_add_add_add_exp", "attrs": { "num_outputs": "1", "num_inputs": "1", "flatten_data": "0", "func_name": "tvmgen_default_fused_add_add_add_add_add_exp", "hash": "bcd04c940b541895" }, "inputs": [ [ 0, 0, 0 ] ] } ], "arg_nodes": [0], "heads": [ [ 1, 0, 0 ] ], "attrs": { "dltype": [ "list_str", [ "float32", "float32" ] ], "device_index": [ "list_int", [1, 1] ], "storage_id": [ "list_int", [0, 1] ], "shape": [ "list_shape", [ [10, 20], [10, 20] ] ] }, "node_row_ptr": [0, 1, 2] }
所有的 op 都被合并到同一个函数 tvmgen_default_fused_add_add_add_add_add_exp
1.6.2. 为什么需要 FuseOps
通过 FuseOps 显然可以减少函数调用的开销. 那么, 每次都把所有的 op 都合并到一个函数中不就可以了么?
每个 op 都有一个 schedule, 如果 schedule 不是 `相容` 的, 则它们是无法合并在一起的, 例如:
import tvm from tvm import relay from tvm.relay import transform from tvm.relay.testing import run_opt_pass import numpy as np def test_fuse_simple(): def before(): x = relay.var("x", shape=(20, 20)) y1 = relay.add(x, x) y2 = relay.add(y1, y1) y3 = relay.sum(y2, axis=-1) y = relay.add(y3, y3) return relay.Function([x], y) z = before() z = run_opt_pass(z, transform.FuseOps()) print(z) with tvm.transform.PassContext(opt_level=0): graph, lib, params = relay.build(z, target="c", params=None) print(lib.get_source()) if __name__ == "__main__": test_fuse_simple()
fn (%x: Tensor[(20, 20), float32]) -> Tensor[(20), float32] {
%2 = fn (%p01: Tensor[(20, 20), float32], Primitive=1) -> Tensor[(20), float32] {
%0 = add(%p01, %p01) * ty=Tensor[(20, 20), float32] *;
%1 = add(%0, %0) * ty=Tensor[(20, 20), float32] *;
sum(%1, axis=[-1]) * ty=Tensor[(20), float32] *
};
%3 = %2(%x) * ty=Tensor[(20), float32] *;
%4 = fn (%p0: Tensor[(20), float32], Primitive=1) -> Tensor[(20), float32] {
add(%p0, %p0) * ty=Tensor[(20), float32] *
};
%4(%3) * ty=Tensor[(20), float32] *
}
// tvm target: c -keys=cpu -link-params=0
#define TVM_EXPORTS
#include "tvm/runtime/c_runtime_api.h"
#include "tvm/runtime/c_backend_api.h"
#include <math.h>
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_default_fused_add_add_sum(void* args, void* arg_type_ids, int32_t num_args, void* out_ret_value, void* out_ret_tcode, void* resource_handle) {
void* arg0 = (((TVMValue*)args)[0].v_handle);
int32_t arg0_code = ((int32_t*)arg_type_ids)[(0)];
void* arg1 = (((TVMValue*)args)[1].v_handle);
int32_t arg1_code = ((int32_t*)arg_type_ids)[(1)];
void* placeholder = (((DLTensor*)arg0)[0].data);
void* arg0_shape = (((DLTensor*)arg0)[0].shape);
void* arg0_strides = (((DLTensor*)arg0)[0].strides);
int32_t dev_id = (((DLTensor*)arg0)[0].device.device_id);
void* T_add_red = (((DLTensor*)arg1)[0].data);
void* arg1_shape = (((DLTensor*)arg1)[0].shape);
void* arg1_strides = (((DLTensor*)arg1)[0].strides);
if (!(arg0_strides = NULL)) {
}
if (!(arg1_strides =
NULL)) {
}
for (int32_t ax0 = 0; ax0 < 20; ++ax0) {
((float*)T_add_red)[(ax0)] = 0.000000e+00f;
for (int32_t k1 = 0; k1 < 20; ++k1) {
((float*)T_add_red)[(ax0)] = (((float*)T_add_red)[(ax0)] + ((((float*)placeholder)[(((ax0 * 20) + k1))] + ((float*)placeholder)[(((ax0 * 20) + k1))]) + (((float*)placeholder)[(((ax0 * 20) + k1))] + ((float*)placeholder)[(((ax0 * 20) + k1))])));
}
}
return 0;
}
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_default_fused_add(void* args, void* arg_type_ids, int32_t num_args, void* out_ret_value, void* out_ret_tcode, void* resource_handle) {
void* arg0 = (((TVMValue*)args)[0].v_handle);
int32_t arg0_code = ((int32_t*)arg_type_ids)[(0)];
void* arg1 = (((TVMValue*)args)[1].v_handle);
int32_t arg1_code = ((int32_t*)arg_type_ids)[(1)];
void* placeholder = (((DLTensor*)arg0)[0].data);
void* arg0_shape = (((DLTensor*)arg0)[0].shape);
void* arg0_strides = (((DLTensor*)arg0)[0].strides);
int32_t dev_id = (((DLTensor*)arg0)[0].device.device_id);
void* T_add = (((DLTensor*)arg1)[0].data);
void* arg1_shape = (((DLTensor*)arg1)[0].shape);
void* arg1_strides = (((DLTensor*)arg1)[0].strides);
if (!(arg0_strides = NULL)) {
}
if (!(arg1_strides =
NULL)) {
}
for (int32_t ax0_outer = 0; ax0_outer < 2; ++ax0_outer) {
for (int32_t ax0_inner_s = 0; ax0_inner_s < 16; ++ax0_inner_s) {
if (((ax0_outer * 16) + ax0_inner_s) < 20) {
((float*)T_add)[(((ax0_outer * 16) + ax0_inner_s))] = (((float*)placeholder)[(((ax0_outer * 16) + ax0_inner_s))] + ((float*)placeholder)[(((ax0_outer * 16) + ax0_inner_s))]);
}
}
}
return 0;
}
y1, y2, y3 三个操作使用的 schedule 都是要处理一个 20x20 的循环, 但 y 的 schdule 处理的是 20x1 的循环, 所以它们无法合并成一个 op.
1.7. FoldScaleAxis
FoldScaleAxis 可以把 Conv2D 和 BatchNorm 融合在一起.
算子融合的好处:
- 减少算子间的中间内存分配
- 两个算子的 loop 可以做成同一个 loop
FoldScaleAxis 除了上面的好处, 还有一个原因是许多加速器并不支持 scalar multiplication, 例如 BatchNorm
TVM 的 FoldScaleAxis 实际上包含三个 transform:
Pass FoldScaleAxis() { Pass pass = Sequential( {BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()}, "FoldScaleAxis"); return pass; }
其中:
- BackwardFoldScaleAxis 是把某个 Op (例如 Conv2D) 后面的 scale 与 Op fold 在一起
- ForwardFoldScaleAxis 是把 Op 前面的 scale 与 Op fold
- ForwardFoldScaleAxis 或 BackwardFoldScaleAxis 只是负责把 x*scale 或 y*scale 变成 w*scale, 最终还需要 FoldConstant 把 w*scale fold 成一个 constant (前提是 w, scale 都是常量)
1.7.1. Example
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # 2021-10-14 22:36 import numpy as np import tvm from tvm import te from tvm import relay from tvm.relay import transform I = N = 1 O = C = 2 H = W = 3 def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body def test_fold(): def get_model(x, weight, scale): args = [x] # x = relay.multiply(x, scale) y = relay.nn.conv2d( x, weight, channels=C, kernel_size=(3, 3), padding=(1, 1), ) y = relay.multiply(y, scale) return relay.Function(args, y) def check(shape): x = relay.var("x", shape=shape) weight = relay.const(np.random.randn(O, I, H, W).astype("float32")) scale = relay.const(np.random.randn(C, 1, 1).astype("float32")) y = get_model(x, weight, scale) y = run_opt_pass(y, transform.InferType()) print(y) y_folded = run_opt_pass(y, transform.BackwardFoldScaleAxis()) print(y_folded) y_folded = run_opt_pass(y_folded, transform.FoldConstant()) print(y_folded) check((1, N, 10, 10)) if __name__ == "__main__": test_fold()
fn (%x: Tensor[(1, 1, 10, 10), float32]) -> Tensor[(1, 2, 10, 10), float32] { %0 = nn.conv2d(%x, meta[relay.Constant][0] * ty=Tensor[(2, 1, 3, 3), float32] *, padding=[1, 1, 1, 1], channels=2, kernel_size=[3, 3]) * ty=Tensor[(1, 2, 10, 10), float32] *; multiply(%0, meta[relay.Constant][1] * ty=Tensor[(2, 1, 1), float32] *) * ty=Tensor[(1, 2, 10, 10), float32] * }
fn (%x: Tensor[(1, 1, 10, 10), float32]) -> Tensor[(1, 2, 10, 10), float32] { %0 = squeeze(meta[relay.Constant][1] * ty=Tensor[(2, 1, 1), float32] *, axis=[1, 2]) * ty=Tensor[(2), float32] *; %1 = expand_dims(%0, axis=1, num_newaxis=3) * ty=Tensor[(2, 1, 1, 1), float32] *; %2 = multiply(meta[relay.Constant][0] * ty=Tensor[(2, 1, 3, 3), float32] *, %1) * ty=Tensor[(2, 1, 3, 3), float32] *; nn.conv2d(%x, %2, padding=[1, 1, 1, 1], channels=2, kernel_size=[3, 3]) * ty=Tensor[(1, 2, 10, 10), float32] * }
fn (%x: Tensor[(1, 1, 10, 10), float32]) -> Tensor[(1, 2, 10, 10), float32] { nn.conv2d(%x, meta[relay.Constant][0] * ty=Tensor[(2, 1, 3, 3), float32] *, padding=[1, 1, 1, 1], channels=2, kernel_size=[3, 3]) * ty=Tensor[(1, 2, 10, 10), float32] * }
1.7.2. Impl
// multiply 会先于 conv2d 被遍历到, 记下 scale Expr MultiplyBackwardTransform( const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { ICHECK(!message.defined()) << "outstanding scale"; const auto* tlhs = call->args[0]->type_as<TensorTypeNode>(); const auto* trhs = call->args[1]->type_as<TensorTypeNode>(); Message lhs_message = transformer->GetMessage(call->args[0]); Message rhs_message = transformer->GetMessage(call->args[1]); if (lhs_message.defined()) { Expr rhs = call->args[1]; if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_message->axes, &rhs) && (!lhs_message->require_positive || IsAllPositiveConstant(rhs))) { // 例如, mul(conv2d, scale), call->args[0] 即 conv2d, rhs 即 scale return transformer->Transform(call->args[0], lhs_message, rhs); } } else if (rhs_message.defined()) { Expr lhs = call->args[0]; if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_message->axes, &lhs) && (!rhs_message->require_positive || IsAllPositiveConstant(lhs))) { // mul(scale,conv2d), call->args[1] 即 conv2d, lhs 即 scale return transformer->Transform(call->args[1], rhs_message, lhs); } } return transformer->NormalCallTransform(call.operator->()); } Expr Conv2DBackwardTransform( const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { // ... Expr data = transformer->Transform( call->args[0], NullValue<Message>(), NullValue<Expr>()); Expr weight = transformer->Transform( call->args[1], NullValue<Message>(), NullValue<Expr>()); // scale on input for deptwise. Expr wscale = ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_ko_axis}); weight = Multiply(weight, wscale); return Call(call->op, {data, weight}, call->attrs, call->type_args); }
1.8. CombineParallelDense
FuseOps, FoldScaleAxis 等 transform 属于纵向的融合, 而 CombineXXX 属于横向的融合
1.8.1. Example
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # 2021-10-14 22:36 import numpy as np import tvm from tvm import te from tvm import relay from tvm.relay import transform def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) mod = tvm.relay.transform.InferType()(mod) mod = opt_pass(mod) return mod["main"] def test_combine_parallel_dense(): """Simple testcase. One dense cannot be combined due to shape mismatch""" def before(x, w1, w2): args = [x, w1, w2] y1 = relay.nn.dense(x, w1) y2 = relay.nn.dense(x, w2) y = relay.Tuple((y1, y2)) return relay.Function(args, y) def check(i, j, k): x = relay.var("x", shape=(i, k)) w1 = relay.var("w1", shape=(j, k)) w2 = relay.var("w2", shape=(j, k)) y_before = before(x, w1, w2) print(y_before) y = run_opt_pass(y_before, transform.CombineParallelDense(min_num_branches=2)) print(y) # 3x4 分别和两个 4x5 相乘, 得到两个 3x5 check(3, 5, 4) if __name__ == "__main__": test_combine_parallel_dense()
fn (%x: Tensor[(3, 4), float32], %w1: Tensor[(5, 4), float32], %w2: Tensor[(5, 4), float32]) { %0 = nn.dense(%x, %w1, units=None); %1 = nn.dense(%x, %w2, units=None); (%0, %1) } fn (%x: Tensor[(3, 4), float32], %w1: Tensor[(5, 4), float32], %w2: Tensor[(5, 4), float32]) -> (Tensor[(3, 5), float32], Tensor[(3, 5), float32]) { %0 = (%x, %x); %1 = (%w1, %w2); %2 = stack(%0) * ty=Tensor[(2, 3, 4), float32] *; %3 = stack(%1) * ty=Tensor[(2, 5, 4), float32] *; %4 = nn.batch_matmul(%2, %3, transpose_b=True) * ty=Tensor[(2, 3, 5), float32] *; %5 = split(%4, indices_or_sections=2) * ty=(Tensor[(1, 3, 5), float32], Tensor[(1, 3, 5), float32]) *; %6 = %5.0; %7 = %5.1; %8 = squeeze(%6, axis=[0]) * ty=Tensor[(3, 5), float32] *; %9 = squeeze(%7, axis=[0]) * ty=Tensor[(3, 5), float32] *; (%8, %9) }
1.9. CombineParallelConv2D
1.9.1. Example
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # 2021-10-22 16:46 import tvm from tvm import relay from tvm.relay import transform def test_combine_parallel_conv2d(): def before(x, w1, w2): args = [x, w1, w2] y1 = relay.nn.conv2d(x, w1) y2 = relay.nn.conv2d(x, w2) y3 = relay.nn.max_pool2d(x) y = relay.Tuple((y1, y2, y3)) func = relay.Function(args, y) mod = tvm.IRModule.from_expr(func) mod = tvm.relay.transform.InferType()(mod) return mod def check(x_shape, channels1, channels2): x = relay.var("x", shape=x_shape) in_c = x_shape[1] w1 = relay.var("w1", shape=(channels1, in_c, 1, 1)) w2 = relay.var("w2", shape=(channels2, in_c, 1, 1)) mod = before(x, w1, w2) print("------before------") print(mod) mod = transform.CombineParallelConv2D(min_num_branches=2)(mod) print("------after------") print(mod) check((1, 4, 16, 16), 4, 4) if __name__ == "__main__": test_combine_parallel_conv2d()
-–—before-–— def @main(%x: Tensor[(1, 4, 16, 16), float32], %w1: Tensor[(4, 4, 1, 1), float32], %w2: Tensor[(4, 4, 1, 1), float32]) -> (Tensor[(1, 4, 16, 16), float32], Tensor[(1, 4, 16, 16), float32], Tensor[(1, 4, 16, 16), float32]) { %0 = nn.conv2d(%x, %w1, padding=[0, 0, 0, 0]) * ty=Tensor[(1, 4, 16, 16), float32] *; %1 = nn.conv2d(%x, %w2, padding=[0, 0, 0, 0]) * ty=Tensor[(1, 4, 16, 16), float32] *; %2 = nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0]) * ty=Tensor[(1, 4, 16, 16), float32] *; (%0, %1, %2) }
-–—after-–— def @main(%x: Tensor[(1, 4, 16, 16), float32], %w1: Tensor[(4, 4, 1, 1), float32], %w2: Tensor[(4, 4, 1, 1), float32]) -> (Tensor[(1, 4, 16, 16), float32], Tensor[(1, 4, 16, 16), float32], Tensor[(1, 4, 16, 16), float32]) { %0 = (%w1, %w2); %1 = concatenate(%0) * ty=Tensor[(8, 4, 1, 1), float32] *; %2 = nn.conv2d(%x, %1, padding=[0, 0, 0, 0], channels=8) * ty=Tensor[(1, 8, 16, 16), float32] *; %3 = strided_slice(%2, begin=[0, 0], end=[-1, 4], strides=[1, 1], slice_mode="size", axes=None) * ty=Tensor[(1, 4, 16, 16), float32] *; %4 = strided_slice(%2, begin=[0, 4], end=[-1, 4], strides=[1, 1], slice_mode="size", axes=None) * ty=Tensor[(1, 4, 16, 16), float32] *; %5 = nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0]) * ty=Tensor[(1, 4, 16, 16), float32] *; (%3, %4, %5) }