Tensor Program Abstraction

Credit to: mlc.ai by Tianqi Chen

Before all, there is a way to print script in pretty way:

import IPython

IPython.display.Code(MyModule.script(), language="python")

Primitive Tensor Function

Primitive Tensor Function can be seen as a self-contained unit responsible for getting the input data, applying the execution, and outputting the expected result. Linear, Add, Relu can be seen as a primitive tensor function, fused function like linear_add or add_relu can also be seen as a primitive tensor function.

Primitive Tensor Function does not restrict its implementation, take add as an example, we can call torch.add or write a vanilla add using python, or even write a parallelized add with the aid of OpenMP.

Example of TPA

The typical Tensor Program Abstraction contains several parts: buffers, loop nests, and computation statement.

from tvm.script import tir as T

@T.prim_func
def main(A: T.Buffer[128, "float32"],     # (Multi-dimensional) buffers that
         B: T.Buffer[128, "float32"],     # holds the input, output, and
         C: T.Buffer[128, "float32"]):    # intermediate results.

    for i in range(128):                  # Loop nests that drive compute iterations.
        with T.block("C"):                # Blocks can be retrieved and iterated in IRModules.
            vi = T.axis.spatial(128, i)   # Extra information about iteration. (Spatial or Reduction)
            C[vi] = A[vi] + B[vi]         # Computations statements.

Essential classes

A schedule is a set of transformations that change the order of computation but preserve the semantics of computation.

TensorIR

Do2Learn

import tvm
from tvm.script import tir as T

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(A: T.Buffer[128, "float32"],     # (Multi-dimensional) buffers that
             B: T.Buffer[128, "float32"],     # holds the input, output, and
             C: T.Buffer[128, "float32"]):    # intermediate results.

        for i in range(128):                  # Loop nests that drive compute iterations.
            with T.block("C"):                # Blocks can be retrieved and iterated in IRModules.
                vi = T.axis.spatial(128, i)   # Extra information about iteration. (Spatial or Reduction)
                C[vi] = A[vi] + B[vi]         # Computations statements.

IRModule has attribute func::script to output the decorated function.

print(MyModule.script())
-------->
@tvm.script.ir_module
class Module:
    @tir.prim_func
    def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:
        # body
        # with tir.block("root")
        for i in tir.serial(128):
            with tir.block("C"):
                vi = tir.axis.spatial(128, i)
                tir.reads(A[vi], B[vi])
                tir.writes(C[vi])
                C[vi] = A[vi] + B[vi]

Staying in IRModule is not enough, if you want any transformation on the original code, initiate a class::Schedule to apply transformations.

sch = tvm.tir.Schedule(MyModule)
type(sch)
-------->
tvm.tir.schedule.schedule.Schedule

With Schedule, code can be equipped with optimizations and can run.

print(list(filter(lambda x: not x.startswith("_"), dir(sch))))
-------->
['add_unit_loop', 'annotate', 'bind', 'blockize', 'cache_read', 'cache_write', 'compute_at', 'compute_inline', 'copy', 'decompose_reduction', 'enter_postproc', 'fork_seed', 'fuse', 'get', 'get_block', 'get_child_blocks', 'get_consumers', 'get_loops', 'get_producers', 'get_sref', 'handle', 'mod', 'parallel', 'reindex', 'remove_rv', 'reorder', 'reverse_compute_at', 'reverse_compute_inline', 'rfactor', 'same_as', 'sample_categorical', 'sample_compute_location', 'sample_perfect_tile', 'seed', 'set_axis_separator', 'set_scope', 'show', 'split', 'state', 'storage_align', 'tensorize', 'trace', 'transform_block_layout', 'transform_layout', 'unannotate', 'unroll', 'vectorize']

For example, we can get the attributes of the code from sch using the func::get_xxx

block_c = sch.get_block("C")        # get the annotated block, the second parameter is the primitive function name with default value "main"
# sch.get_loops returns a list of loop, using i, will automatically unpack the list
i, = sch.get_loops(block_c)         # getting the loops of a block is getting the outside loops of the block
print(i)
-------->
tir.LoopRV(0x1d353c0)

When we get the loops, we can do many fancy things, e.g. unroll loop, vectorize, parallelize...

i0, i1, i2 = sch.split(i, factors=[None, 4, 4])   # None is a placeholder which will get the proper number
print(sch.mod.script())
-------->
@tvm.script.ir_module
class Module:
    @tir.prim_func
    def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:
        # body
        # with tir.block("root")
        for i_0, i_1, i_2 in tir.grid(8, 4, 4):
            with tir.block("C"):
                vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)
                tir.reads(A[vi], B[vi])
                tir.writes(C[vi])
                C[vi] = A[vi] + B[vi]

As we can see, the schedule holds a copy of IRModule with the name of mod. Every transformation will result in the change of mod.script().

There can be more transformations:

sch.reorder(i2, i1, i0)
sch.parallel(i0)
sch.vectorize(i1)
print(sch.mod.script())
-------->
@tvm.script.ir_module
class Module:
    @tir.prim_func
    def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:
        # body
        # with tir.block("root")
        for i_2 in tir.serial(4):
            for i_1 in tir.vectorized(4):
                for i_0 in tir.parallel(8):
                    with tir.block("C"):
                        vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)
                        tir.reads(A[vi], B[vi])
                        tir.writes(C[vi])
                        C[vi] = A[vi] + B[vi]

From the example above, we have a little comprehension of it:

  1. The computation part is bounded inside a block, like C in the example.

  2. Every attempt of optimization on the original script needs a concrete schedule to implement. (I can not find the reverse operation till now)

build and run

After we make some fancy change, we can export it to a runnable function using tvm:

rt_mod = tvm.build(sch.mod, target="llvm")         # build with a supported target as the backend
func = rt_mod["main"]                              # build all prim functions inside an IRModule and retrieve by mapping **global symbol** name
type(func)                                         # tvm.runtime.packed_func.PackedFunc
# Now the func is a callable function
a = tvm.nd.array(np.arrange(128, dtype="float32")) # to run in tvm PackedFunc, we need tvm specified data
b = tvm.nd.array(np.ones(128, dtype="float32"))
c = tvm.nd.empty([128], dtype="float32")           # allocate an empty array to store the result
# Call func as usual
func(a, b, c)

To evaluate the performance of optimization, tvm provides a evaluator:

func_timer = rt_mod.time_evaluator("main", tvm.cpu())
print("Time cost of MyModule is %g sec" % func_timer(a, b, c).mean())

Case Study on MM_Relu

Low-level numpy implementation:

def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((128, 128), dtype="float32")
    for i in range(128):
        for j in range(128):
            for k in range(128):
                if k == 0:
                    Y[i, j] = 0
                Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
    for i in range(128):
        for j in range(128):
            C[i, j] = max(Y[i, j], 0)
# `@tvm.script.ir_module` indicate that MyModule is an IRModule. 
# IRModule is the container object to hold a collection of tensor functions in machine learning compilation. 
@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(A: T.Buffer[(128, 128), "float32"],
                B: T.Buffer[(128, 128), "float32"],
                C: T.Buffer[(128, 128), "float32"]):
        # `global_symbol` corresponds to the name of the function.
        # `tir.noalias` is an attribute indicating that all the buffer memories do not overlap.
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                vk = T.axis.reduce(128, k)
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

The differences between low-level numpy and IRModule are:

  1. variable type: tvm treats all variables as buffer, no matter parameters or temporary variables.

  2. for loop: T.grid is a sugar indicating nested loops.

  3. block: a block is a basic unit of computation in TensorIR, the computation parts are wrapped in annotated blocks like "Y" and "C", they can be retrieved by sch.get_block.

  4. extra information: the more information we provide to compiler, the more the compiler can do for us. Like the T.axis.spatial and T.axis.reduce respectively indicate sequence-irrelevant and reduction required.

Block Axis properties

for i, j, k in T.grid(128, 128, 128):
    with T.block("Y"):
        vi = T.axis.spatial(128, i)
        vj = T.axis.spatial(128, j)
        vk = T.axis.reduce(128, k)
        with T.init():
            Y[vi, vj] = T.float32(0)
        Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]

