TVM BYOC Codegen

Table of Contents

1. TVM BYOC Codegen

1.1. DNNL BYOC Example

1.1.1. 不使用 DNNL

import tvm
from tvm import relay


def get_demo_mod():
    d1 = relay.var("d1", shape=(1, 32, 56, 56), dtype="float32")
    w1 = relay.var("w1", shape=(32, 32, 3, 3), dtype="float32")
    b1 = relay.var("b1", shape=(32,), dtype="float32")
    conv = relay.nn.conv2d(d1, w1, strides=(1, 1), padding=(1, 1))
    bias = relay.nn.bias_add(conv, b1)
    relu = relay.nn.relu(bias)

    func = relay.Function([d1, w1, b1], relu)
    mod = tvm.IRModule.from_expr(func)
    mod = relay.transform.InferType()(mod)
    return mod


mod = get_demo_mod()
print(mod)
def @main(%d1: Tensor[(1, 32, 56, 56), float32], %w1: Tensor[(32, 32, 3, 3), float32], %b1: Tensor[(32), float32]) -> Tensor[(1, 32, 56, 56), float32] {
  %0 = nn.conv2d(%d1, %w1, padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 32, 56, 56), float32] */;
  %1 = nn.bias_add(%0, %b1) /* ty=Tensor[(1, 32, 56, 56), float32] */;
  nn.relu(%1) /* ty=Tensor[(1, 32, 56, 56), float32] */
}
with tvm.transform.PassContext(opt_level=2):
    graph, lib, params = relay.build(mod, target="c", params=None)

print(graph)
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
{
  "nodes": [
    {
      "op": "null",
      "name": "d1",
      "inputs": []
    },
    {
      "op": "null",
      "name": "w1",
      "inputs": []
    },
    {
      "op": "null",
      "name": "b1",
      "inputs": []
    },
    {
      "op": "tvm_op",
      "name": "tvmgen_default_fused_nn_conv2d_nn_bias_add_nn_relu",
      "attrs": {
        "num_outputs": "1",
        "flatten_data": "0",
        "out_layout": "",
        "func_name": "tvmgen_default_fused_nn_conv2d_nn_bias_add_nn_relu",
        "hash": "45f3f3bf6ba0201f",
        "num_inputs": "3",
        "kernel_layout": "OIHW",
        "data_layout": "NCHW"
      },
      "inputs": [
        [
          0,
          0,
          0
        ],
        [
          1,
          0,
          0
        ],
        [
          2,
          0,
          0
        ]
      ]
    }
  ],
  "arg_nodes": [0, 1, 2],
  "heads": [
    [
      3,
      0,
      0
    ]
  ],
  "attrs": {
    "dltype": [
      "list_str",
      [
        "float32",
        "float32",
        "float32",
        "float32"
      ]
    ],
    "shape": [
      "list_shape",
      [
        [1, 32, 56, 56],
        [32, 32, 3, 3],
        [32],
        [1, 32, 56, 56]
      ]
    ],
    "storage_id": [
      "list_int",
      [0, 1, 2, 3]
    ]
  },
  "node_row_ptr": [0, 1, 2, 3, 4]
}
/tmp/ipykernel_8688/48896817.py:2: DeprecationWarning: legacy graph executor behavior of producing json / lib / params will be removed in the next release. Please see documents of tvm.contrib.graph_executor.GraphModule for the  new recommended usage.
  graph, lib, params = relay.build(mod, target="c", params=None)
print(lib.get_source())
// tvm target: c -keys=cpu -link-params=0
#define TVM_EXPORTS
#include "tvm/runtime/c_runtime_api.h"
#include "tvm/runtime/c_backend_api.h"
#include <math.h>
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_default_fused_nn_conv2d_nn_bias_add_nn_relu(void* args, void* arg_type_ids, int32_t num_args, void* out_ret_value, void* out_ret_tcode, void* resource_handle) {
  void* arg0 = (((TVMValue*)args)[0].v_handle);
  int32_t arg0_code = ((int32_t*)arg_type_ids)[(0)];
  void* arg1 = (((TVMValue*)args)[1].v_handle);
  int32_t arg1_code = ((int32_t*)arg_type_ids)[(1)];
  void* arg2 = (((TVMValue*)args)[2].v_handle);
  int32_t arg2_code = ((int32_t*)arg_type_ids)[(2)];
  void* arg3 = (((TVMValue*)args)[3].v_handle);
  int32_t arg3_code = ((int32_t*)arg_type_ids)[(3)];
  void* placeholder = (((DLTensor*)arg0)[0].data);
  void* arg0_shape = (((DLTensor*)arg0)[0].shape);
  void* arg0_strides = (((DLTensor*)arg0)[0].strides);
  int32_t dev_id = (((DLTensor*)arg0)[0].device.device_id);
  void* placeholder1 = (((DLTensor*)arg1)[0].data);
  void* arg1_shape = (((DLTensor*)arg1)[0].shape);
  void* arg1_strides = (((DLTensor*)arg1)[0].strides);
  void* placeholder2 = (((DLTensor*)arg2)[0].data);
  void* arg2_shape = (((DLTensor*)arg2)[0].shape);
  void* arg2_strides = (((DLTensor*)arg2)[0].strides);
  void* T_relu = (((DLTensor*)arg3)[0].data);
  void* arg3_shape = (((DLTensor*)arg3)[0].shape);
  void* arg3_strides = (((DLTensor*)arg3)[0].strides);
  if (!(arg0_strides == NULL)) {
  }
  if (!(arg1_strides == NULL)) {
  }
  if (!(arg2_strides == NULL)) {
  }
  if (!(arg3_strides == NULL)) {
  }
  void* data_vec = TVMBackendAllocWorkspace(1, dev_id, (uint64_t)401408, 2, 32);
  if (data_vec == NULL) {
    return -1;
  }
  void* data_pad = TVMBackendAllocWorkspace(1, dev_id, (uint64_t)430592, 2, 32);
  if (data_pad == NULL) {
    return -1;
  }
  for (int32_t bs_c_fused_h_fused = 0; bs_c_fused_h_fused < 224; ++bs_c_fused_h_fused) {
    for (int32_t w = 0; w < 56; ++w) {
      for (int32_t vc = 0; vc < 8; ++vc) {
        ((float*)data_vec)[((((bs_c_fused_h_fused * 448) + (w * 8)) + vc))] = ((float*)placeholder)[((((((bs_c_fused_h_fused / 56) * 25088) + (vc * 3136)) + ((bs_c_fused_h_fused % 56) * 56)) + w))];
      }
    }
  }
  for (int32_t i0_i1_fused_i2_fused = 0; i0_i1_fused_i2_fused < 232; ++i0_i1_fused_i2_fused) {
    for (int32_t i3 = 0; i3 < 58; ++i3) {
      ((float8*)((float*)data_pad + (((i0_i1_fused_i2_fused * 464) + (i3 * 8)))))[0] = (((((1 <= (i0_i1_fused_i2_fused % 58)) && ((i0_i1_fused_i2_fused % 58) < 57)) && (1 <= i3)) && (i3 < 57)) ? ((float8*)((float*)data_vec + ((((((i0_i1_fused_i2_fused / 58) * 25088) + ((i0_i1_fused_i2_fused % 58) * 448)) + (i3 * 8)) - 456))))[0] : ((float8)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f)));
    }
  }
  for (int32_t occ_k_h_fused = 0; occ_k_h_fused < 12; ++occ_k_h_fused) {
    for (int32_t icc = 0; icc < 4; ++icc) {
      for (int32_t k_w = 0; k_w < 3; ++k_w) {
        for (int32_t icb = 0; icb < 8; ++icb) {
          int32_t8 _1 = (int8)((((((((occ_k_h_fused / 3) * 2304) + (icc * 72)) + (icb * 9)) + ((occ_k_h_fused % 3) * 3)) + k_w))+(288*0), (((((((occ_k_h_fused / 3) * 2304) + (icc * 72)) + (icb * 9)) + ((occ_k_h_fused % 3) * 3)) + k_w))+(288*1), (((((((occ_k_h_fused / 3) * 2304) + (icc * 72)) + (icb * 9)) + ((occ_k_h_fused % 3) * 3)) + k_w))+(288*2), (((((((occ_k_h_fused / 3) * 2304) + (icc * 72)) + (icb * 9)) + ((occ_k_h_fused % 3) * 3)) + k_w))+(288*3), (((((((occ_k_h_fused / 3) * 2304) + (icc * 72)) + (icb * 9)) + ((occ_k_h_fused % 3) * 3)) + k_w))+(288*4), (((((((occ_k_h_fused / 3) * 2304) + (icc * 72)) + (icb * 9)) + ((occ_k_h_fused % 3) * 3)) + k_w))+(288*5), (((((((occ_k_h_fused / 3) * 2304) + (icc * 72)) + (icb * 9)) + ((occ_k_h_fused % 3) * 3)) + k_w))+(288*6), (((((((occ_k_h_fused / 3) * 2304) + (icc * 72)) + (icb * 9)) + ((occ_k_h_fused % 3) * 3)) + k_w))+(288*7));
          ((float8*)((float*)data_vec + (((((((occ_k_h_fused / 3) * 2304) + (icc * 576)) + ((occ_k_h_fused % 3) * 192)) + (k_w * 64)) + (icb * 8)))))[0] = ((float8)(((float*)placeholder1)[_1.s0],((float*)placeholder1)[_1.s1],((float*)placeholder1)[_1.s2],((float*)placeholder1)[_1.s3],((float*)placeholder1)[_1.s4],((float*)placeholder1)[_1.s5],((float*)placeholder1)[_1.s6],((float*)placeholder1)[_1.s7]));
        }
      }
    }
  }
  for (int32_t ax0_ax1_outer_fused_ax2_fused = 0; ax0_ax1_outer_fused_ax2_fused < 224; ++ax0_ax1_outer_fused_ax2_fused) {
    void* conv2d_NCHWc = TVMBackendAllocWorkspace(1, dev_id, (uint64_t)1792, 2, 32);
    if (conv2d_NCHWc == NULL) {
      return -1;
    }
    void* conv2d_NCHWc_global = TVMBackendAllocWorkspace(1, dev_id, (uint64_t)896, 2, 32);
    if (conv2d_NCHWc_global == NULL) {
      return -1;
    }
    for (int32_t ow_outer = 0; ow_outer < 2; ++ow_outer) {
      ((float8*)((float*)conv2d_NCHWc_global + (0)))[0] = ((float8)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f));
      ((float8*)((float*)conv2d_NCHWc_global + (8)))[0] = ((float8)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f));
      ((float8*)((float*)conv2d_NCHWc_global + (16)))[0] = ((float8)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f));
      ((float8*)((float*)conv2d_NCHWc_global + (24)))[0] = ((float8)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f));
      ((float8*)((float*)conv2d_NCHWc_global + (32)))[0] = ((float8)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f));
      ((float8*)((float*)conv2d_NCHWc_global + (40)))[0] = ((float8)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f));
      ((float8*)((float*)conv2d_NCHWc_global + (48)))[0] = ((float8)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f));
      //...
      for (int32_t ic_outer = 0; ic_outer < 4; ++ic_outer) {
        for (int32_t kh = 0; kh < 3; ++kh) {
          for (int32_t kw = 0; kw < 3; ++kw) {
            for (int32_t ic_inner = 0; ic_inner < 8; ++ic_inner) {
              ((float8*)((float*)conv2d_NCHWc_global + (0)))[0] = (((float8*)((float*)conv2d_NCHWc_global + (0)))[0] + (((float8)(((float*)data_pad)[(((((((ic_outer * 26912) + (kh * 464)) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 464)) + (ow_outer * 224)) + (kw * 8)) + ic_inner))], ((float*)data_pad)[(((((((ic_outer * 26912) + (kh * 464)) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 464)) + (ow_outer * 224)) + (kw * 8)) + ic_inner))], ((float*)data_pad)[(((((((ic_outer * 26912) + (kh * 464)) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 464)) + (ow_outer * 224)) + (kw * 8)) + ic_inner))], ((float*)data_pad)[(((((((ic_outer * 26912) + (kh * 464)) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 464)) + (ow_outer * 224)) + (kw * 8)) + ic_inner))], ((float*)data_pad)[(((((((ic_outer * 26912) + (kh * 464)) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 464)) + (ow_outer * 224)) + (kw * 8)) + ic_inner))], ((float*)data_pad)[(((((((ic_outer * 26912) + (kh * 464)) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 464)) + (ow_outer * 224)) + (kw * 8)) + ic_inner))], ((float*)data_pad)[(((((((ic_outer * 26912) + (kh * 464)) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 464)) + (ow_outer * 224)) + (kw * 8)) + ic_inner))], ((float*)data_pad)[(((((((ic_outer * 26912) + (kh * 464)) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 464)) + (ow_outer * 224)) + (kw * 8)) + ic_inner))])) * ((float8*)((float*)data_vec + (((((((ax0_ax1_outer_fused_ax2_fused / 56) * 2304) + (ic_outer * 576)) + (kh * 192)) + (kw * 64)) + (ic_inner * 8)))))[0]));
              //...
          }
        }
      }
      for (int32_t ow_inner = 0; ow_inner < 28; ++ow_inner) {
        ((float8*)((float*)conv2d_NCHWc + (((ow_outer * 224) + (ow_inner * 8)))))[0] = ((float8*)((float*)conv2d_NCHWc_global + ((ow_inner * 8))))[0];
      }
    }
    for (int32_t ax3_outer = 0; ax3_outer < 2; ++ax3_outer) {
      for (int32_t ax3_inner = 0; ax3_inner < 28; ++ax3_inner) {
          int32_t8 _2 = (int8)(((((((ax0_ax1_outer_fused_ax2_fused / 56) * 25088) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 56)) + (ax3_outer * 28)) + ax3_inner))+(3136*0), ((((((ax0_ax1_outer_fused_ax2_fused / 56) * 25088) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 56)) + (ax3_outer * 28)) + ax3_inner))+(3136*1), ((((((ax0_ax1_outer_fused_ax2_fused / 56) * 25088) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 56)) + (ax3_outer * 28)) + ax3_inner))+(3136*2), ((((((ax0_ax1_outer_fused_ax2_fused / 56) * 25088) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 56)) + (ax3_outer * 28)) + ax3_inner))+(3136*3), ((((((ax0_ax1_outer_fused_ax2_fused / 56) * 25088) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 56)) + (ax3_outer * 28)) + ax3_inner))+(3136*4), ((((((ax0_ax1_outer_fused_ax2_fused / 56) * 25088) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 56)) + (ax3_outer * 28)) + ax3_inner))+(3136*5), ((((((ax0_ax1_outer_fused_ax2_fused / 56) * 25088) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 56)) + (ax3_outer * 28)) + ax3_inner))+(3136*6), ((((((ax0_ax1_outer_fused_ax2_fused / 56) * 25088) + ((ax0_ax1_outer_fused_ax2_fused % 56) * 56)) + (ax3_outer * 28)) + ax3_inner))+(3136*7));
          float8 _3 = ((float8*)((float*)conv2d_NCHWc + (((ax3_outer * 224) + (ax3_inner * 8)))))[0] + ((float8*)((float*)placeholder2 + (((ax0_ax1_outer_fused_ax2_fused / 56) * 8))))[0];
          float8 _4 = (float8)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f);
          float8 _5 = (_3) > (_4) ? (_3) : (_4);
          ((float*)T_relu)[_2.s0] = _5.s0;
          ((float*)T_relu)[_2.s1] = _5.s1;
          ((float*)T_relu)[_2.s2] = _5.s2;
          ((float*)T_relu)[_2.s3] = _5.s3;
          ((float*)T_relu)[_2.s4] = _5.s4;
          ((float*)T_relu)[_2.s5] = _5.s5;
          ((float*)T_relu)[_2.s6] = _5.s6;
          ((float*)T_relu)[_2.s7] = _5.s7;
      }
    }
    if (TVMBackendFreeWorkspace(1, dev_id, conv2d_NCHWc_global) != 0) {
      return -1;
    }
    if (TVMBackendFreeWorkspace(1, dev_id, conv2d_NCHWc) != 0) {
      return -1;
    }
  }
  if (TVMBackendFreeWorkspace(1, dev_id, data_pad) != 0) {
    return -1;
  }
  if (TVMBackendFreeWorkspace(1, dev_id, data_vec) != 0) {
    return -1;
  }
  return 0;
}

1.1.2. 使用 DNNL

if not tvm.get_global_func("relay.ext.dnnl", True):
    print("skip because DNNL codegen is not available")

mod = get_demo_mod()
mod = relay.transform.AnnotateTarget("dnnl")(mod)
mod = relay.transform.PartitionGraph()(mod)
print(mod)

def @main(%d1: Tensor[(1, 32, 56, 56), float32], %w1: Tensor[(32, 32, 3, 3), float32], %b1: Tensor[(32), float32]) -> Tensor[(1, 32, 56, 56), float32] {
  %0 = @tvmgen_default_dnnl_main_0(%d1, %w1) /* ty=Tensor[(1, 32, 56, 56), float32] */;
  %1 = nn.bias_add(%0, %b1) /* ty=Tensor[(1, 32, 56, 56), float32] */;
  @tvmgen_default_dnnl_main_2(%1) /* ty=Tensor[(1, 32, 56, 56), float32] */
}

def @tvmgen_default_dnnl_main_0(%dnnl_0_i0: Tensor[(1, 32, 56, 56), float32], %dnnl_0_i1: Tensor[(32, 32, 3, 3), float32], Inline=1, Compiler="dnnl", global_symbol="tvmgen_default_dnnl_main_0", Primitive=1) -> Tensor[(1, 32, 56, 56), float32] {
  nn.conv2d(%dnnl_0_i0, %dnnl_0_i1, padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 32, 56, 56), float32] */
}

