TVM TIR

Table of Contents

1. TVM TIR

1.1. ir_builder

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 2021-09-08 18:19
import tvm
from tvm import te
import numpy as np

N = 10

dtype = "float32"
A = te.placeholder((N,), name="A")
B = te.placeholder((N,), name="B")


def test_ir_builder(A, B, C):
    ib = tvm.tir.ir_builder.create()
    Aptr = ib.buffer_ptr(A)
    Bptr = ib.buffer_ptr(B)
    Cptr = ib.buffer_ptr(C)
    with ib.for_range(0, N, name="k") as i:
        Cptr[i] = Aptr[i] + Bptr[i]
    return ib.get()


C = te.extern(
    A.shape,
    [A, B],
    lambda ins, outs: test_ir_builder(ins[0], ins[1], outs[0]),
    name="add",
    dtype=dtype,
)
s = te.create_schedule(C.op)

fadd = tvm.build(s, [A, B, C], "llvm")

dev = tvm.cpu(0)
a = tvm.nd.array(np.ones(N).astype(A.dtype), dev)
b = tvm.nd.array(np.ones(N).astype(B.dtype), dev)
c = tvm.nd.array(np.zeros(N, dtype=C.dtype), dev)
fadd(a, b, c)

print(a)
print(b)
print(c)

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]

1.2. tir.call_extern

import tvm
from tvm import te
import numpy as np


def test_ir_builder(C):
    ib = tvm.tir.ir_builder.create()
    Cptr = ib.buffer_ptr(C)
    with ib.for_range(0, 10, name="i") as i:
        with ib.for_range(0, 10, name="j") as j:
            Cptr[i * 10 + j] = tvm.tir.call_extern("int32", "get_value", i * 3 + j * 1)
    return ib.get()


C = te.extern(
    (100,),
    [],
    lambda ins, outs: test_ir_builder(outs[0]),
    name="test",
    dtype="int32",
)

s = te.create_schedule(C.op)
fadd = tvm.build(s, [C], "c")
print(fadd.get_source())

// tvm target: c -keys=cpu -link-params=0 #define TVM_EXPORTS #include "tvm/runtime/c_runtime_api.h" #include "tvm/runtime/c_backend_api.h" #include <math.h> #ifdef __cplusplus extern "C" #endif TVM_DLL int32_t default_function(void* args, void* arg_type_ids, int32_t num_args, void* out_ret_value, void* out_ret_tcode, void* resource_handle) { void* arg0 = (((TVMValue*)args)[0].v_handle); int32_t arg0_code = ((int32_t*)arg_type_ids)[(0)]; void* test = (((DLTensor*)arg0)[0].data); void* arg0_shape = (((DLTensor*)arg0)[0].shape); void* arg0_strides = (((DLTensor*)arg0)[0].strides); int32_t dev_id = (((DLTensor*)arg0)[0].device.device_id); if (!(arg0_strides == NULL)) { } for (int32_t i = 0; i < 10; ++i) { for (int32_t j = 0; j < 10; ++j) { ((int32_t*)test)[(((i * 10) + j))] = get_value(((i * 3) + j)); } } return 0; }

// CodegenC: NOTE: Auto-generated entry function #ifdef __cplusplus extern "C" #endif TVM_DLL int32_t __tvm_main__(void* args, int* arg_type_ids, int num_args, void* out_ret_value, int* out_ret_tcode, void* resource_handle) { return default_function(args, arg_type_ids, num_args, out_ret_value, out_ret_tcode, resource_handle); }

1.3. tir.call_packed

Author: [email protected]
Date: 2021-10-13 Wed 00:00
Last updated: 2021-10-13 Wed 01:06

知识共享许可协议