Notably, for a fixed value of vi and vj, the computation block produces a point value at a spatial location of Y (Y[vi, vj]) that is independent from other locations in Y (with a different vi, vj values). we can call vi, vj spatial axes as they directly corresponds to the beginning of a spatial region of buffers that the block writes to. The axes that involves in reduction (vk) are named as reduce axes.

The extra information of axis can help us validate the correctness. For example, if the axis requires an iterator of 128 but binds to a for loop of range(127), it will raise an exception. Besides, extra information can let the compiler make more fancy things according to their dependency and relevance.

Sugars

The initialization of axises can be simplified as:

# SSR means the properties of each axes are "spatial", "spatial", "reduce"
vi, vj, vk = T.axis.remap("SSR", [i, j, k])

Function Annotations

T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})

Here are two attributes:

  • global_symbol: the name of function that is unique in this IRModule, will be used when retrieving function inside built Module.

  • tir.noalias: indicating that all the buffer memories do not overlap.

Transformation

@tvm.script.ir_module
class MyModuleWithAxisRemapSugar:
    @T.prim_func
    def mm_relu(A: T.Buffer[(128, 128), "float32"],
                B: T.Buffer[(128, 128), "float32"],
                C: T.Buffer[(128, 128), "float32"]):
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

If we apply the transformation of fully utilizing the cache of matrix B like this in low-level numpy:

def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((128, 128), dtype="float32")
    for i in range(128):
        for j0 in range(32):
            for k in range(128):
                for j1 in range(4):                             # Here compute 4 continuous number in matrix B/Y at one time
                    j = j0 * 4 + j1
                    if k == 0:
                        Y[i, j] = 0
                    Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
    for i in range(128):
        for j in range(128):
            C[i, j] = max(Y[i, j], 0)

In this case, we need to:

  1. split the j loop into two parts

  2. reorder the j loop with k loop

# Initialize the schedule
sch = tvm.tir.Schedule(MyModuleWithAxisRemapSugar)

# get the block
block_Y = sch.get_block("Y", func_name="mm_relu")
# get the outside loops of Y
i, j, k, = sch.get_loops(block_Y)
# split j into two parts
j0, j1 = sch.split(j, factors=[None, 4])
# reorder
sch.reorder(j0, k, j1)