def @tvmgen_default_dnnl_main_2(%dnnl_2_i0: Tensor[(1, 32, 56, 56), float32], Inline=1, Compiler="dnnl", global_symbol="tvmgen_default_dnnl_main_2", Primitive=1) -> Tensor[(1, 32, 56, 56), float32] {
  nn.relu(%dnnl_2_i0) /* ty=Tensor[(1, 32, 56, 56), float32] */
}
with tvm.transform.PassContext(opt_level=2):
    graph, lib, params = relay.build(mod, target="c", params=None)

lib.export_library("/tmp/liba.so")
print(lib.imported_modules[0].get_source())
// tvm target: c -keys=cpu -link-params=0
#define TVM_EXPORTS
#include "tvm/runtime/c_runtime_api.h"
#include "tvm/runtime/c_backend_api.h"
#include <math.h>
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_default_fused_nn_bias_add(void* args, void* arg_type_ids, int32_t num_args, void* out_ret_value, void* out_ret_tcode, void* resource_handle) {
  void* arg0 = (((TVMValue*)args)[0].v_handle);
  int32_t arg0_code = ((int32_t*)arg_type_ids)[(0)];
  void* arg1 = (((TVMValue*)args)[1].v_handle);
  int32_t arg1_code = ((int32_t*)arg_type_ids)[(1)];
  void* arg2 = (((TVMValue*)args)[2].v_handle);
  int32_t arg2_code = ((int32_t*)arg_type_ids)[(2)];
  void* placeholder = (((DLTensor*)arg0)[0].data);
  void* arg0_shape = (((DLTensor*)arg0)[0].shape);
  void* arg0_strides = (((DLTensor*)arg0)[0].strides);
  int32_t dev_id = (((DLTensor*)arg0)[0].device.device_id);
  void* placeholder1 = (((DLTensor*)arg1)[0].data);
  void* arg1_shape = (((DLTensor*)arg1)[0].shape);
  void* arg1_strides = (((DLTensor*)arg1)[0].strides);
  void* T_add = (((DLTensor*)arg2)[0].data);
  void* arg2_shape = (((DLTensor*)arg2)[0].shape);
  void* arg2_strides = (((DLTensor*)arg2)[0].strides);
  if (!(arg0_strides == NULL)) {
  }
  if (!(arg1_strides == NULL)) {
  }
  if (!(arg2_strides == NULL)) {
  }
  for (int32_t ax0_ax1_fused = 0; ax0_ax1_fused < 32; ++ax0_ax1_fused) {
    for (int32_t ax2 = 0; ax2 < 56; ++ax2) {
      for (int32_t ax3_outer = 0; ax3_outer < 4; ++ax3_outer) {
        for (int32_t ax3_inner_s = 0; ax3_inner_s < 16; ++ax3_inner_s) {
          if (((ax3_outer * 16) + ax3_inner_s) < 56) {
            ((float*)T_add)[(((((ax0_ax1_fused * 3136) + (ax2 * 56)) + (ax3_outer * 16)) + ax3_inner_s))] = (((float*)placeholder)[(((((ax0_ax1_fused * 3136) + (ax2 * 56)) + (ax3_outer * 16)) + ax3_inner_s))] + ((float*)placeholder1)[(ax0_ax1_fused)]);
          }
        }
      }
    }
  }
  return 0;
}


/tmp/ipykernel_38953/3957980515.py:2: DeprecationWarning: legacy graph executor behavior of producing json / lib / params will be removed in the next release. Please see documents of tvm.contrib.graph_executor.GraphModule for the  new recommended usage.
  graph, lib, params = relay.build(mod, target="c", params=None)
print(lib.imported_modules[1].get_source())
{
  "nodes": [
    {
      "op": "input",
      "name": "dnnl_2_i0",
      "attrs": {
        "dtype": [
          [
            "float32"
          ]
        ],
        "shape": [
          [
            [1, 32, 56, 56]
          ]
        ]
      }
    },
    {
      "op": "kernel",
      "name": "nn.relu",
      "inputs": [[
          0,
          0,
          0]],
      "attrs": {
        "num_outputs": "1",
        "shape": [
          [
            [1, 32, 56, 56]
          ]
        ],
        "num_inputs": "1",
        "dtype": [
          [
            "float32"
          ]
        ]
      }
    }
  ],
  "arg_nodes": [0],
  "heads": [[
      1,
      0,
      0]],
  "node_row_ptr": [0, 1, 2]
}
readelf -p .rodata /tmp/liba.so
String dump of section '.rodata':
  [    18]  metadata
  [    40]  tvmgen_default_dnnl_main_0^Z
  [    62]  tvmgen_default_dnnl_main_2^B
  [    9c]  dnnl_json^Z
  [    ad]  tvmgen_default_dnnl_main_0L^G
  [    cf]  {^J  "nodes": [^J    {^J      "op": "input", ^J      "name": "dnnl_0_i0", ^J      "attrs": {^J        "dtype": [^J          [^J            "float32"^J          ]^J        ], ^J        "shape": [^J          [^J            [1, 32, 56, 56]^J          ]^J        ]^J      }^J    }, ^J    {^J      "op": "input", ^J      "name": "dnnl_0_i1", ^J      "attrs": {^J        "dtype": [^J          [^J            "float32"^J          ]^J        ], ^J        "shape": [^J          [^J            [32, 32, 3, 3]^J          ]^J        ]^J      }^J    }, ^J    {^J      "op": "kernel", ^J      "name": "nn.conv2d", ^J      "inputs": [[^J          0, ^J          0, ^J          0], [^J          1, ^J          0, ^J          0]], ^J      "attrs": {^J        "num_outputs": "1", ^J        "num_inputs": "2", ^J        "channels": [^J          [^J            ""^J          ]^J        ], ^J        "out_layout": [^J          [^J            ""^J          ]^J        ], ^J        "groups": [^J          [^J            "1"^J          ]^J        ], ^J        "padding": [^J          [^J            "1", ^J            "1", ^J            "1", ^J            "1"^J          ]^J        ], ^J        "kernel_layout": [^J          [^J            "OIHW"^J          ]^J        ], ^J        "strides": [^J          [^J            "1", ^J            "1"^J          ]^J        ], ^J        "dilation": [^J          [^J            "1", ^J            "1"^J          ]^J        ], ^J        "dtype": [^J          [^J            "float32"^J          ]^J        ], ^J        "kernel_size": [^J          [^J            ""^J          ]^J        ], ^J        "data_layout": [^J          [^J            "NCHW"^J          ]^J        ], ^J        "out_dtype": [^J          [^J            ""^J          ]^J        ], ^J        "shape": [^J          [^J            [1, 32, 56, 56]^J          ]^J        ]^J      }^J    }^J  ], ^J  "arg_nodes": [0, 1], ^J  "heads": [[^J      2, ^J      0, ^J      0]], ^J  "node_row_ptr": [0, 1, 2, 3]^J}
  [   82b]  dnnl_json^Z
  [   83c]  tvmgen_default_dnnl_main_2^B
  [   85e]  {^J  "nodes": [^J    {^J      "op": "input", ^J      "name": "dnnl_2_i0", ^J      "attrs": {^J        "dtype": [^J          [^J            "float32"^J          ]^J        ], ^J        "shape": [^J          [^J            [1, 32, 56, 56]^J          ]^J        ]^J      }^J    }, ^J    {^J      "op": "kernel", ^J      "name": "nn.relu", ^J      "inputs": [[^J          0, ^J          0, ^J          0]], ^J      "attrs": {^J        "num_outputs": "1", ^J        "shape": [^J          [^J            [1, 32, 56, 56]^J          ]^J        ], ^J        "num_inputs": "1", ^J        "dtype": [^J          [^J            "float32"^J          ]^J        ]^J      }^J    }^J  ], ^J  "arg_nodes": [0], ^J  "heads": [[^J      1, ^J      0, ^J      0]], ^J  "node_row_ptr": [0, 1, 2]^J}
  [   b46]  _lib^L
  [   b52]  _import_tree^E

由于生成的 lib 中只能看到 dnnl 相应的 json 数据, 但找不到使用这段 json 的代码, 所以 lib 本身并非 self-contained, 需要 runtime 里也需要有 dnnl 对应的代码

hello_world/hello_tvm/graph_runner

1.2. BYOC Impl

1.2.1. Runtime

runtime 主要的任务是:

  • 根据某些信息定位到 target runtime
  • target runtime 找到 op 对应的 target codegen 的结果 (例如 json 文件) 并执行它

what is runtime

1.2.1.1. load module

runtime 第一步是用 LoadFromFile 加载 so, 这个 elf 有一个 __tvm_dev_mblob 符号,它的地址中保存的信息会导致不同的 target runtime 被加载进来 (除了 LoadFromFile 外, 还有 system-lib 方式: Micro TVM & BYOC)

LoadFromFile(name, fmt):
  std::string fmt = GetFileFormat(file_name, format);
  std::string load_f_name = "runtime.module.loadfile_" + fmt;
  // 假设 fmt 为 so
  // dso_library.cc
    TVM_REGISTER_GLOBAL("runtime.module.loadfile_so")
      .set_body([](TVMArgs args, TVMRetValue* rv) {
        auto n = make_object<DSOLibrary>();
        n->Init(args[0]);
        *rv = CreateModuleFromLibrary(n);
      });

CreateModuleFromLibrary :
  // tvm_dev_mblob 为 __tvm_dev_mblob, build 生成的 elf 中有这个符号
  // 实际上这个符号与 .rodata 的地址相同, 所以前面用 readelf -p .rodata
  // liba.so 能看到许多 json 数据:
  const char* dev_mblob = reinterpret_cast<const char*>(
        lib->GetSymbol(runtime::symbol::tvm_dev_mblob));
  ProcessModuleBlob(dev_mblob, lib, &root_mod, &dso_ctx_addr);
    for (uint64_t i = 0; i < size; ++i) {
      // 对 dnnl 来说, tkey 为 dnnl_json
      // blob 中可以包含多个不同的 tkey, 导致可以有多个不同的 target runtime
      // 被加载, 每个 tkey (或 function name) 对应一个单独的 Module (或者叫 target
      // runtime)
      std::string tkey;
      ICHECK(stream->Read(&tkey));
      auto m = LoadModuleFromBinary(tkey, stream);
      // 最终可能会返回一个包含多个 module 的 root_module, 每个 module 都只会包
      // 含一个 symbol 的实现
      modules.emplace_back(m);
    }

