TVM BYOC Codegen
Table of Contents
1. TVM BYOC Codegen
https://tvm.apache.org/2020/07/15/how-to-bring-your-own-codegen-to-tvm
https://gist.github.com/SrivastavaKshitij/9341a414147fbc290eff4a92b8e73acd
BYOC 与 TFLite delegate 有些类似
tvm/docs/dev/relay_bring_your_own_codegen.rst
1.1. DNNL BYOC Example
1.1.1. 不使用 DNNL
import tvm
from tvm import relay
def get_demo_mod():
d1 = relay.var("d1", shape=(1, 32, 56, 56), dtype="float32")
w1 = relay.var("w1", shape=(32, 32, 3, 3), dtype="float32")
b1 = relay.var("b1", shape=(32,), dtype="float32")
conv = relay.nn.conv2d(d1, w1, strides=(1, 1), padding=(1, 1))
bias = relay.nn.bias_add(conv, b1)
relu = relay.nn.relu(bias)
func = relay.Function([d1, w1, b1], relu)
mod = tvm.IRModule.from_expr(func)
mod = relay.transform.InferType()(mod)
return mod
mod = get_demo_mod()
print(mod)
def @main(%d1: Tensor[(1, 32, 56, 56), float32], %w1: Tensor[(32, 32, 3, 3), float32], %b1: Tensor[(32), float32]) -> Tensor[(1, 32, 56, 56), float32] {
%0 = nn.conv2d(%d1, %w1, padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 32, 56, 56), float32] */;
%1 = nn.bias_add(%0, %b1) /* ty=Tensor[(1, 32, 56, 56), float32] */;
nn.relu(%1) /* ty=Tensor[(1, 32, 56, 56), float32] */
}
with tvm.transform.PassContext(opt_level=2):
graph, lib, params = relay.build(mod, target="c", params=None)
print(graph)
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
{
"nodes": [
{
"op": "null",
"name": "d1",
"inputs": []
},
{
"op": "null",
"name": "w1",
"inputs": []
},
{
"op": "null",
"name": "b1",
"inputs": []
},
{
"op": "tvm_op",
"name": "tvmgen_default_fused_nn_conv2d_nn_bias_add_nn_relu",
"attrs": {
"num_outputs": "1",
"flatten_data": "0",
"out_layout": "",
"func_name": "tvmgen_default_fused_nn_conv2d_nn_bias_add_nn_relu",
"hash": "45f3f3bf6ba0201f",
"num_inputs": "3",
"kernel_layout": "OIHW",
"data_layout": "NCHW"
},
"inputs": [
[
0,
0,
0
],
[
1,
0,
0
],
[
2,
0,
0
]
]
}
],
"arg_nodes": [0, 1, 2],
"heads": [
[
3,
0,
0
]
],
"attrs": {
"dltype": [
"list_str",
[
"float32",
"float32",
"float32",
"float32"
]
],
"shape": [
"list_shape",
[
[1, 32, 56, 56],
[32, 32, 3, 3],
[32],
[1, 32, 56, 56]
]
],
"storage_id": [
"list_int",
[0, 1, 2, 3]
]
},
"node_row_ptr": [0, 1, 2, 3, 4]
}
/tmp/ipykernel_8688/48896817.py:2: DeprecationWarning: legacy graph executor behavior of producing json / lib / params will be removed in the next release. Please see documents of tvm.contrib.graph_executor.GraphModule for the new recommended usage.
graph, lib, params = relay.build(mod, target="c", params=None)
print(lib.get_source())
// tvm target: c -keys=cpu -link-params=0
#define TVM_EXPORTS
#include "tvm/runtime/c_runtime_api.h"
#include "tvm/runtime/c_backend_api.h"
#include <math.h>
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_default_fused_nn_conv2d_nn_bias_add_nn_relu(void* args, void* arg_type_ids, int32_t num_args, void* out_ret_value, void* out_ret_tcode, void* resource_handle) {
void* arg0 = (((TVMValue*)args)[0].v_handle);
int32_t arg0_code = ((int32_t*)arg_type_ids)[(0)];
void* arg1 = (((TVMValue*)args)[1].v_handle);
int32_t arg1_code = ((int32_t*)arg_type_ids)[(1)];
void* arg2 = (((TVMValue*)args)[2].v_handle);
int32_t arg2_code = ((int32_t*)arg_type_ids)[(2)];
void* arg3 = (((TVMValue*)args)[3].v_handle);
int32_t arg3_code = ((int32_t*)arg_type_ids)[(3)];
void* placeholder = (((DLTensor*)arg0)[0].data);
void* arg0_shape = (((DLTensor*)arg0)[0].shape);
void* arg0_strides = (((DLTensor*)arg0)[0].strides);
int32_t dev_id = (((DLTensor*)arg0)[0].device.device_id);
void* placeholder1 = (((DLTensor*)arg1)[0].data);
void* arg1_shape = (((DLTensor*)arg1)[0].shape);
void* arg1_strides = (((DLTensor*)arg1)[0].strides);
void* placeholder2 = (((DLTensor*)arg2)[0].data);
void* arg2_shape = (((DLTensor*)arg2)[0].shape);
void* arg2_strides = (((DLTensor*)arg2)[0].strides);
void* T_relu = (((DLTensor*)arg3)[0].data);
void* arg3_shape = (((DLTensor*)arg3)[0].shape);
void* arg3_strides = (((DLTensor*)arg3)[0].strides);
if (!(arg0_strides == NULL)) {
}
if (!(arg1_strides == NULL)) {
}
if (!(arg2_strides == NULL)) {
}
if (!(arg3_strides == NULL)) {
}
void* data_vec = TVMBackendAllocWorkspace(1, dev_id, (uint64_t)401408, 2, 32);
if (data_vec == NULL) {
return -1;
}
void* data_pad = TVMBackendAllocWorkspace(1, dev_id, (uint64_t)430592, 2, 32);
if (data_pad == NULL) {
return -1;
}
for (int32_t bs_c_fused_h_fused = 0; bs_c_fused_h_fused < 224; ++bs_c_fused_h_fused) {
for (int32_t w = 0; w < 56; ++w) {
for (int32_t vc = 0; vc < 8; ++vc) {
((float*)data_vec)[((((bs_c_fused_h_fused * 448) + (w * 8)) + vc))] = ((float*)placeholder)[((((((bs_c_fused_h_fused / 56) * 25088) + (vc * 3136)) + ((bs_c_fused_h_fused % 56) * 56)) + w))];
}
}
}
for (int32_t i0_i1_fused_i2_fused = 0; i0_i1_fused_i2_fused < 232; ++i0_i1_fused_i2_fused) {
for (int32_t i3 = 0; i3 < 58; ++i3) {
((float8*)((float*)data_pad + (((i0_i1_fused_i2_fused * 464) + (i3 * 8)))))[0] = (((((1 <= (i0_i1_fused_i2_fused % 58)) && ((i0_i1_fused_i2_fused % 58) < 57)) && (1 <= i3)) && (i3 < 57)) ? ((float8*)((float*)data_vec + ((((((i0_i1_fused_i2_fused / 58) * 25088) + ((i0_i1_fused_i2_fused % 58) * 448)) + (i3 * 8)) - 456))))[0] : ((float8)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f)));
}
}
for (int32_t occ_k_h_fused = 0; occ_k_h_fused < 12; ++occ_k_h_fused) {
for (int32_t icc = 0; icc < 4; ++icc) {
for (int32_t k_w = 0; k_w < 3; ++k_w) {
for (int32_t icb = 0; icb < 8; ++icb) {
int32_t8 _1 = (int8)((((((((occ_k_h_fused / 3) * 2304) + (icc * 72)) + (icb * 9)) + ((occ_k_h_fused % 3) * 3)) + k_w))+(288*0), (((((((occ_k_h_fused / 3) * 2304) + (icc * 72)) + (icb * 9)) + ((occ_k_h_fused % 3) * 3)) + k_w))+(288*1), (((((((occ_k_h_fused / 3) * 2304) + (icc * 72)) + (icb * 9)) + ((occ_k_h_fused % 3) * 3)) + k_w))+(288*2), (((((((occ_k_h_fused / 3) * 2304) + (icc * 72)) + (icb * 9)) + ((occ_k_h_fused % 3) * 3)) + k_w))+(288*3), (((((((occ_k_h_fused / 3) * 2304) + (icc * 72)) + (icb * 9)) + ((occ_k_h_fused % 3) * 3)) + k_w))+(288*4), (((((((occ_k_h_fused / 3) * 2304) + (icc * 72)) + (icb * 9)) + ((occ_k_h_fused % 3) * 3)) + k_w))+(288*5), (((((((occ_k_h_fused / 3) * 2304) + (icc * 72)) + (icb * 9)) + ((occ_k_h_fused % 3) * 3)) + k_w))+(288*6), (((((((occ_k_h_fused / 3) * 2304) + (icc * 72)) + (icb * 9)) + ((occ_k_h_fused % 3) * 3)) + k_w))+(288*7));
((float8*)((float*)data_vec + (((((((occ_k_h_fused / 3) * 2304) + (icc * 576)) + ((occ_k_h_fused % 3) * 192)) + (k_w * 64)) + (icb * 8)))))[0] = ((float8)(((float*)placeholder1)[_1.s0],((float*)placeholder1)[_1.s1],((float*)placeholder1)[_1.s2],((float*)placeholder1)[_1.s3],((float*)placeholder1)[_1.s4],((float*)placeholder1)[_1.s5],((float*)placeholder1)[_1.s6],((float*)placeholder1)[_1.s7]));
}
}
}
}
for (int32_t ax0_ax1_outer_fused_ax2_fused = 0; ax0_ax1_outer_fused_ax2_fused < 224; ++ax0_ax1_outer_fused_ax2_fused) {
void* conv2d_NCHWc = TVMBackendAllocWorkspace(1, dev_id, (uint64_t)1792, 2, 32);
if (conv2d_NCHWc == NULL) {
return -1;
}
void* conv2d_NCHWc_global = TVMBackendAllocWorkspace(1, dev_id, (uint64_t)896, 2, 32);
if (conv2d_NCHWc_global == NULL) {
return -1;
}
for (int32_t ow_outer = 0; ow_outer < 2; ++ow_outer) {
((float8*)((float*)conv2d_NCHWc_global + (0)))[0] = ((float8)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f));
((float8*)((float*)conv2d_NCHWc_global + (8)))[0] = ((float8)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f));
((float8*)((float*)conv2d_NCHWc_global + (16)))[0] = ((float8)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f));
((float8*)((float*)conv2d_NCHWc_global + (24)))[0] = ((float8)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f));
((float8*)((float*)conv2d_NCHWc_global + (32)))[0] = ((float8)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f));
((float8*)((float*)conv2d_NCHWc_global + (40)))[0] = ((float8)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f));
((float8*)((float*)conv2d_NCHWc_global + (48)))[0] = ((float8)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f));
//...
for (int32_t ic_outer = 0; ic_outer < 4; ++ic_outer) {
for (int32_t kh = 0; kh < 3; ++kh) {
for (int32_t kw = 0; kw < 3; ++kw) {
for (int32_t ic_inner = 0; ic_inner < 8; ++ic_inner) {
((float8*)((float*)conv2d_NCHWc_global + (0)))[0] = (((float8*)((float*)conv2d_NCHWc_global + (0)))[0] + (((float8)(((float*)data_pad)[(((((((ic_outer * 26912) + (kh * 464)) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 464)) + (ow_outer * 224)) + (kw * 8)) + ic_inner))], ((float*)data_pad)[(((((((ic_outer * 26912) + (kh * 464)) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 464)) + (ow_outer * 224)) + (kw * 8)) + ic_inner))], ((float*)data_pad)[(((((((ic_outer * 26912) + (kh * 464)) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 464)) + (ow_outer * 224)) + (kw * 8)) + ic_inner))], ((float*)data_pad)[(((((((ic_outer * 26912) + (kh * 464)) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 464)) + (ow_outer * 224)) + (kw * 8)) + ic_inner))], ((float*)data_pad)[(((((((ic_outer * 26912) + (kh * 464)) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 464)) + (ow_outer * 224)) + (kw * 8)) + ic_inner))], ((float*)data_pad)[(((((((ic_outer * 26912) + (kh * 464)) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 464)) + (ow_outer * 224)) + (kw * 8)) + ic_inner))], ((float*)data_pad)[(((((((ic_outer * 26912) + (kh * 464)) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 464)) + (ow_outer * 224)) + (kw * 8)) + ic_inner))], ((float*)data_pad)[(((((((ic_outer * 26912) + (kh * 464)) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 464)) + (ow_outer * 224)) + (kw * 8)) + ic_inner))])) * ((float8*)((float*)data_vec + (((((((ax0_ax1_outer_fused_ax2_fused / 56) * 2304) + (ic_outer * 576)) + (kh * 192)) + (kw * 64)) + (ic_inner * 8)))))[0]));
//...
}
}
}
for (int32_t ow_inner = 0; ow_inner < 28; ++ow_inner) {
((float8*)((float*)conv2d_NCHWc + (((ow_outer * 224) + (ow_inner * 8)))))[0] = ((float8*)((float*)conv2d_NCHWc_global + ((ow_inner * 8))))[0];
}
}
for (int32_t ax3_outer = 0; ax3_outer < 2; ++ax3_outer) {
for (int32_t ax3_inner = 0; ax3_inner < 28; ++ax3_inner) {
int32_t8 _2 = (int8)(((((((ax0_ax1_outer_fused_ax2_fused / 56) * 25088) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 56)) + (ax3_outer * 28)) + ax3_inner))+(3136*0), ((((((ax0_ax1_outer_fused_ax2_fused / 56) * 25088) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 56)) + (ax3_outer * 28)) + ax3_inner))+(3136*1), ((((((ax0_ax1_outer_fused_ax2_fused / 56) * 25088) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 56)) + (ax3_outer * 28)) + ax3_inner))+(3136*2), ((((((ax0_ax1_outer_fused_ax2_fused / 56) * 25088) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 56)) + (ax3_outer * 28)) + ax3_inner))+(3136*3), ((((((ax0_ax1_outer_fused_ax2_fused / 56) * 25088) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 56)) + (ax3_outer * 28)) + ax3_inner))+(3136*4), ((((((ax0_ax1_outer_fused_ax2_fused / 56) * 25088) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 56)) + (ax3_outer * 28)) + ax3_inner))+(3136*5), ((((((ax0_ax1_outer_fused_ax2_fused / 56) * 25088) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 56)) + (ax3_outer * 28)) + ax3_inner))+(3136*6), ((((((ax0_ax1_outer_fused_ax2_fused / 56) * 25088) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 56)) + (ax3_outer * 28)) + ax3_inner))+(3136*7));
float8 _3 = ((float8*)((float*)conv2d_NCHWc + (((ax3_outer * 224) + (ax3_inner * 8)))))[0] + ((float8*)((float*)placeholder2 + (((ax0_ax1_outer_fused_ax2_fused / 56) * 8))))[0];
float8 _4 = (float8)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f);
float8 _5 = (_3) > (_4) ? (_3) : (_4);
((float*)T_relu)[_2.s0] = _5.s0;
((float*)T_relu)[_2.s1] = _5.s1;
((float*)T_relu)[_2.s2] = _5.s2;
((float*)T_relu)[_2.s3] = _5.s3;
((float*)T_relu)[_2.s4] = _5.s4;
((float*)T_relu)[_2.s5] = _5.s5;
((float*)T_relu)[_2.s6] = _5.s6;
((float*)T_relu)[_2.s7] = _5.s7;
}
}
if (TVMBackendFreeWorkspace(1, dev_id, conv2d_NCHWc_global) != 0) {
return -1;
}
if (TVMBackendFreeWorkspace(1, dev_id, conv2d_NCHWc) != 0) {
return -1;
}
}
if (TVMBackendFreeWorkspace(1, dev_id, data_pad) != 0) {
return -1;
}
if (TVMBackendFreeWorkspace(1, dev_id, data_vec) != 0) {
return -1;
}
return 0;
}
1.1.2. 使用 DNNL
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
mod = get_demo_mod()
mod = relay.transform.AnnotateTarget("dnnl")(mod)
mod = relay.transform.PartitionGraph()(mod)
print(mod)
def @main(%d1: Tensor[(1, 32, 56, 56), float32], %w1: Tensor[(32, 32, 3, 3), float32], %b1: Tensor[(32), float32]) -> Tensor[(1, 32, 56, 56), float32] {
%0 = @tvmgen_default_dnnl_main_0(%d1, %w1) /* ty=Tensor[(1, 32, 56, 56), float32] */;
%1 = nn.bias_add(%0, %b1) /* ty=Tensor[(1, 32, 56, 56), float32] */;
@tvmgen_default_dnnl_main_2(%1) /* ty=Tensor[(1, 32, 56, 56), float32] */
}
def @tvmgen_default_dnnl_main_0(%dnnl_0_i0: Tensor[(1, 32, 56, 56), float32], %dnnl_0_i1: Tensor[(32, 32, 3, 3), float32], Inline=1, Compiler="dnnl", global_symbol="tvmgen_default_dnnl_main_0", Primitive=1) -> Tensor[(1, 32, 56, 56), float32] {
nn.conv2d(%dnnl_0_i0, %dnnl_0_i1, padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 32, 56, 56), float32] */
}
def @tvmgen_default_dnnl_main_2(%dnnl_2_i0: Tensor[(1, 32, 56, 56), float32], Inline=1, Compiler="dnnl", global_symbol="tvmgen_default_dnnl_main_2", Primitive=1) -> Tensor[(1, 32, 56, 56), float32] {
nn.relu(%dnnl_2_i0) /* ty=Tensor[(1, 32, 56, 56), float32] */
}
with tvm.transform.PassContext(opt_level=2):
graph, lib, params = relay.build(mod, target="c", params=None)
lib.export_library("/tmp/liba.so")
print(lib.imported_modules[0].get_source())
// tvm target: c -keys=cpu -link-params=0
#define TVM_EXPORTS
#include "tvm/runtime/c_runtime_api.h"
#include "tvm/runtime/c_backend_api.h"
#include <math.h>
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_default_fused_nn_bias_add(void* args, void* arg_type_ids, int32_t num_args, void* out_ret_value, void* out_ret_tcode, void* resource_handle) {
void* arg0 = (((TVMValue*)args)[0].v_handle);
int32_t arg0_code = ((int32_t*)arg_type_ids)[(0)];
void* arg1 = (((TVMValue*)args)[1].v_handle);
int32_t arg1_code = ((int32_t*)arg_type_ids)[(1)];
void* arg2 = (((TVMValue*)args)[2].v_handle);
int32_t arg2_code = ((int32_t*)arg_type_ids)[(2)];
void* placeholder = (((DLTensor*)arg0)[0].data);
void* arg0_shape = (((DLTensor*)arg0)[0].shape);
void* arg0_strides = (((DLTensor*)arg0)[0].strides);
int32_t dev_id = (((DLTensor*)arg0)[0].device.device_id);
void* placeholder1 = (((DLTensor*)arg1)[0].data);
void* arg1_shape = (((DLTensor*)arg1)[0].shape);
void* arg1_strides = (((DLTensor*)arg1)[0].strides);
void* T_add = (((DLTensor*)arg2)[0].data);
void* arg2_shape = (((DLTensor*)arg2)[0].shape);
void* arg2_strides = (((DLTensor*)arg2)[0].strides);
if (!(arg0_strides == NULL)) {
}
if (!(arg1_strides == NULL)) {
}
if (!(arg2_strides == NULL)) {
}
for (int32_t ax0_ax1_fused = 0; ax0_ax1_fused < 32; ++ax0_ax1_fused) {
for (int32_t ax2 = 0; ax2 < 56; ++ax2) {
for (int32_t ax3_outer = 0; ax3_outer < 4; ++ax3_outer) {
for (int32_t ax3_inner_s = 0; ax3_inner_s < 16; ++ax3_inner_s) {
if (((ax3_outer * 16) + ax3_inner_s) < 56) {
((float*)T_add)[(((((ax0_ax1_fused * 3136) + (ax2 * 56)) + (ax3_outer * 16)) + ax3_inner_s))] = (((float*)placeholder)[(((((ax0_ax1_fused * 3136) + (ax2 * 56)) + (ax3_outer * 16)) + ax3_inner_s))] + ((float*)placeholder1)[(ax0_ax1_fused)]);
}
}
}
}
}
return 0;
}
/tmp/ipykernel_38953/3957980515.py:2: DeprecationWarning: legacy graph executor behavior of producing json / lib / params will be removed in the next release. Please see documents of tvm.contrib.graph_executor.GraphModule for the new recommended usage.
graph, lib, params = relay.build(mod, target="c", params=None)
print(lib.imported_modules[1].get_source())
{
"nodes": [
{
"op": "input",
"name": "dnnl_2_i0",
"attrs": {
"dtype": [
[
"float32"
]
],
"shape": [
[
[1, 32, 56, 56]
]
]
}
},
{
"op": "kernel",
"name": "nn.relu",
"inputs": [[
0,
0,
0]],
"attrs": {
"num_outputs": "1",
"shape": [
[
[1, 32, 56, 56]
]
],
"num_inputs": "1",
"dtype": [
[
"float32"
]
]
}
}
],
"arg_nodes": [0],
"heads": [[
1,
0,
0]],
"node_row_ptr": [0, 1, 2]
}
readelf -p .rodata /tmp/liba.so
String dump of section '.rodata':
[ 18] metadata
[ 40] tvmgen_default_dnnl_main_0^Z
[ 62] tvmgen_default_dnnl_main_2^B
[ 9c] dnnl_json^Z
[ ad] tvmgen_default_dnnl_main_0L^G
[ cf] {^J "nodes": [^J {^J "op": "input", ^J "name": "dnnl_0_i0", ^J "attrs": {^J "dtype": [^J [^J "float32"^J ]^J ], ^J "shape": [^J [^J [1, 32, 56, 56]^J ]^J ]^J }^J }, ^J {^J "op": "input", ^J "name": "dnnl_0_i1", ^J "attrs": {^J "dtype": [^J [^J "float32"^J ]^J ], ^J "shape": [^J [^J [32, 32, 3, 3]^J ]^J ]^J }^J }, ^J {^J "op": "kernel", ^J "name": "nn.conv2d", ^J "inputs": [[^J 0, ^J 0, ^J 0], [^J 1, ^J 0, ^J 0]], ^J "attrs": {^J "num_outputs": "1", ^J "num_inputs": "2", ^J "channels": [^J [^J ""^J ]^J ], ^J "out_layout": [^J [^J ""^J ]^J ], ^J "groups": [^J [^J "1"^J ]^J ], ^J "padding": [^J [^J "1", ^J "1", ^J "1", ^J "1"^J ]^J ], ^J "kernel_layout": [^J [^J "OIHW"^J ]^J ], ^J "strides": [^J [^J "1", ^J "1"^J ]^J ], ^J "dilation": [^J [^J "1", ^J "1"^J ]^J ], ^J "dtype": [^J [^J "float32"^J ]^J ], ^J "kernel_size": [^J [^J ""^J ]^J ], ^J "data_layout": [^J [^J "NCHW"^J ]^J ], ^J "out_dtype": [^J [^J ""^J ]^J ], ^J "shape": [^J [^J [1, 32, 56, 56]^J ]^J ]^J }^J }^J ], ^J "arg_nodes": [0, 1], ^J "heads": [[^J 2, ^J 0, ^J 0]], ^J "node_row_ptr": [0, 1, 2, 3]^J}
[ 82b] dnnl_json^Z
[ 83c] tvmgen_default_dnnl_main_2^B
[ 85e] {^J "nodes": [^J {^J "op": "input", ^J "name": "dnnl_2_i0", ^J "attrs": {^J "dtype": [^J [^J "float32"^J ]^J ], ^J "shape": [^J [^J [1, 32, 56, 56]^J ]^J ]^J }^J }, ^J {^J "op": "kernel", ^J "name": "nn.relu", ^J "inputs": [[^J 0, ^J 0, ^J 0]], ^J "attrs": {^J "num_outputs": "1", ^J "shape": [^J [^J [1, 32, 56, 56]^J ]^J ], ^J "num_inputs": "1", ^J "dtype": [^J [^J "float32"^J ]^J ]^J }^J }^J ], ^J "arg_nodes": [0], ^J "heads": [[^J 1, ^J 0, ^J 0]], ^J "node_row_ptr": [0, 1, 2]^J}
[ b46] _lib^L
[ b52] _import_tree^E
由于生成的 lib 中只能看到 dnnl 相应的 json 数据, 但找不到使用这段 json 的代码, 所以 lib 本身并非 self-contained, 需要 runtime 里也需要有 dnnl 对应的代码
hello_world/hello_tvm/graph_runner
1.2. BYOC Impl
1.2.1. Runtime
runtime 主要的任务是:
- 根据某些信息定位到 target runtime
- target runtime 找到 op 对应的 target codegen 的结果 (例如 json 文件) 并执行它
1.2.1.1. load module
runtime 第一步是用 LoadFromFile 加载 so, 这个 elf 有一个 __tvm_dev_mblob 符号,它的地址中保存的信息会导致不同的 target runtime 被加载进来 (除了 LoadFromFile 外, 还有 system-lib 方式: Micro TVM & BYOC)
LoadFromFile(name, fmt): std::string fmt = GetFileFormat(file_name, format); std::string load_f_name = "runtime.module.loadfile_" + fmt; // 假设 fmt 为 so // dso_library.cc TVM_REGISTER_GLOBAL("runtime.module.loadfile_so") .set_body([](TVMArgs args, TVMRetValue* rv) { auto n = make_object<DSOLibrary>(); n->Init(args[0]); *rv = CreateModuleFromLibrary(n); }); CreateModuleFromLibrary : // tvm_dev_mblob 为 __tvm_dev_mblob, build 生成的 elf 中有这个符号 // 实际上这个符号与 .rodata 的地址相同, 所以前面用 readelf -p .rodata // liba.so 能看到许多 json 数据: const char* dev_mblob = reinterpret_cast<const char*>( lib->GetSymbol(runtime::symbol::tvm_dev_mblob)); ProcessModuleBlob(dev_mblob, lib, &root_mod, &dso_ctx_addr); for (uint64_t i = 0; i < size; ++i) { // 对 dnnl 来说, tkey 为 dnnl_json // blob 中可以包含多个不同的 tkey, 导致可以有多个不同的 target runtime // 被加载, 每个 tkey (或 function name) 对应一个单独的 Module (或者叫 target // runtime) std::string tkey; ICHECK(stream->Read(&tkey)); auto m = LoadModuleFromBinary(tkey, stream); // 最终可能会返回一个包含多个 module 的 root_module, 每个 module 都只会包 // 含一个 symbol 的实现 modules.emplace_back(m); } LoadModuleFromBinary("dnnl_json", stream): std::string loadkey = "runtime.module.loadbinary_"; std::string fkey = loadkey + type_key; const PackedFunc* f = Registry::Get(fkey); // 拼成的函数是 runtime.module.loadbinary_dnnl_json, 这个函数在 // dnnl_json_runtime.cc 中定义 // 这里的 symbol 名与 graph 中的名字是一致的, 例如 tvmgen_default_dnnl_main_0, // graph_json 中的 json 数据则是 dnnl 实现 tvmgen_default_dnnl_main_0 需要的 json 数据 TVM_REGISTER_GLOBAL("runtime.module.loadbinary_dnnl_json") .set_body_typed(JSONRuntimeBase::LoadFromBinary<DNNLJSONRuntime>); ICHECK(stream->Read(&symbol)) << "Loading symbol name failed"; ICHECK(stream->Read(&graph_json)) << "Loading graph json failed"; ICHECK(stream->Read(&consts)) << "Loading the const name list failed"; auto n = make_object<T>(symbol, graph_json, const_names); return Module(n);
1.2.1.1.1. run the module
LoadFromFile 返回的 module 已经是一个可以实现运行的模型: 例如通过 module.GetFunction("tvmgen_default_dnnl_main_0") 就能通过 dnnl 执行 tvmgen_default_dnnl_main_0 这个 symbol 对应的 graph
不同的 module 有不同的 GetFunction 实现, 例如:
- 对于 LibraryModuleNode, GetFunction 实际就是对 dlsym 的封装
对于 DNNLJSONRuntime, GetFunction 的实现是对比一下 symbol 与自己的 symbol 是否相同, 因为每个 dnnl symbol_name_ 对应一个 DNNLJSONRuntime
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) override { if (name == "get_symbol") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_name_; }); } else if (this->symbol_name_ == name) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK(this->initialized_) << "The module has not been initialized"; this->SetInputOutputBuffers(args); // !! 这里的 Run 就是真正与 dnnl 交互的地方了 this->Run(); }); } // .... }
1.2.1.2. TVM Graph Executor
1.2.1.3. To Summarize
codegen 时生成的 lib 里包含了 __tvm_dev_mblob 信息, 其中 type_key (例如 dnnl_json) 用来指示 symbol 需要由哪个 target runtime 来处理
1.2.2. Relay Codegen
relay codegen 主要的任务:
- 注册 target codegen
- 根据 target 支持的 op 对 graph 进行 partition 和 annotate
- 保存 target codegen 的结果 (例如 __tvm_dev_mblob) 以便 target runtime 使用
1.2.2.1. AnnotateTarget
AnnotateTarget 的目的是:
- 根据 Relay Codegen 支持哪些 op 来修改 Relay IR 上各个 IR 的信息, 例如加上 Compiler=xxx 的 tag
- 根据 Relay Codegen 的要求合并某些 IR
mod = get_demo_mod() print(mod)
def @main(%d1: Tensor[(1, 32, 56, 56), float32], %w1: Tensor[(32, 32, 3, 3), float32], %b1: Tensor[(32), float32]) -> Tensor[(1, 32, 56, 56), float32] {
%0 = nn.conv2d(%d1, %w1, padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 32, 56, 56), float32] */;
%1 = nn.bias_add(%0, %b1) /* ty=Tensor[(1, 32, 56, 56), float32] */;
nn.relu(%1) /* ty=Tensor[(1, 32, 56, 56), float32] */
}
@tvm.ir.register_op_attr("nn.conv2d", "target.my")
def _my_conv2d_wrapper(expr):
return True
mod = relay.transform.AnnotateTarget("my")(mod)
mod = relay.transform.PartitionGraph()(mod)
print(mod)
def @main(%d1: Tensor[(1, 32, 56, 56), float32], %w1: Tensor[(32, 32, 3, 3), float32], %b1: Tensor[(32), float32]) -> Tensor[(1, 32, 56, 56), float32] {
%0 = @tvmgen_default_my_main_0(%d1, %w1) /* ty=Tensor[(1, 32, 56, 56), float32] */;
%1 = nn.bias_add(%0, %b1) /* ty=Tensor[(1, 32, 56, 56), float32] */;
nn.relu(%1) /* ty=Tensor[(1, 32, 56, 56), float32] */
}
def @tvmgen_default_my_main_0(%my_0_i0: Tensor[(1, 32, 56, 56), float32], %my_0_i1: Tensor[(32, 32, 3, 3), float32], Inline=1, Compiler="my", global_symbol="tvmgen_default_my_main_0", Primitive=1) -> Tensor[(1, 32, 56, 56), float32] {
nn.conv2d(%my_0_i0, %my_0_i1, padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 32, 56, 56), float32] */
}
1.2.2.2. 注册 codegen
TVM_REGISTER_GLOBAL("relay.ext.dnnl").set_body_typed(DNNLCompiler);
1.2.2.3. 编译 external functions
relay.transform.AnnotateTarget 之后 (使用 DNNL), 生成的 relay 针对 external functions (如 dnnl) 会有一个额外的 Compiler="dnnl" 信息, 编译时 LowerExternalFunctions 函数会根据这个信息调用对应的 compiler 来编译对应的 function
// te_compiler.cc Array<tvm::runtime::Module> LowerExternalFunctions(): // Compiler=xxx if (src_func->GetAttr<String>(attr::kCompiler).defined()) { // code_gen="dnnl" auto code_gen = src_func->GetAttr<String>(attr::kCompiler); // "relay.ext.dnnl" std::string ext_name = "relay.ext." + code_gen_name; auto pf = tvm::runtime::Registry::Get(ext_name); runtime::Module ext_mod = (*pf)(src_func); ret.push_back(ext_mod); TVM_REGISTER_GLOBAL("relay.ext.dnnl").set_body_typed(DNNLCompiler); // ref 即 src_func runtime::Module DNNLCompiler(const ObjectRef& ref) { ICHECK(ref->IsInstance<FunctionNode>()); auto func = Downcast<Function>(ref); auto func_name = GetExtSymbol(func); // !! serializer 是真正做编译的代码, 它会把 function 转换为 dnnl json, 结果为 // 下面的 graph_json DNNLJSONSerializer serializer(func_name, func); serializer.serialize(); std::string graph_json = serializer.GetJSON(); auto params = serializer.GetParams(); // 这里用 graph_json 生成了一个 dnnl runtime, 但并不是要用来 run, 而只是利用一 // 下 dnnl runtime 的 SaveToBinary const auto* pf = runtime::Registry::Get("runtime.DNNLJSONRuntimeCreate"); ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create"; auto mod = (*pf)(func_name, graph_json, params); return mod; }
1.2.2.4. SaveToBinary
tvm/docs/dev/introduction_to_module_serialization.rst
SerializeModule(dmlc::Stream* stream): for (const auto& group : mod_group_vec_) { // type_key 即 elf 中的 tkey, 例如 dnnl_json std::string mod_type_key = group[0]->type_key(); stream->Write(mod_type_key); group[0]->SaveToBinary(stream); JSONRuntimeBase::SaveToBinary(dmlc::Stream* stream) override { // Save the symbol stream->Write(symbol_name_); // Save the graph stream->Write(graph_json_); // Save the required const names std::vector<std::string> consts; for (const auto& it : const_names_) { consts.push_back(it); } stream->Write(consts); }
1.2.3. To Summarize
实现一个 BYOC 需要:
在 python 中通过 tvm.ir.register_op_attr 和 register_pattern_table 两个 annotation 来定义 codegen 支持某些 op, 后续编译时 compile_engine 会据此去找对应的 codegen
tvm/python/tvm/relay/op/contrib/dnnl.py::_register_external_op_helper("nn.batch_norm")实现一个 codegen, 并通过 relay.ext.xxx 注册
tvm/src/relay/backend/contrib/dnnl/codegen.cc::runtime::Module DNNLCompiler(const ObjectRef& ref)实现一个 runtime, 主要的函数是 LoadFromBinary, SaveToBinary, GetFunction
tvm/src/runtime/contrib/dnnl/dnnl_json_runtime.cc::TVM_REGISTER_GLOBAL("runtime.module.loadbinary_dnnl_json")
Backlinks
OpenMP (OpenMP > libgomp > target): offload 过程与 DPC++, TVM BYOC Codegen 以及 ComputeCpp 类似, 以 nvptx 为例, 主要 步骤是:
