[NOTE] Compilers and Programming Models
Classic compiler frameworks such as Clang, LLVM, MLIR, etc. AI optimizing compilers like Torch, TensorRT, TVM, XLA/HLO, JAX/Pallas, Triton, etc. Also includes wider topics such as programming models, and program synthesis techniques, such as High-Level Synthesis (HLS), target-specific domain specific languages, etc.
Compiler Basics
-
SSA (static single assignment): each variable is assigned exactly once.
- SSA allows for precise tracking and easier program analysis/optimization. such as CSE, DCE, and register allocation.
- IMO, using SSA can expand the use-def relationship into a tree structure, while not having SSA will form a graph structure with loop-backs (e.g., when a variable is assigned multiple times in a loop).
-
Reflection: ability of a program to inspect, analyze, and modify its own structure, data, or behavior at runtime. We can use the reflective APIs from the language to inspect the program structure, and modify the program structure.
import inspect
src = inspect.getsource(matrix_add)
print(src)
Google JAX/Pallas
-
JIT + Autograd + XLA:
- Jax maps numpy functions to TPU/GPU/CPU
- Support automatic differentiation (in reverse and forward mode)
- XLA-backed JIT compilation
- Jax is async executed
-
Pallas: JAX extension for low-level kernels
- Similar programming model as Triton (memory, compute, schedule on blocks)
- On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic.
AWS Neuron SDK
-
Compilation flow
- Target-agnostic: Torch-XLA => XLA/HLO
- Hardware-specific: Penguin IR (loop transformation, hardware intrinsics mapping) => BIR/Walrus => assembly (NEFF, execution file format)
-
Tensor layout restrictions:
- Free dimension (placed in one par_dim, can support random access)
- Partitioned dimension (placed across multiple par_dims, can only support sequential access)
PyTorch
-
Computational graph acquisition:
-
Eager execution: does not offer the compiler based optimization (e.g., graph operation fusion, constant folding, etc.) as static graph approach. However, eager execution is more flexible for prototyping and easier to debug
-
Script mode: production-oriented (getting rid of python GIL and runtime deps)
jit.trace/jit.script
: to trace/parse ops in eager modules into TorchScript module (static graph in torch IR). No control flow or data strictures recorded in tracing mode- JIT optimizer: the captured IR is JIT-optimized with, e.g., operator fusion, sparsification, etc.
-
LTC (Lazy TensorCore): is a tracing system introduced in PyTorch/XLA. Operations performed on XLA tensors are recorded into a graph lazily, and compiled & dispatched (i.e. the tensors have been materialized) for async device execution when reaching a barrier.
-
Torch FX*: more intended for Py-to-Py module transformation in Torch. Can be used along Eager mode. A FX-transformed torch python module can be passed to TorchScript for further deployment.
-
-
Torch2 new features
- Dynamo: JIT compiler solution that captures the CPython bytecode and converts it to torch FX graph on the fly, and generate optimized backend code from FX (e.g., Triton code in Inductor backend)
LLVM
MLIR
- Example of MLIR assembly format using
memeref
andlinalg
dialects to implement a batch norm function.
func.func @batchnorm(
%input: memref<1x3x224x224xi9>,
%weight: memref<3xi8>,
%bias: memref<3xi8>
) -> memref<1x3x224x224xi9> {
%alloc = memref.alloc() {name="output"} : memref<1x3x224x224xi9>
linalg.generic
// op attributes
{
indexing_maps = [
affine_map<(i, j, k, l) -> (i, j, k, l)>,
affine_map<(i, j, k, l) -> (j)>,
affine_map<(i, j, k, l) -> (j)>,
affine_map<(i, j, k, l) -> (i, j, k, l)>
],
iterator_types = [
"parallel", "parallel", "parallel", "parallel"
]
}
// basic block
ins(%input, %weight, %bias:
memref<1x3x224x224xi9>, memref<3xi8>, memref<3xi8>)
outs(%alloc: memref<1x3x224x224xi9>) {
^bb0(%in: i9, %w: i8, %b: i8, %a: i9):
%in_cast = arith.trunci %in : i9 to i8
%0 = arith.muli %in_cast, %w : i8
%1 = arith.addi %0, %b : i8
%2 = arith.extsi %1 : i8 to i9
linalg.yield %2 : i9
}
return %alloc : memref<1x3x224x224xi9>
}