LoadModuleFromBinary("dnnl_json", stream):
  std::string loadkey = "runtime.module.loadbinary_";
  std::string fkey = loadkey + type_key;
  const PackedFunc* f = Registry::Get(fkey);
  // 拼成的函数是 runtime.module.loadbinary_dnnl_json, 这个函数在
  // dnnl_json_runtime.cc 中定义
  // 这里的 symbol 名与 graph 中的名字是一致的, 例如 tvmgen_default_dnnl_main_0,
  // graph_json 中的 json 数据则是 dnnl 实现 tvmgen_default_dnnl_main_0 需要的 json 数据
    TVM_REGISTER_GLOBAL("runtime.module.loadbinary_dnnl_json")
      .set_body_typed(JSONRuntimeBase::LoadFromBinary<DNNLJSONRuntime>);
      ICHECK(stream->Read(&symbol)) << "Loading symbol name failed";
      ICHECK(stream->Read(&graph_json)) << "Loading graph json failed";
      ICHECK(stream->Read(&consts)) << "Loading the const name list failed";
      auto n = make_object<T>(symbol, graph_json, const_names);
      return Module(n);

1.2.1.1.1. run the module

LoadFromFile 返回的 module 已经是一个可以实现运行的模型: 例如通过 module.GetFunction("tvmgen_default_dnnl_main_0") 就能通过 dnnl 执行 tvmgen_default_dnnl_main_0 这个 symbol 对应的 graph

不同的 module 有不同的 GetFunction 实现, 例如:

  • 对于 LibraryModuleNode, GetFunction 实际就是对 dlsym 的封装
  • 对于 DNNLJSONRuntime, GetFunction 的实现是对比一下 symbol 与自己的 symbol 是否相同, 因为每个 dnnl symbol_name_ 对应一个 DNNLJSONRuntime

    PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) override {
        if (name == "get_symbol") {
          return PackedFunc(
              [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_name_; });
        } else if (this->symbol_name_ == name) {
          return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
            ICHECK(this->initialized_) << "The module has not been initialized";
    
            this->SetInputOutputBuffers(args);
            // !! 这里的 Run 就是真正与 dnnl 交互的地方了
            this->Run();
          });
        }
        // ....
      }
    
    
1.2.1.3. To Summarize

codegen 时生成的 lib 里包含了 __tvm_dev_mblob 信息, 其中 type_key (例如 dnnl_json) 用来指示 symbol 需要由哪个 target runtime 来处理

1.2.2. Relay Codegen

relay codegen 主要的任务:

  • 注册 target codegen
  • 根据 target 支持的 op 对 graph 进行 partition 和 annotate
  • 保存 target codegen 的结果 (例如 __tvm_dev_mblob) 以便 target runtime 使用
1.2.2.1. AnnotateTarget

AnnotateTarget 的目的是:

  1. 根据 Relay Codegen 支持哪些 op 来修改 Relay IR 上各个 IR 的信息, 例如加上 Compiler=xxx 的 tag
  2. 根据 Relay Codegen 的要求合并某些 IR
mod = get_demo_mod()
print(mod)
def @main(%d1: Tensor[(1, 32, 56, 56), float32], %w1: Tensor[(32, 32, 3, 3), float32], %b1: Tensor[(32), float32]) -> Tensor[(1, 32, 56, 56), float32] {
  %0 = nn.conv2d(%d1, %w1, padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 32, 56, 56), float32] */;
  %1 = nn.bias_add(%0, %b1) /* ty=Tensor[(1, 32, 56, 56), float32] */;
  nn.relu(%1) /* ty=Tensor[(1, 32, 56, 56), float32] */
}

@tvm.ir.register_op_attr("nn.conv2d", "target.my")
def _my_conv2d_wrapper(expr):
    return True

mod = relay.transform.AnnotateTarget("my")(mod)
mod = relay.transform.PartitionGraph()(mod)
print(mod)
def @main(%d1: Tensor[(1, 32, 56, 56), float32], %w1: Tensor[(32, 32, 3, 3), float32], %b1: Tensor[(32), float32]) -> Tensor[(1, 32, 56, 56), float32] {
  %0 = @tvmgen_default_my_main_0(%d1, %w1) /* ty=Tensor[(1, 32, 56, 56), float32] */;
  %1 = nn.bias_add(%0, %b1) /* ty=Tensor[(1, 32, 56, 56), float32] */;
  nn.relu(%1) /* ty=Tensor[(1, 32, 56, 56), float32] */
}

