TVM Graph Executor

Table of Contents

1. TVM Graph Executor

对于一个 nn 来说, codegen 会编译成多个 function, function 之间的调用关系由 relay.build 返回的 graph 来描述.

graph executor 的作用是根据 graph 依次执行所有相关的 module

1.1. Example

import numpy as np
import tvm
from tvm import relay

a = relay.var("a", shape=(1, 10), dtype="float32")
b = relay.var("b", shape=(1, 10), dtype="float32")
c = relay.var("c", shape=(1, 10), dtype="float32")
out = relay.add(a, b)
out = relay.add(out, c)

func = relay.Function([a, b, c], out)
mod = tvm.IRModule.from_expr(func)

print(mod)
with tvm.transform.PassContext(opt_level=0):
    graph, lib, params = relay.build(mod, target="c", params=None)
print(graph)

# x, y, z = np.ones((1, 10)), np.ones((1, 10)), np.ones((1, 10))
# intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm")
# op_res = intrp.evaluate(func)(x, y, z)

def @main(%a: Tensor[(1, 10), float32], %b: Tensor[(1, 10), float32], %c: Tensor[(1, 10), float32]) { %0 = add(%a, %b); add(%0, %c) }

{ "nodes": [ { "op": "null", "name": "a", "inputs": [] }, { "op": "null", "name": "b", "inputs": [] }, { "op": "null", "name": "c", "inputs": [] }, { "op": "tvm_op", "name": "tvmgen_default_fused_add", "attrs": { "num_inputs": "2", "num_outputs": "1", "hash": "8ec3c08331cd92b5", "flatten_data": "0", "func_name": "tvmgen_default_fused_add" }, "inputs": [ [ 0, 0, 0 ], [ 1, 0, 0 ] ] }, { "op": "tvm_op", "name": "tvmgen_default_fused_add1", "attrs": { "num_inputs": "2", "num_outputs": "1", "hash": "8ec3c08331cd92b5", "flatten_data": "0", "func_name": "tvmgen_default_fused_add" }, "inputs": [ [ 3, 0, 0 ], [ 2, 0, 0 ] ] } ], "arg_nodes": [0, 1, 2], "heads": [ [ 4, 0, 0 ] ], "attrs": { "dltype": [ "list_str", [ "float32", "float32", "float32", "float32", "float32" ] ], "shape": [ "list_shape", [ [1, 10], [1, 10], [1, 10], [1, 10], [1, 10] ] ], "storage_id": [ "list_int", [0, 1, 2, 3, 4] ] }, "node_row_ptr": [0, 1, 2, 3, 4, 5] } /tmp/ipykernel_218993/2066100685.py:16: DeprecationWarning: legacy graph executor behavior of producing json / lib / params will be removed in the next release. Please see documents of tvm.contrib.graph_executor.GraphModule for the new recommended usage. graph, lib, params = relay.build(mod, target="c", params=None)

import numpy as np

x, y, z = np.ones((1, 10)), np.ones((1, 10)), np.ones((1, 10))
intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x, y, z)
print(op_res)

1.2. Impl

1.2.1. init

const runtime::PackedFunc* graph_executor = tvm::runtime::Registry::Get("tvm.graph_executor.create");
  Module GraphExecutorCreate(const std::string& sym_json, const tvm::runtime::Module& m,
                           const std::vector<Device>& devs,
                           const PackedFunc lookup_linked_param_func)
    auto exec = make_object<GraphExecutor>();
    exec->Init(sym_json, m, devs, lookup_linked_param_func);
    return Module(exec);

Init:
  this->SetupStorage();
  this->SetupOpExecs();
