TVM Schedule

Table of Contents

1. TVM Schedule

1.1. Schedule

1.1.1. Overview

http://tvm.apache.org/docs/api/python/te.html

https://github.com/StrongSpoon/tvm.schedule

TE (tensor expression) 位于 Relay IR 和 TIR 之间, TE 包含 compute + schedule, 其中 compute 用 TIR 来描述运算的核心操作,schedule 用来描述如何调度.

Q: 什么是调度 & 为什么需要调度

Halide: A Language and Compiler for Optimizing Parallelism, Locality, and Recomputation in Image Processing Pipelines

Learning to Optimize Tensor Programs

1.1.2. Example

Relay IR 先经过 tvm.relay.transform 进行优化,然后转换为 TE, 然后通过 schedule 产生最佳性能的代码 (例如 vectorize scheduler 会使用 SIMD 指令加速)

import tvm
from tvm import te

M = 1024
N = 1024
A = te.placeholder((M, N), name="A")
B = te.placeholder((M, N), name="B")
C = te.compute((M, N), lambda x, y: A[x, y] + B[x, y], name="C")

s = te.create_schedule(C.op)
print(tvm.lower(s, [A, B, C]))

f = tvm.build(s, [A, B, C], "c")
print(f.get_source())
primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [1024, 1024], []),
             A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], []),
             B: Buffer(B_2: Pointer(float32), float32, [1024, 1024], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (x: int32, 0, 1024) {
    for (y: int32, 0, 1024) {
      C_2[((x*1024) + y)] = ((float32*)A_2[((x*1024) + y)] + (float32*)B_2[((x*1024) + y)])
    }
  }
}


// tvm target: c -keys=cpu -link-params=0
#define TVM_EXPORTS
#include "tvm/runtime/c_runtime_api.h"
#include "tvm/runtime/c_backend_api.h"
#include <math.h>
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t default_function(void* args, void* arg_type_ids, int32_t num_args, void* out_ret_value, void* out_ret_tcode, void* resource_handle) {
  void* arg0 = (((TVMValue*)args)[0].v_handle);
  int32_t arg0_code = ((int32_t*)arg_type_ids)[(0)];
  void* arg1 = (((TVMValue*)args)[1].v_handle);
  int32_t arg1_code = ((int32_t*)arg_type_ids)[(1)];
  void* arg2 = (((TVMValue*)args)[2].v_handle);
  int32_t arg2_code = ((int32_t*)arg_type_ids)[(2)];
  void* A = (((DLTensor*)arg0)[0].data);
  void* arg0_shape = (((DLTensor*)arg0)[0].shape);
  void* arg0_strides = (((DLTensor*)arg0)[0].strides);
  int32_t dev_id = (((DLTensor*)arg0)[0].device.device_id);
  void* B = (((DLTensor*)arg1)[0].data);
  void* arg1_shape = (((DLTensor*)arg1)[0].shape);
  void* arg1_strides = (((DLTensor*)arg1)[0].strides);
  void* C = (((DLTensor*)arg2)[0].data);
  void* arg2_shape = (((DLTensor*)arg2)[0].shape);
  void* arg2_strides = (((DLTensor*)arg2)[0].strides);
  if (!(arg0_strides == NULL)) {
  }
  if (!(arg1_strides == NULL)) {
  }
  if (!(arg2_strides == NULL)) {
  }
  for (int32_t x = 0; x < 1024; ++x) {
    for (int32_t y = 0; y < 1024; ++y) {
      ((float*)C)[(((x * 1024) + y))] = (((float*)A)[(((x * 1024) + y))] + ((float*)B)[(((x * 1024) + y))]);
    }
  }
  return 0;
}

使用 split schedule 后:

s[C].split(C.op.axis[0], factor=20)
print(tvm.lower(s, [A, B, C]))
primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [1024, 1024], []),
             A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], []),
             B: Buffer(B_2: Pointer(float32), float32, [1024, 1024], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (x.outer: int32, 0, 52) {
    for (x.inner: int32, 0, 20) {
      if (((x.outer*20) + x.inner) < 1024) {
        for (y: int32, 0, 1024) {
          C_2[(((x.outer*20480) + (x.inner*1024)) + y)] = ((float32*)A_2[(((x.outer*20480) + (x.inner*1024)) + y)] + (float32*)B_2[(((x.outer*20480) + (x.inner*1024)) + y)])
        }
      }
    }
  }
}

使用 vectorize schedule 后:

s = te.create_schedule(C.op)
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], 32, 32)

print(tvm.lower(s, [A, B, C], simple_mode=True))
print("---------cutting line---------")

s[C].vectorize(yi)

print(tvm.lower(s, [A, B, C], simple_mode=True))
primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [1024, 1024], []),
             A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], []),
             B: Buffer(B_2: Pointer(float32), float32, [1024, 1024], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (x.outer: int32, 0, 32) {
    for (y.outer: int32, 0, 32) {
      for (x.inner: int32, 0, 32) {
        for (y.inner: int32, 0, 32) {
          C_2[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] = ((float32*)A_2[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] + (float32*)B_2[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)])
        }
      }
    }
  }
}


---------cutting line---------
primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [1024, 1024], []),
             A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], []),
             B: Buffer(B_2: Pointer(float32), float32, [1024, 1024], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (x.outer: int32, 0, 32) {
    for (y.outer: int32, 0, 32) {
      for (x.inner: int32, 0, 32) {
        C_2[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)] = ((float32x32*)A_2[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)] + (float32x32*)B_2[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)])
      }
    }
  }
}

1.1.3. TVM Schedule Primitives

https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html

schedule 主要分为以下几类:

  1. memory tiling
  2. loop transformation (splitting, reording, unrolling)
  3. vectorization/tensorization
  4. parallelization
1.1.3.1. split
1.1.3.2. ternsorize

https://tvm.apache.org/docs/tutorials/language/tensorize.html

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 2021-10-22 16:46
import tvm
from tvm import te

# a = 1024 x 64
# b = 512  x 64
# c = a @ b.T = 1024 x 512
N, M, L = 1024, 512, 64
A = te.placeholder((N, L), name="A")
B = te.placeholder((M, L), name="B")
k = te.reduce_axis((0, L), name="k")
C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[j, k], axis=k), name="C")
s = te.create_schedule(C.op)

factor = 32
x, y = C.op.axis
(z,) = C.op.reduce_axis
yo, yi = s[C].split(y, factor=factor)
# s[C].reorder(x, yo, yi, z)
print("------ before ------")
print(tvm.lower(s, [A, B, C], simple_mode=True))


def intrin_gemv(m, l):
    a = te.placeholder((l,), name="a")
    b = te.placeholder((m, l), name="b")
    k = te.reduce_axis((0, l), name="k")
    c = te.compute((m,), lambda i: te.sum(a[k] * b[i, k], axis=k), name="c")

    Abuf = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[1])
    Bbuf = tvm.tir.decl_buffer(
        b.shape, b.dtype, name="B", offset_factor=1, strides=[te.var("s1"), 1]
    )
    Cbuf = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[1])

    def intrin_func(ins, outs):
        ib = tvm.tir.ir_builder.create()
        aa, bb = ins
        cc = outs[0]
        ib.emit(
            tvm.tir.call_extern(
                "int32",
                "gemv_update",
                cc.access_ptr("w"),
                aa.access_ptr("r"),
                bb.access_ptr("r"),
                m,
                l,
                bb.strides[0],
            )
        )
        return ib.get()

    return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Abuf, b: Bbuf, c: Cbuf})


# def gemv_impl():
#     cc_code = """
#       extern "C" int gemv_update(float *cc, float *aa, float *bb, int m, int l, int stride) {
#         for (int i = 0; i < m; ++i) {
#             for (int j = 0; j < l; ++j) {
#                 cc[i] += aa[j] * bb[i * stride + j];
#             }
#         }
#         return 0;
#       }
#     """
#     from tvm.contrib import utils, clang
#     temp = utils.tempdir()
#     ll_path = temp.relpath("temp.ll")
#     # Create LLVM ir from c source code
#     ll_code = clang.create_llvm(cc_code, output=ll_path)
#     return ll_code
#
# s[C].pragma(x, "import_llvm", gemv_impl())

print("------ after ------")
gemv = intrin_gemv(factor, L)
s[C].tensorize(yi, gemv)
print(tvm.lower(s, [A, B, C], simple_mode=True))
------ before ------
primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [1024, 512], []),
             B: Buffer(B_2: Pointer(float32), float32, [512, 64], []),
             A: Buffer(A_2: Pointer(float32), float32, [1024, 64], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (i: int32, 0, 1024) {
    for (j.outer: int32, 0, 16) {
      for (j.inner: int32, 0, 32) {
        C_2[(((i*512) + (j.outer*32)) + j.inner)] = 0f32
        for (k: int32, 0, 64) {
          C_2[(((i*512) + (j.outer*32)) + j.inner)] = ((float32*)C_2[(((i*512) + (j.outer*32)) + j.inner)] + ((float32*)A_2[((i*64) + k)]*(float32*)B_2[(((j.outer*2048) + (j.inner*64)) + k)]))
        }
      }
    }
  }
}


------ after ------
primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [1024, 512], []),
             B: Buffer(B_2: Pointer(float32), float32, [512, 64], []),
             A: Buffer(A_2: Pointer(float32), float32, [1024, 64], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (i: int32, 0, 1024) {
    for (j.outer: int32, 0, 16) {
      @tir.call_extern("gemv_update", @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float32), C_2, ((i*512) + (j.outer*32)), 32, 2, dtype=handle), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float32), A_2, (i*64), 64, 1, dtype=handle), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float32), B_2, (j.outer*2048), 2048, 1, dtype=handle), 32, 64, 64, dtype=int32)
    }
  }
}

1.1.3.3. parallel
import tvm
from tvm import te

n = 1024
m = 1024

A = te.placeholder((n, m), name="A")
l = te.reduce_axis((0, m), name="l")

B = te.compute((n,), lambda i: te.sum(A[i, l], axis=l), name="B")

s = te.create_schedule(B.op)

print(tvm.lower(s, [A, B], simple_mode=True))
print("---------cutting line---------")

s[B].parallel(B.op.reduce_axis[0])
print(tvm.lower(s, [A, B], simple_mode=True))
primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: Pointer(float32), float32, [1024], []),
             A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], [])}
  buffer_map = {A_1: A, B_1: B} {
  for (i: int32, 0, 1024) {
    B_2[i] = 0f32
    for (l: int32, 0, 1024) {
      B_2[i] = ((float32*)B_2[i] + (float32*)A_2[((i*1024) + l)])
    }
  }
}


---------cutting line---------
primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: Pointer(float32), float32, [1024], []),
             A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], [])}
  buffer_map = {A_1: A, B_1: B} {
  for (i: int32, 0, 1024) {
    B_2[i] = 0f32
    for (l: int32, 0, 1024) "parallel" {
      B_2[i] = ((float32*)B_2[i] + (float32*)A_2[((i*1024) + l)])
    }
  }
}

1.1.3.4. vectorize
1.1.3.5. tile
1.1.3.6. reorder
1.1.3.7. compute_at
1.1.3.8. compute_inline
1.1.3.9. compute_root
1.1.3.10. fuse
import tvm
from tvm import te

n = 1024
A = te.placeholder((n,), name="A")
k = te.reduce_axis((0, n), name="k")

B = te.compute((1,), lambda i: te.sum(A[k], axis=k), name="B")

s = te.create_schedule(B.op)

ko, ki = s[B].split(B.op.reduce_axis[0], factor=32)

print("------before------")
print(tvm.lower(s, [A, B], simple_mode=True))

s[B].fuse(ko, ki)
print("------after------")
print(tvm.lower(s, [A, B], simple_mode=True))
------before------
primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: Pointer(float32), float32, [1], []),
             A: Buffer(A_2: Pointer(float32), float32, [1024], [])}
  buffer_map = {A_1: A, B_1: B} {
  B_2[0] = 0f32
  for (k.outer: int32, 0, 32) {
    for (k.inner: int32, 0, 32) {
      B_2[0] = ((float32*)B_2[0] + (float32*)A_2[((k.outer*32) + k.inner)])
    }
  }
}


------after------
primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: Pointer(float32), float32, [1], []),
             A: Buffer(A_2: Pointer(float32), float32, [1024], [])}
  buffer_map = {A_1: A, B_1: B} {
  B_2[0] = 0f32
  for (k.outer.k.inner.fused: int32, 0, 1024) {
    B_2[0] = ((float32*)B_2[0] + (float32*)A_2[k.outer.k.inner.fused])
  }
}

1.1.3.11. bind

1.1.4. Injective Schedule

https://en.wikipedia.org/wiki/Injective_function

injective function 即单射函数, 也称为 one-to-one function, 在 TVM 中, 有大量的符合 injective 定义的 op 都使用同一个 schedule: schedule_injective, 因为它们执行时除了具体的运算不同, schedule 的要求实际上是一样的.

TVM 提供了几个函数来指定 schedule_injective 做为 schedule:

  1. register_injective_schedule
  2. register_broadcast_schedule

例如:

register_broadcast_schedule("log")
register_broadcast_schedule("log2")
register_broadcast_schedule("log10")
register_broadcast_schedule("tan")
register_broadcast_schedule("cos")
register_broadcast_schedule("cosh")
register_broadcast_schedule("sin")
register_broadcast_schedule("sinh")
# ...
register_broadcast_schedule("add")
register_broadcast_schedule("subtract")
register_broadcast_schedule("multiply")
register_broadcast_schedule("divide")
register_broadcast_schedule("floor_divide")
register_broadcast_schedule("power")
register_broadcast_schedule("copy")
register_broadcast_schedule("logical_not")
register_broadcast_schedule("logical_and")
register_broadcast_schedule("logical_or")
register_broadcast_schedule("logical_xor")
register_broadcast_schedule("bitwise_not")
register_broadcast_schedule("bitwise_and")
register_broadcast_schedule("bitwise_or")
register_broadcast_schedule("bitwise_xor")
register_broadcast_schedule("negative")
register_broadcast_schedule("mod")
register_broadcast_schedule("floor_mod")
# ...
register_injective_schedule("one_hot")
# ...

register_injective_schedule 会使用一个名为 schedule_injective 的 generic function, 不同的 target 会通过 schedule_injective.register 注册不同的实现

@schedule_injective.register("cpu")
def schedule_injective_cpu(attrs, outs, target):
    """schedule injective ops for x86"""
    with target:
        return topi.x86.schedule_injective(outs)


@schedule_injective.register(["arm_cpu", "micro_dev"])
def schedule_injective_arm_cpu(_, outs, target):
    """schedule injective ops for arm cpu"""
    with target:
        return topi.arm_cpu.schedule_injective(outs)

injective schedule 主要是利用 split+vectorize, 因为 vectorize (simd) 本来就是用来加速 injective 操作的

1.1.5. Reduce Schedule

Reduce Schedule 与 Injective Schedule 类似, 但它针对的是 sum, max 等进行 reduction 的操作, reduce schedule 无法使用 vectorize, tvm 的实现主要是使用 parallel

register_reduce_schedule("argmax")
register_reduce_schedule("argmin")
register_reduce_schedule("sum")
register_reduce_schedule("all")
register_reduce_schedule("any")
register_reduce_schedule("max")
register_reduce_schedule("min")
register_reduce_schedule("prod")
register_reduce_schedule("mean")
# ...


