10. Tiled Execution IR#
We have assembled the necessary tools to execute tensor operations efficiently. This chapter introduces the Tiled Execution Intermediate Representation (TEIR), which describes the execution of tensor operations through primitives that operate on subtensors, known as tiles. The IR guides the backend implementation and controls when and where primitives are executed. Conceptually, the IR is closely related to tile-based programming models such as Triton, CUDA Tile IR, Pallas, and TileIR.
10.1. Specification#
Domains and Notation
Operation-Type Constraints
The operation type is determined by \(\mathrm{prim\_main}\). Unary operations use \(\mathrm{prim\_main}\in\{\mathrm{Copy},\mathrm{ReLU}\}\) and require \(T=2\) (two tensors). Binary contractions use \(\mathrm{prim\_main}\in\{\mathrm{GEMM},\mathrm{BRGEMM}\}\) and require \(T=3\) (three tensors).
Unary operations:
\(T=2\) (tensors \(\mathrm{in0}\), \(\mathrm{out}\)).
All axes MUST have type \(\mathrm{C}\): \(t_i=\mathrm{C}\; \forall i\).
\(\mathrm{prim\_first}=\mathrm{None}\).
\(\mathrm{prim\_main}\in\{\mathrm{Copy},\mathrm{ReLU}\}\).
\(\mathrm{prim\_last}=\mathrm{None}\).
Binary contractions:
\(T=3\) (tensors \(\mathrm{in0}\), \(\mathrm{in1}\), \(\mathrm{out}\)).
\(t_i \in \{\mathrm{C},\mathrm{M},\mathrm{N},\mathrm{K}\}\).
\(\mathrm{prim\_first}\in\{\mathrm{None},\mathrm{Zero}\}\).
\(\mathrm{prim\_main}\in\{\mathrm{GEMM},\mathrm{BRGEMM}\}\).
\(\mathrm{prim\_last}\in\{\mathrm{None},\mathrm{ReLU}\}\).
Records
Axis Roles
The tensor index \(j \in \{0,\ldots,T-1\}\) indexes the tensors in a configuration. The axis role mapping \(\mathbf m\) determines which tensors participate in a given axis: \(\mathbf m(t)[j]=1\) if tensor \(j\) participates in an axis of type \(t\), and \(0\) otherwise. For unary operations (\(T=2\)), only \(\mathrm{C}\) is permitted as a \(dim\_type\); every axis touches both tensors, so \(\mathbf m(\mathrm{C})=(1,1)\). For binary contraction operations (\(T=3\)), the full mapping is:
Well-Formedness
All schedule vectors have length \(D\); the stride tensor has shape \(T \times D\):
\(\mathbf{strides}[j][i]\) MUST be \(0\) whenever tensor \(j\) does not participate in axis \(i\), that is, whenever \(\mathbf m(t_i)[j] = 0\).
Primitive-Specific Requirements
Execution Semantics
- Execution Types
If \(e_i=\mathrm{prim}\), axis \(i\) is consumed inside the primitive(s). Values \(\mathrm{seq}\) and \(\mathrm{parallel}\) denote sequential and parallel traversal of axis \(i\) in the schedule. The overall schedule order is determined by the order of all axes with \(e_i \neq \mathrm{prim}\) as they appear in the TEIR-Schedule. Traversal proceeds from the first such axis (outermost) to the last (innermost). No ordering guarantees are imposed between multiple axes marked as \(\mathrm{parallel}\).
- First/Last-Access Primitives
Primitives \(\mathrm{prim\_first}\) and \(\mathrm{prim\_last}\) define initialization and finalization steps applied to output tiles:
\(\mathrm{prim\_first}\) is applied the first time an output tile is accessed in a given schedule.
\(\mathrm{prim\_last}\) is applied the last time an output tile is accessed.
10.2. Tensor Operation Configuration#
Section 10.1 contains the formal specification of the Tiled Execution Intermediate Representation (TEIR). TEIR comprises two records: TEIR-Primitives, which specifies the primitives to be executed, and TEIR-Schedule, which defines how these primitives are applied to tiles of the tensors. This section describes the IR from the perspective of a user who configures a tensor operation using TEIR.
Field |
Meaning |
Domain & constraints |
|---|---|---|
|
Axis roles across tensors (D axes) |
{ |
|
Execution policy per axis (D axes) |
{ |
|
Positive extent per axis |
|
|
Per-tensor strides (T×D stride tensor) |
|
Table 10.2.1 provides a concise informal form of TEIR-Schedule.
The field dim_types describes whether an axis spans both inputs and the output (C), only the first input and the output (M), only the second input and the output (N), or only the two inputs (K).
The field exec_types specifies the execution type of each axis.
Setting seq results in sequential execution of an axis.
Parallelization is achieved with parallel.
Axes with type prim are consumed inside the primitives.
The remaining fields describe the sizes of the axes in field dim_sizes and the data layout of each tensor in the stride tensor strides, where strides[j][i] gives the stride of tensor j along axis i.
Field |
Meaning |
Allowed values |
|---|---|---|
|
Data type of inputs and output |
{ |
|
First-access primitive |
{ |
|
Main primitive |
{ |
|
Last-access primitive |
{ |
Table 10.2.2 provides a short form of TEIR-Primitives.
TEIR-Primitives contains four fields.
data_type determines the data type of the input and output tensors.
In addition, up to three primitives can be used in TEIR.
The first-access primitive (prim_first) is applied to a tile of the output tensor when it is accessed for the first time.
Similarly, the last-access primitive (prim_last) is applied to a tile of the output tensor when it is accessed for the last time.
The possible types used by prim_first and prim_last are listed below (see the table above for the per-field constraints):
NoneNo primitive is executed.
ZeroZero the output tile.
ReLUApply ReLU to the output tile’s values.
By contrast, the main primitive (prim_main) is executed for every valid combination of input and output tiles in the TEIR schedule.
The main primitive can have one of the following types:
NoneNo primitive is executed.
CopyCopy the input tile’s values to the output tile.
GEMMGeneral Matrix Multiply (GEMM). Multiply two 2D input tiles and accumulate the result into the 2D output tile.
BRGEMMBatch-Reduce GEMM (BRGEMM). Perform a BRGEMM operation on two 3D input tiles and accumulate the result into the 2D output tile.
10.3. Example Configurations#
This section illustrates TEIR through example configurations for two types of tensor operations:
- Permutation
Permute the axes of the input tensor to obtain the output tensor: \(T_\text{out} = \operatorname{permute} \left( T_\text{in0} \right)\).
- Binary Tensor Contraction
Contract two input tensors to obtain an output tensor: \(T_\text{out} = \operatorname{contract} \left( T_\text{in0}, T_\text{in1} \right)\).
Each example provides a TEIR schedule table specifying the dimension types, execution types, sizes, and per-tensor strides, followed by pseudocode showing the corresponding execution.
10.3.1. Scalar Permutation#
A sequential element-wise permutation of a 4D tensor: the input has shape \(|a| \times |b| \times |c| \times |d|\) stored in row-major order, and the output stores the permuted result abcd -> cdab in row-major order.
All axes are sequential and have type C (unary operation).
The main primitive is Copy; first- and last-access primitives are None.
Dimension ID |
a |
b |
c |
d |
|---|---|---|---|---|
dim_sizes |
|a| |
|b| |
|c| |
|d| |
dim_types |
C |
C |
C |
C |
exec_types |
seq |
seq |
seq |
seq |
strides[in0] |
|b| × |c| × |d| |
|c| × |d| |
|d| |
1 |
strides[out] |
|b| |
1 |
|d| × |a| × |b| |
|a| × |b| |
for a in |a|
for b in |b|
for c in |c|
for d in |d|
out[c][d][a][b] = in0[a][b][c][d]
10.3.2. Tiled Permutation#
Same permutation as the scalar permutation, but axes d and b are consumed by a Copy primitive that copies a 2D tile at once.
The outer axes a and c are traversed sequentially, and each iteration copies a tile of shape \(|d| \times |b|\).
Dimension ID |
a |
c |
d |
b |
|---|---|---|---|---|
dim_sizes |
|a| |
|c| |
|d| |
|b| |
dim_types |
C |
C |
C |
C |
exec_types |
seq |
seq |
prim |
prim |
strides[in0] |
|b| × |c| × |d| |
|d| |
1 |
|c| × |d| |
strides[out] |
|b| |
|d| × |a| × |b| |
|a| × |b| |
1 |
for a in |a|
for c in |c|
Copy( in = in0[a][0][c],
out = out[c][0][a],
m = |d|,
n = |b|,
ldI = |c| * |d|,
ldO = |a| * |b| )
10.3.3. Scalar Batched GEMM#
All dimensions are sequential. Each iteration performs a single scalar multiply-accumulate.
Dimension ID |
a |
b |
c |
d |
|---|---|---|---|---|
dim_sizes |
|a| |
|b| |
|c| |
|d| |
dim_types |
K |
M |
N |
C |
exec_types |
seq |
seq |
seq |
seq |
strides[in0] |
1 |
|a| |
0 |
|b| × |a| |
strides[in1] |
|c| |
0 |
1 |
|a| × |c| |
strides[out] |
0 |
|c| |
1 |
|b| × |c| |
for a in |a|
for b in |b|
for c in |c|
for d in |d|
out[d][b][c] += in0[d][b][a] * in1[d][a][c]
10.3.4. Scalar Batched GEMM — Reordered#
Same operation as the scalar batched GEMM, but with the contraction axis K moved to the innermost position.
This changes the traversal order without affecting the result.
Dimension ID |
b |
c |
d |
a |
|---|---|---|---|---|
dim_sizes |
|b| |
|c| |
|d| |
|a| |
dim_types |
M |
N |
C |
K |
exec_types |
seq |
seq |
seq |
seq |
strides[in0] |
|a| |
0 |
|b| × |a| |
1 |
strides[in1] |
0 |
1 |
|a| × |c| |
|c| |
strides[out] |
|c| |
1 |
|b| × |c| |
0 |
for b in |b|
for c in |c|
for d in |d|
for a in |a|
out[d][b][c] += in0[d][b][a] * in1[d][a][c]
10.3.5. Scalar Tensor Contraction#
A general tensor contraction trus,pqtu->pqrs with six sequential axes and scalar execution.
Each of the M, N, and K roles spans two axes.
Dimension ID |
p |
q |
r |
s |
t |
u |
|---|---|---|---|---|---|---|
dim_sizes |
|p| |
|q| |
|r| |
|s| |
|t| |
|u| |
dim_types |
N |
N |
M |
M |
K |
K |
exec_types |
seq |
seq |
seq |
seq |
seq |
seq |
strides[in0] |
0 |
0 |
|u| × |s| |
1 |
|r| × |u| × |s| |
|s| |
strides[in1] |
|q| × |t| × |u| |
|t| × |u| |
0 |
0 |
|u| |
1 |
strides[out] |
|q| × |r| × |s| |
|r| × |s| |
|s| |
1 |
0 |
0 |
for p in |p|
for q in |q|
for r in |r|
for s in |s|
for t in |t|
for u in |u|
out[p][q][r][s] += in0[t][r][u][s] * in1[p][q][t][u]
10.3.6. Tensor Contraction with GEMM#
Same contraction as the scalar tensor contraction, but axes s, q, u are consumed by a GEMM primitive.
The remaining axes p, r, t are traversed sequentially.
Dimension ID |
p |
r |
t |
s |
q |
u |
|---|---|---|---|---|---|---|
dim_sizes |
|p| |
|r| |
|t| |
|s| |
|q| |
|u| |
dim_types |
N |
M |
K |
M |
N |
K |
exec_types |
seq |
seq |
seq |
prim |
prim |
prim |
strides[in0] |
0 |
|u| × |s| |
|r| × |u| × |s| |
1 |
0 |
|s| |
strides[in1] |
|q| × |t| × |u| |
0 |
|u| |
0 |
|t| × |u| |
1 |
strides[out] |
|q| × |r| × |s| |
|s| |
0 |
1 |
|r| × |s| |
0 |
for p in |p|
for r in |r|
for t in |t|
GEMM( A = in0[t][r],
B = in1[p][0][t],
C = out[p][0][r],
m = |s|,
n = |q|,
k = |u|,
ldA = |s|,
ldB = |t| * |u|,
ldC = |r| * |s| )
Fig. 10.3.1 and Fig. 10.3.2 illustrate this configuration for a concrete instance with |r|=3, |t|=2, and |p|=4.
Each tensor is drawn as a grid of tiles.
The seq axes (p, r, t) index tiles across each grid; the prim axes span the interior of each tile according to the axis roles:
s and u for in0, u and q for in1, and s and q for out.
Spatial alignment of the grids reflects the dimension roles:
r aligns the rows of in0 and out (M),
p aligns the columns of in1 and out (N), and t connects the columns of in0 to the rows of in1 (K, contracted).
Fig. 10.3.2 shows the memory layout within and across tiles.
Fig. 10.3.1 Illustration of the dimensions for the binary tensor contraction trus,pqtu->pqrs.#
Fig. 10.3.2 Illustration of the memory layout for the binary tensor contraction trus,pqtu->pqrs.#
10.3.7. Parallel Tensor Contraction with BRGEMM#
Same contraction as the scalar and GEMM-based examples.
The contraction axis t is absorbed into a BRGEMM primitive alongside s, q, u.
The outer axes p and r use parallelism.
Dimension ID |
p |
r |
t |
s |
q |
u |
|---|---|---|---|---|---|---|
dim_sizes |
|p| |
|r| |
|t| |
|s| |
|q| |
|u| |
dim_types |
N |
M |
K |
M |
N |
K |
exec_types |
parallel |
parallel |
prim |
prim |
prim |
prim |
strides[in0] |
0 |
|u| × |s| |
|r| × |u| × |s| |
1 |
0 |
|s| |
strides[in1] |
|q| × |t| × |u| |
0 |
|u| |
0 |
|t| × |u| |
1 |
strides[out] |
|q| × |r| × |s| |
|s| |
0 |
1 |
|r| × |s| |
0 |
#pragma omp parallel for collapse(2)
for p in |p|
for r in |r|
BRGEMM( A = in0[0][r],
B = in1[p][0],
C = out[p][0][r],
m = |s|,
n = |q|,
k = |u|,
ldA = |s|,
ldB = |t| * |u|,
ldC = |r| * |s|,
brSize = |t|,
brStrA = |r| * |u| * |s|,
brStrB = |u| )