1.2.1.1. graph format
# 一共有 5 个 node
{
  "nodes": [
    {
      "op": "null", 
      "name": "a", 
      "inputs": []
    }, 
    {
      "op": "null", 
      "name": "b", 
      "inputs": []
    }, 
    {
      "op": "null", 
      "name": "c", 
      "inputs": []
    }, 
    {
      "op": "tvm_op", 
      "name": "tvmgen_default_fused_add", 
      "attrs": {
        "num_outputs": "1", 
        "num_inputs": "2", 
        "flatten_data": "0", 
        "func_name": "tvmgen_default_fused_add", 
        "hash": "8ec3c08331cd92b5"
      },
      # 第一个 add 操作有两个 input:
      # 第1个为 [0,0,0], 表示 node[0] 的 第 0 个输出, 即 a
      # 第2个为 [1,0,0], 表示 node[1] 的 第 0 个输出, 即 b
      # 这三个数的意义为 [node,index,version]
      "inputs": [
        [
          0, 
          0, 
          0
        ], 
        [
          1, 
          0, 
          0
        ]
      ]
    }, 
    {
      "op": "tvm_op", 
      "name": "tvmgen_default_fused_add1", 
      "attrs": {
        "num_outputs": "1", 
        "num_inputs": "2", 
        "flatten_data": "0", 
        "func_name": "tvmgen_default_fused_add", 
        "hash": "8ec3c08331cd92b5"
      },
      # 第二个 add 操作的两个 input:
      # 第1个为 [3,0,0], 表示 node[3] 的 第 0 个输出, 即第一个 add 操作的唯一的输出
      # 第2个为 [2,0,0], 表示 node[2] 的 第 0 个输出, 即 c
      # 这三个数的意义为 [node,index,version]      
      "inputs": [
        [
          3, 
          0, 
          0
        ], 
        [
          2, 
          0, 
          0
        ]
      ]
    }
  ],
  # arg_nodes 表示模型的输入 node, 即 a,b,c 三个 node 是输入
  "arg_nodes": [0, 1, 2],
  # heads 表示模型的输出, [4,0,0] 即 node[4] 的第 0 个输出, 即第二个 add 操作的输出
  "heads": [
    [
      4, 
      0, 
      0
    ]
  ], 
  "attrs": {
    "dltype": [
      "list_str", 
      [
        "float32", 
        "float32", 
        "float32", 
        "float32", 
        "float32"
      ]
    ],
    #storage_id 是代表的是每个输出的 storage_id, 如果两个输出不会同时使用, 则它
    #们可能使用相同的 storage_id 以节省内存, 由于 storage_id 在 graph 中已经计算
    #好, SetupStorage 的代码会比较简单
    "storage_id": [
      "list_int", 
      [0, 1, 2, 3, 4]
    ], 
    "shape": [
      "list_shape", 
      [
        [1, 10], 
        [1, 10], 
        [1, 10], 
        [1, 10], 
        [1, 10]
      ]
    ]
  },
  # node_row_ptr 用来快速给每一个输出编码, 进而获得它对应的 storage, 例如
  # 三个 node 输出个数分别为 1,3,1, 则 node_row_ptr 为
  # [0, 1, 4, 5]
  # 则 node[0][0] 为 0, node[2,0] = node_row_ptr[2]+0=4
  # 相当于把二维的 [node,index] 转换为一维, 参考 entry_id 函数
  "node_row_ptr": [0, 1, 2, 3, 4, 5]
}

1.2.1.2. SetupStorage

SetupStorage 的作用基本上是:

1.2.1.2.1. device_index

根据 graph 中的 attrs.device_index 确定 expr 使用的 device, 将来会在 device 上分配 DLTensor

其中 attris.device_index 是由 GraphExecutorCodegen 生成的 (根据 on_device annotation)

for (const auto& pit : pool_entry) {
    // This for loop is very fast since there are usually only a couple of
    // devices available on the same hardware.
    const auto& cit =
        std::find_if(devices_.begin(), devices_.end(), [&pit](const Device& d) {
            return pit.device_type == static_cast<int>(d.device_type);
        });
    Device dev = cit == devices_.end() ? devices_[0] : *cit;

    std::vector<int64_t> shape;
    shape.push_back(static_cast<int64_t>(pit.size + 3) / 4);
    storage_pool_.push_back(
        NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, dev));
}

NDArray NDArray::Empty(
    ShapeTuple shape, DLDataType dtype, Device dev,
    Optional<String> mem_scope) {
    NDArray ret = Internal::Create(shape, dtype, dev);
    ret.get_mutable()->dl_tensor.data =
        DeviceAPI::Get(ret->device)
            ->AllocDataSpace(
                ret->device, shape.size(), shape.data(), ret->dtype, mem_scope);
    return ret;
}

// opencl example
void* OpenCLWorkspace::AllocDataSpace(
    Device dev, size_t size, size_t alignment, DLDataType type_hint) {
    cl::BufferDescriptor* desc = new cl::BufferDescriptor;
    desc->buffer = clCreateBuffer(
        this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code);
    desc->layout = cl::BufferDescriptor::MemoryLayout::kBuffer1D;
    return desc;
}
1.2.1.2.2. storage_id

根据 graph 中的 attrs.storage_id 给每一个输出设置一个 DLTensor, 保存在 data_entry_ 中

attris.storage_id 也是在 GraphExecutorCodegen 生成的