def @tvmgen_default_my_main_0(%my_0_i0: Tensor[(1, 32, 56, 56), float32], %my_0_i1: Tensor[(32, 32, 3, 3), float32], Inline=1, Compiler="my", global_symbol="tvmgen_default_my_main_0", Primitive=1) -> Tensor[(1, 32, 56, 56), float32] {
  nn.conv2d(%my_0_i0, %my_0_i1, padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 32, 56, 56), float32] */
}

1.2.2.2. 注册 codegen
TVM_REGISTER_GLOBAL("relay.ext.dnnl").set_body_typed(DNNLCompiler);
1.2.2.3. 编译 external functions

relay.transform.AnnotateTarget 之后 (使用 DNNL), 生成的 relay 针对 external functions (如 dnnl) 会有一个额外的 Compiler="dnnl" 信息, 编译时 LowerExternalFunctions 函数会根据这个信息调用对应的 compiler 来编译对应的 function

// te_compiler.cc
Array<tvm::runtime::Module> LowerExternalFunctions():
  // Compiler=xxx
  if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
    // code_gen="dnnl"
    auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
    // "relay.ext.dnnl"
    std::string ext_name = "relay.ext." + code_gen_name;
    auto pf = tvm::runtime::Registry::Get(ext_name);
    runtime::Module ext_mod = (*pf)(src_func);
    ret.push_back(ext_mod);

TVM_REGISTER_GLOBAL("relay.ext.dnnl").set_body_typed(DNNLCompiler);

// ref 即 src_func
runtime::Module DNNLCompiler(const ObjectRef& ref) {
  ICHECK(ref->IsInstance<FunctionNode>());
  auto func = Downcast<Function>(ref);
  auto func_name = GetExtSymbol(func);
  // !! serializer 是真正做编译的代码, 它会把 function 转换为 dnnl json, 结果为
  // 下面的 graph_json
  DNNLJSONSerializer serializer(func_name, func);
  serializer.serialize();
  std::string graph_json = serializer.GetJSON();
  auto params = serializer.GetParams();

  // 这里用 graph_json 生成了一个 dnnl runtime, 但并不是要用来 run, 而只是利用一
  // 下 dnnl runtime 的 SaveToBinary
  const auto* pf = runtime::Registry::Get("runtime.DNNLJSONRuntimeCreate");
  ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create";
  auto mod = (*pf)(func_name, graph_json, params);
  return mod;
}
1.2.2.4. SaveToBinary

tvm/docs/dev/introduction_to_module_serialization.rst

SerializeModule(dmlc::Stream* stream):
  for (const auto& group : mod_group_vec_) {
    // type_key 即 elf 中的 tkey, 例如 dnnl_json
    std::string mod_type_key = group[0]->type_key();
    stream->Write(mod_type_key);
    group[0]->SaveToBinary(stream);

JSONRuntimeBase::SaveToBinary(dmlc::Stream* stream) override {
  // Save the symbol
  stream->Write(symbol_name_);
  // Save the graph
  stream->Write(graph_json_);
  // Save the required const names
  std::vector<std::string> consts;
  for (const auto& it : const_names_) {
    consts.push_back(it);
  }
  stream->Write(consts);
}

1.2.3. To Summarize

实现一个 BYOC 需要:

  1. 在 python 中通过 tvm.ir.register_op_attr 和 register_pattern_table 两个 annotation 来定义 codegen 支持某些 op, 后续编译时 compile_engine 会据此去找对应的 codegen

    tvm/python/tvm/relay/op/contrib/dnnl.py::_register_external_op_helper("nn.batch_norm")

  2. 实现一个 codegen, 并通过 relay.ext.xxx 注册

    tvm/src/relay/backend/contrib/dnnl/codegen.cc::runtime::Module DNNLCompiler(const ObjectRef& ref)

  3. 实现一个 runtime, 主要的函数是 LoadFromBinary, SaveToBinary, GetFunction

    tvm/src/runtime/contrib/dnnl/dnnl_json_runtime.cc::TVM_REGISTER_GLOBAL("runtime.module.loadbinary_dnnl_json")

Backlinks

OpenMP (OpenMP > libgomp > target): offload 过程与 DPC++, TVM BYOC Codegen 以及 ComputeCpp 类似, 以 nvptx 为例, 主要 步骤是:

Author: [email protected]
Date: 2021-08-03 Tue 00:00
Last updated: 2023-12-01 Fri 18:28

知识共享许可协议