@schedule_reduce.register("arm_cpu")
def schedule_reduce_cpu(attrs, outs, target):
    """schedule reduction ops for arm_cpu"""
    with target:
        return topi.x86.schedule_reduce(outs)


@schedule_reduce.register("cpu")
def schedule_reduce_cpu(attrs, outs, target):
    """schedule reduction ops for x86"""
    with target:
        return topi.x86.schedule_reduce(outs)

1.2. AutoTVM

tvm/gallery/tutorial/autotvm_matmul_x86.py

1.3. TVM Auto Scheduler

1.3.1. Testing Function

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 2021-08-09 17:16
import os

import numpy as np
import tvm
from tvm import te, auto_scheduler

N = L = M = 1024
target = tvm.target.Target("llvm -mcpu=skylake-avx512")

A = tvm.nd.array(np.random.uniform(size=(N, L)).astype(np.float32))
B = tvm.nd.array(np.random.uniform(size=(L, M)).astype(np.float32))
C = tvm.nd.array(np.random.uniform(size=(N, M)).astype(np.float32))
OUT = tvm.nd.empty((N, M))


def eval(sch, args):
    # print(tvm.lower(sch, args))
    func = tvm.build(sch, args, target)
    evaluator = func.time_evaluator(func.entry_name, tvm.cpu())
    print(
        "Execution time of this operator: %.3f ms"
        % (np.median(evaluator(A, B, C, OUT).results) * 100)
    )


@auto_scheduler.register_workload
def matmul_add(N, L, M, dtype):
    A = te.placeholder((N, L), name="A", dtype=dtype)
    B = te.placeholder((L, M), name="B", dtype=dtype)
    C = te.placeholder((N, M), name="C", dtype=dtype)

    k = te.reduce_axis((0, L), name="k")
    matmul = te.compute(
        (N, M),
        lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
        name="matmul",
        attrs={"layout_free_placeholders": [B]},
    )
    out = te.compute((N, M), lambda i, j: matmul[i, j] + C[i, j], name="out")

    return [A, B, C, out]

1.3.2. AutoScheduler Tunning

task = tvm.auto_scheduler.SearchTask(
    func=matmul_add, args=(N, L, M, "float32"), target=target
)
log_file = "/tmp/matmul.json"

print("-------------------- tunning --------------------")
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=10,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=1,
)
task.tune(tune_option)
-------------------- tunning --------------------
----------------------------------------------------------------------
------------------------------  [ Search ]
----------------------------------------------------------------------
Generate Sketches		#s: 3
Sample Initial Population	#s: 2015	fail_ct: 1	Time elapsed: 0.67
GA Iter: 0	Max score: 0.9998	Min score: 0.9348	#Pop: 128	#M+: 0	#M-: 0
GA Iter: 4	Max score: 1.0000	Min score: 0.9887	#Pop: 128	#M+: 1381	#M-: 72
EvolutionarySearch		#s: 128	Time elapsed: 2.90
----------------------------------------------------------------------
------------------------------  [ Measure ]
----------------------------------------------------------------------
Get 10 programs to measure:
..........*T*T*T*T*T*T*T*T*T*T
Time elapsed for measurement: 101.94 s
----------------------------------------------------------------------
------------------------------  [ Done ]
----------------------------------------------------------------------
No valid state found in this search round. Check if it has traversed all of the search space.

1.3.3. Evaluation

print("-------------------- orig_scheuler--------------------")
args = matmul_add(N, L, M, "float32")
sch = te.create_schedule(args[-1].op)
eval(sch, args)

print("-------------------- auto_scheuler--------------------")
sch, args = task.apply_best(log_file)
eval(sch, args)
-------------------- orig_scheuler--------------------
Execution time of this operator: 287.769 ms
-------------------- auto_scheuler--------------------
Execution time of this operator: 6.305 ms

1.3.4. Auto Scheduler Internals

auto_scheduler 要解决的问题是:

定义一个 schedule 参数的搜索空间, 通过一个 measure 函数在这个空间中 sample, 然后基于这些 sample 用 cost_model 预测最佳的 schedule 参数

