[NOTE] Compilers and Programming Models

Classic compiler frameworks such as Clang, LLVM, MLIR, etc. AI optimizing compilers like Torch, TensorRT, TVM, XLA/HLO, JAX/Pallas, Triton, etc. Also includes wider topics such as programming models, and program synthesis techniques, such as High-Level Synthesis (HLS), target-specific domain specific languages, etc.


Compiler Basics

import inspect
src = inspect.getsource(matrix_add)
print(src)

Google JAX/Pallas


AWS Neuron SDK


PyTorch

LLVM

MLIR

func.func @batchnorm(
  %input: memref<1x3x224x224xi9>, 
  %weight: memref<3xi8>, 
  %bias: memref<3xi8>
) -> memref<1x3x224x224xi9> {
    %alloc = memref.alloc() {name="output"} : memref<1x3x224x224xi9>
    
    linalg.generic 
      // op attributes
      {
        indexing_maps = [
          affine_map<(i, j, k, l) -> (i, j, k, l)>, 
          affine_map<(i, j, k, l) -> (j)>, 
          affine_map<(i, j, k, l) -> (j)>, 
          affine_map<(i, j, k, l) -> (i, j, k, l)>
        ], 
        iterator_types = [
          "parallel", "parallel", "parallel", "parallel"
        ]
      } 

      // basic block 
      ins(%input, %weight, %bias: 
          memref<1x3x224x224xi9>, memref<3xi8>, memref<3xi8>) 
      outs(%alloc: memref<1x3x224x224xi9>) {
        ^bb0(%in: i9, %w: i8, %b: i8, %a: i9):
            %in_cast = arith.trunci %in : i9 to i8
            %0 = arith.muli %in_cast, %w : i8
            %1 = arith.addi %0, %b : i8
            %2 = arith.extsi %1 : i8 to i9
            linalg.yield %2 : i9
      }

    return %alloc : memref<1x3x224x224xi9>
}

Vitis HLS C++