IPython.display.Code(sch.mod.script(), language="python")
-------->
@tvm.script.ir_module
class Module:
    @tir.prim_func
    def mm_relu(A: tir.Buffer[(128, 128), "float32"], B: tir.Buffer[(128, 128), "float32"], C: tir.Buffer[(128, 128), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        # body
        # with tir.block("root")
        Y = tir.alloc_buffer([128, 128], dtype="float32")
        for i, j_0, k, j_1 in tir.grid(128, 32, 128, 4):
            with tir.block("Y"):
                vi = tir.axis.spatial(128, i)
                vj = tir.axis.spatial(128, j_0 * 4 + j_1)
                vk = tir.axis.reduce(128, k)
                tir.reads(A[vi, vk], B[vk, vj])
                tir.writes(Y[vi, vj])
                with tir.init():
                    Y[vi, vj] = tir.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in tir.grid(128, 128):
            with tir.block("C"):
                vi, vj = tir.axis.remap("SS", [i, j])
                tir.reads(Y[vi, vj])
                tir.writes(C[vi, vj])
                C[vi, vj] = tir.max(Y[vi, vj], tir.float32(0))

What's more, we can find that the block C and block Y shares part of the loop, therefore we can combine them together.

block_C = sch.get_block("C", func_name="mm_relu")
sch.reverse_compute_at(block_C, j0)

IPython.display.Code(sch.mod.script(), language="python")
-------->
@tvm.script.ir_module
class Module:
    @tir.prim_func
    def mm_relu(A: tir.Buffer[(128, 128), "float32"], B: tir.Buffer[(128, 128), "float32"], C: tir.Buffer[(128, 128), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        # body
        # with tir.block("root")
        Y = tir.alloc_buffer([128, 128], dtype="float32")
        for i, j_0 in tir.grid(128, 32):
            for k, j_1 in tir.grid(128, 4):
                with tir.block("Y"):
                    vi = tir.axis.spatial(128, i)
                    vj = tir.axis.spatial(128, j_0 * 4 + j_1)
                    vk = tir.axis.reduce(128, k)
                    tir.reads(A[vi, vk], B[vk, vj])
                    tir.writes(Y[vi, vj])
                    with tir.init():
                        Y[vi, vj] = tir.float32(0)
                    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
            for ax0 in tir.serial(4):
                with tir.block("C"):
                    vi = tir.axis.spatial(128, i)
                    vj = tir.axis.spatial(128, j_0 * 4 + ax0)
                    tir.reads(Y[vi, vj])
                    tir.writes(C[vi, vj])
                    C[vi, vj] = tir.max(Y[vi, vj], tir.float32(0))

Last, we can separate the initialization part and the computation part of Y using func::decompose_reduction:

sch.decompose_reduction(block_Y, k)
IPython.display.Code(sch.mod.script(), language="python")
------->

@tvm.script.ir_module
class Module:
    @tir.prim_func
    def mm_relu(A: tir.Buffer[(128, 128), "float32"], B: tir.Buffer[(128, 128), "float32"], C: tir.Buffer[(128, 128), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        # body
        # with tir.block("root")
        Y = tir.alloc_buffer([128, 128], dtype="float32")
        for i, j_0 in tir.grid(128, 32):
            for j_1_init in tir.serial(4):
                with tir.block("Y_init"):
                    vi = tir.axis.spatial(128, i)
                    vj = tir.axis.spatial(128, j_0 * 4 + j_1_init)
                    tir.reads()
                    tir.writes(Y[vi, vj])
                    Y[vi, vj] = tir.float32(0)
            for k, j_1 in tir.grid(128, 4):
                with tir.block("Y_update"):
                    vi = tir.axis.spatial(128, i)
                    vj = tir.axis.spatial(128, j_0 * 4 + j_1)
                    vk = tir.axis.reduce(128, k)
                    tir.reads(Y[vi, vj], A[vi, vk], B[vk, vj])
                    tir.writes(Y[vi, vj])
                    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
            for ax0 in tir.serial(4):
                with tir.block("C"):
                    vi = tir.axis.spatial(128, i)
                    vj = tir.axis.spatial(128, j_0 * 4 + ax0)
                    tir.reads(Y[vi, vj])
                    tir.writes(C[vi, vj])
                    C[vi, vj] = tir.max(Y[vi, vj], tir.float32(0))

Another way to create and interact with TensorIR

Intro to Tensor Expression

Tensor expression (te) is a domain-specific language that describes a sequence of computations via an expression like API.

from tvm import te

A = te.placeholder((128, 128), "float32", name="A")
B = te.placeholder((128, 128), "float32", name="B")
k = te.reduce_axis((0, 128), "k")
Y = te.compute((128, 128), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="Y")
C = te.compute((128, 128), lambda i, j: te.max(Y[i, j], 0), name="C")

Here te.compute takes the signature te.compute(output_shape, fcompute). And the fcompute function (the lambda function in the example) describes how we want to compute the value of each element Y[i, j] for a given index.

In this particular case, we want to create a function with two input parameters (A, B) and one output parameter (C).

te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"})
MyModuleFromTE = tvm.IRModule({"mm_relu": te_func})
IPython.display.Code(MyModuleFromTE.script(), language="python")
-------->
@tvm.script.ir_module
class Module:
    @tir.prim_func
    def mm_relu(A: tir.Buffer[(128, 128), "float32"], B: tir.Buffer[(128, 128), "float32"], C: tir.Buffer[(128, 128), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        # body
        # with tir.block("root")
        Y = tir.alloc_buffer([128, 128], dtype="float32")
        for i0, i1, i2 in tir.grid(128, 128, 128):
            with tir.block("Y"):
                i, j, k = tir.axis.remap("SSR", [i0, i1, i2])
                tir.reads(A[i, k], B[k, j])
                tir.writes(Y[i, j])
                with tir.init():
                    Y[i, j] = tir.float32(0)
                Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
        for i0, i1 in tir.grid(128, 128):
            with tir.block("C"):
                i, j = tir.axis.remap("SS", [i0, i1])
                tir.reads(Y[i, j])
                tir.writes(C[i, j])
                C[i, j] = tir.max(Y[i, j], tir.float32(0))

The tensor expression API provides a helpful tool to generate TensorIR functions for a given higher-level input.

Last updated