1.3.4.1. SearchTask
SearchTask.init:
    # SearchTask 接受一个 target 参数, 这个参数可以告诉 builder 生成针对哪个 target 的代码 

SearchTask.tune: 
    if search_policy is None:
        # 默认使用 xgboost 来预测
        cost_model = XGBModel()
        search_policy = SketchPolicy(self, cost_model)

    _ffi_api.AutoSchedule(search_policy, tuning_options)

AutoSchedule:
    # tuning_options->builder 的默认值是 LocalBuilder
    # tuning_options->runner 的默认值是 LocalRunner
    ProgramMeasurer measurer =
      ProgramMeasurer(tuning_options->builder, tuning_options->runner,
                      tuning_options->measure_callbacks, tuning_options->verbose);    

    State state =
        search_policy->Search(tuning_options->num_measure_trials, tuning_options->early_stopping,
                            tuning_options->num_measures_per_round, measurer);
    return search_policy->search_task->compute_dag.ApplySteps(state->transform_steps);

SketchPolicy::Search:
    # MeasureInput 包含了 schedule 参数
    # MeasureResult 包含针对一套参数 measure 后的结果 (cost, ...)
    Array<MeasureInput> inputs;
    Array<MeasureResult> results;
    while (ct < n_trials):
        # train xgboost, inputs 相当于 x, results 相当于 y
        program_cost_model->Update(inputs, results);

        # sample 出新的参数
        inputs = PickStatesWithEpsGreedy(best_states, random_states, n_trials - ct);
        # measure 这些参数
        results = measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs);        

    return measurer->best_state[search_task->workload_key];        
1.3.4.2. Measure
ProgramMeasurerNode::Measure:
  // tvm 默认会同时 measure 多套参数
  if (batch_size == -1) {
    // set default batch size
    batch_size = builder->n_parallel * 2;
  }
  for (size_t i = 0; i < inputs.size(); i += batch_size):
    Array<MeasureInput> input_batch(inputs.begin() + i,
                                      inputs.begin() + std::min(i + batch_size, inputs.size()));
    Array<MeasureResult> result_batch;
    SilentMeasure(task, input_batch, &result_batch);
    for (auto& res : result_batch) : 
      results.push_back(res);

  return results;
SilentMeasure:
  Array<BuildResult> build_res_batch = builder->Build(inputs, verbose);
  Array<MeasureResult> result_batch = runner->Run(inputs, build_res_batch, verbose);

  // Store result batch
  for (auto& res : result_batch) {
    results->push_back(res);
  }
1.3.4.2.1. builder.Build

builder 是一个 ProgramBuilder 接口, LocalBuilder 是一个参考实现. 注: 虽然 TuningOptions 指定的 builder 是 python 对象, 最终调用是在 c++ 代码中 (SilentMeasure), 所以实现一个自定义的 builder 有两种选择:

  1. 模仿 LocalBuilder 的写法, 用 c++ 来实现, 然后通过 Registry 再调回到 python
  2. 使用 LocalBuilder, 但需要提供一个 python 实现的和 ndk.create_shared 或 tar.tar 类似的自定义 build_func
class LocalBuilder(ProgramBuilder):
    def __init__(self, timeout=15, n_parallel=multiprocessing.cpu_count(), build_func="default"):
        if build_func == "default":
            BuildFunc.name = "default"
            BuildFunc.build_func = tar.tar
        elif build_func == "ndk":
            BuildFunc.name = "ndk"
            BuildFunc.build_func = ndk.create_shared
        elif callable(build_func):
            BuildFunc.name = "custom"
            BuildFunc.build_func = build_func
        else:
            raise ValueError("Invalid build_func" + build_func)

        self.__init_handle_by_constructor__(
            _ffi_api.LocalBuilder, timeout, n_parallel, BuildFunc.name
        )

c++ LocalBuilder 会通过 ffi 调用回 python 的 local_builder_build