data_entry_.resize(num_node_entries());
data_alignment_.resize(num_node_entries());
for (size_t i = 0; i < data_entry_.size(); ++i) {
    int storage_id = attrs_.storage_id[i];
    ICHECK_LT(static_cast<size_t>(storage_id), storage_pool_.size());
    data_entry_[i] = storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]);

    const DLTensor* tmp = data_entry_[i].operator->();
}
  1. storage_id example
    import numpy as np
    import tvm
    from tvm import relay
    
    a = relay.var("a", shape=(1, 10), dtype="float32")
    b = relay.var("b", shape=(1, 10), dtype="float32")
    out1 = relay.add(a, b)
    out2 = relay.add(out1, a)
    out = relay.add(out2, b)
    
    func = relay.Function([a, b], out)
    mod = tvm.IRModule.from_expr(func)
    
    with tvm.transform.PassContext(opt_level=0):
        graph, lib, params = relay.build(mod, target="c", params=None)
    print(graph)
    
    

    { "nodes": [ { "op": "null", "name": "a", "inputs": [] }, { "op": "null", "name": "b", "inputs": [] }, { "op": "tvm_op", "name": "tvmgen_default_fused_add", "attrs": { "num_outputs": "1", "num_inputs": "2", "flatten_data": "0", "func_name": "tvmgen_default_fused_add", "hash": "8ec3c08331cd92b5" }, "inputs": [ [ 0, 0, 0 ], [ 1, 0, 0 ] ] }, { "op": "tvm_op", "name": "tvmgen_default_fused_add1", "attrs": { "num_outputs": "1", "num_inputs": "2", "flatten_data": "0", "func_name": "tvmgen_default_fused_add", "hash": "8ec3c08331cd92b5" }, "inputs": [ [ 2, 0, 0 ], [ 0, 0, 0 ] ] }, { "op": "tvm_op", "name": "tvmgen_default_fused_add2", "attrs": { "num_outputs": "1", "num_inputs": "2", "flatten_data": "0", "func_name": "tvmgen_default_fused_add", "hash": "8ec3c08331cd92b5" }, "inputs": [ [ 3, 0, 0 ], [ 1, 0, 0 ] ] } ], "arg_nodes": [0, 1], "heads": [ [ 4, 0, 0 ] ], "attrs": { "dltype": [ "list_str", [ "float32", "float32", "float32", "float32", "float32" ] ], "storage_id": [ "list_int", [0, 1, 2, 3, 2] ], "shape": [ "list_shape", [ [1, 10], [1, 10], [1, 10], [1, 10], [1, 10] ] ] }, "node_row_ptr": [0, 1, 2, 3, 4, 5] }

    storage_id 为 [0, 1, 2, 3, 2],因为 output (即 [5,0,0]) 与 out1 (即 [2,0,0]) 使用会相同的 DLTensor

1.2.1.3. GraphExecutorCodegen

SetupStorage 时需的 graph 信息,包括 device_index 和 storage_id, 都来自 GraphExecutorCodegen, 具体的, 来自 StaticMemoryPlan

例如:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 2021-09-08 18:19
import tvm
from tvm import relay
from tvm.contrib import graph_executor
import numpy as np


def test_on_device():
    x = relay.var("x", shape=(1, 5))
    y = relay.var("y", shape=(1, 5))
    x_data = np.random.rand(1, 5).astype("float32")
    y_data = np.random.rand(1, 5).astype("float32")

    cpu_dev = tvm.device("cpu")
    opencl_dev = tvm.device("opencl")

    def get_function():
        add = relay.add(x, y)
        _add = relay.annotation.on_device(add, opencl_dev)
        log = relay.log(_add)
        func = relay.Function([x, y], log)
        return func

    func = get_function()
    print(func)

    with tvm.transform.PassContext(opt_level=1):
        lib = relay.build(
            func, {"cpu": "llvm", "opencl": "opencl"}, params={"x": x_data, "y": y_data}
        )
        print(lib.graph_json)
        # rt_mod = graph_executor.GraphModule(
        #     lib["default"](tvm.cpu(0), tvm.device("opencl"))
        # )
        # rt_mod.run()
        # tvm_res = rt_mod.get_output(0).numpy()
        # print(tvm_res)


if __name__ == "__main__":
    test_on_device()