Array<BuildResult> LocalBuilderNode::Build(const Array<MeasureInput>& inputs, int verbose) {
  if (const auto* f = runtime::Registry::Get("auto_scheduler.local_builder.build")) {
    Array<BuildResult> results = (*f)(inputs, timeout, n_parallel, build_func, verbose);
    return results;
  }
}
def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbose=1):
    pool = multiprocessing.pool.ThreadPool(n_parallel)
    tuple_res = pool.map(
        local_build_worker,
        [
            (
                i.serialize(),
                build_func,
                timeout,
                verbose,
            )
            for i in inputs
        ],
    )

    results = []
    for res in tuple_res:
        results.append(BuildResult(*res))

    return results

def local_build_worker(args):
    build_func = BuildFunc.build_func
    res = call_func_with_timeout(timeout, _timed_func, args=(inp, build_func, verbose))

最终的 build 过程:

def _timed_func(inp_serialized, build_func, verbose):
    inp = MeasureInput.deserialize(inp_serialized)
    task = inp.task
    sch, args = task.compute_dag.apply_steps_from_state(
        inp.state, layout_rewrite=task.layout_rewrite_option
    )

    dirname = tempfile.mkdtemp()
    filename = os.path.join(dirname, "tmp_func." + build_func.output_format)

    with transform.PassContext():
        func = build_module.build(sch, args, target=task.target)
    func.export_library(filename, build_func)

    return filename, args, error_no, error_msg, time.time() - tic

1.3.4.2.2. runner.Run

runner 的调用过程和 builder 类似 (python->c++->python), 但实现一个自定义的 runner 时没有类似于 build_func 的东西, 需要参照 _timed_eval_func 从头实现, 比较重要的信息是从 build_res 可以拿到 builder.Build 生成的文件名

def _timed_eval_func(
    inp_serialized,
    build_res,
    number,
    repeat,
    min_repeat_ms,
    cooldown_interval,
    enable_cpu_cache_flush,
    verbose,
):
    inp = MeasureInput.deserialize(inp_serialized)
    task_input_names = inp.task.task_input_names

    func = module.load_module(build_res.filename)
    dev = ndarray.device(str(inp.task.target), 0)

    time_f = func.time_evaluator(
        func.entry_name,
        dev,
        number=number,
        repeat=repeat,
        min_repeat_ms=min_repeat_ms,
        f_preproc=f_prepare,
    )

    random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True)

    tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {}
    args = []
    task_inputs_count = 0
    for arg in build_res.args:
        if arg in tensor_input_map:
            tensor_name = tensor_input_map[arg]
            if tensor_name in task_input_names:
                args.append(
                    ndarray.array(
                        get_task_input_buffer(inp.task.workload_key, tensor_name), dev
                    )
                )
                task_inputs_count += 1
        else:
            empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, dev)
            random_fill(empty_array)
            args.append(empty_array)
        costs = time_f(*args).results

    return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc

1.3.5. Auto Scheduler and MicroTVM

auto_scheduler 自带实现的:

  1. builder
    • LocalBuilder

      多进程, 最终调用 build_module.build 及 export_library

  2. runner
    • LocalRunner

      单进程, 在本地 CPU 上执行, 所以非本机 target 编译的 module 无法用 LocalRunner

    • RPCRunner

      多进程, 使用 RPC 在多个远端的 target 设备上执行, 所以这种方式可以并行执行, 而且可以执行为其它 target 编译的 module, 但现在的实现是依赖于一个基于 TCP 的中心 tracker 来连接 PC 和多个 target 设备, 和 MicroTVM 自己那一套基于串口 RPC 机制没有任何关系

虽然 MicroTVM 有一套自己的 Flasher, Transport 以及 RPC 机制, 现在看和 auto_scheduler 并不能一起工作, 为了让 auto_scheduler 支持 MicroTVM, 需要自己实现 builder 和 runner,

Author: [email protected]
Date: 2021-08-03 Tue 00:00
Last updated: 2023-12-01 Fri 18:28

知识共享许可协议