fn (%x: Tensor[(1, 5), float32], %y: Tensor[(1, 5), float32]) { %0 = add(%x, %y); %1 = on_device(%0, device_type=4); log(%1) } { "nodes": [ { "op": "null", "name": "p0", "inputs": [] }, { "op": "null", "name": "p1", "inputs": [] }, { "op": "tvm_op", "name": "tvmgen_default_fused_add", "attrs": { "num_outputs": "1", "num_inputs": "2", "flatten_data": "0", "func_name": "tvmgen_default_fused_add", "hash": "77c3516efdd817bd" }, "inputs": [ [ 0, 0, 0 ], [ 1, 0, 0 ] ] }, { "op": "tvm_op", "name": "__copy", "attrs": { "num_outputs": "1", "num_inputs": "1", "flatten_data": "0", "func_name": "__copy", "hash": "b78177ac726e601e" }, "inputs": [ [ 2, 0, 0 ] ] }, { "op": "tvm_op", "name": "tvmgen_default_fused_log", "attrs": { "num_outputs": "1", "num_inputs": "1", "flatten_data": "0", "func_name": "tvmgen_default_fused_log", "hash": "8a435bddfed54dab" }, "inputs": [ [ 3, 0, 0 ] ] } ], "arg_nodes": [0, 1], "heads": [ [ 4, 0, 0 ] ], "attrs": { "dltype": [ "list_str", [ "float32", "float32", "float32", "float32", "float32" ] ], "device_index": [ "list_int", [4, 4, 4, 1, 1] ], "storage_id": [ "list_int", [0, 1, 2, 3, 4] ], "shape": [ "list_shape", [ [1, 5], [1, 5], [1, 5], [1, 5], [1, 5] ] ] }, "node_row_ptr": [0, 1, 2, 3, 4, 5] }

device_index 为 [4,4,4,1,1], 意味着 node[2] (即 add) 的输出会被分配在 device 4 (opencl) 上, 同时它的输入 (node[0],node[1], 即 x,y) 也需要分配在 opencl 上

GraphExecutor 的 set_input 会负责最终调用 opencl 的 CopyDataFromTo 把输入复制到 x, y

1.2.1.4. SetupOpExecs

SetupOpExecs 的作用有两个:

  1. 找到所有 PackedFunc (通过 module->GetFunction)
  2. 确定每个 PackedFunc 的参数和返回值对应的具体的 DLTensor (使用 SetupStorage 时分配的 data_entry_)
void GraphExecutor::SetupOpExecs() {
    std::unordered_set<uint32_t> input_node_eids;
    for (size_t i = 0; i < input_nodes_.size(); i++) {
        uint32_t nid = input_nodes_[i];
        input_node_eids.insert(entry_id(nid, 0));
    }

    for (uint32_t nid = 0; nid < this->GetNumOfNodes(); ++nid) {
        const auto& inode = nodes_[nid];
        std::vector<DLTensor> args;
        // 输入参数
        for (const auto& e : inode.inputs) {
            uint32_t eid = this->entry_id(e);
            // 使用 data_entry_ 中分配给 eid 的 DLTensor, 有可能不同的 eid 最终
            // 会使用相同的 DLTensor, 因为它们有相同的 storage_id
            args.push_back(*(data_entry_[eid].operator->()));
        }
        // 输出
        for (uint32_t index = 0; index < inode.param.num_outputs; ++index) {
            uint32_t eid = this->entry_id(nid, index);
            args.push_back(*(data_entry_[eid].operator->()));
        }

        // 输入和输出的 DLTensor 都打包在同一个 args 中做为 tvm op 的参数
        std::tie(op_execs_[nid], op_args) =
            CreateTVMOp(inode.param, args, inode.inputs.size());
    }
}

1.2.2. run

执行 GraphExecutor 的代码大约是这样:

PackedFunc set_input = mod.GetFunction("set_input", false);
PackedFunc run = mod.GetFunction("run", false);
PackedFunc get_output = mod.GetFunction("get_output", false);
set_input("A", a_val);
set_input("B", b_val);
set_input("C", c_val);
run();
tvm::runtime::NDArray out = get_output(0);

其中 "run" 的实现为:

void GraphExecutor::Run() {
  for (size_t i = 0; i < op_execs_.size(); ++i) {
    if (op_execs_[i]) op_execs_[i]();
  }
}

而 op_execs_ 即各个 module 中针对 graph 中的 symbol 的具体的实现, 例如 DNNLJSONRuntime 中的 tvmgen_default_dnnl_main_0

由于各个 op_execs_ 的输入输出在上一步的 SetupOpExecs 时已经设置好了, 这时只负责执行即可, 不需要再考虑参数和返回值的问题.

另外, 由于 op_execs_ 是线性执行的, 所以生成 graph 时需要保证是一个拓扑排序

Author: [email protected]
Date: 2021-09-24 Fri 00:00
Last updated: 2022-01-24 Mon 19:34

知识共享许可协议