Files

Fused-MoE with CK Tile

This example implements a highly optimized fused Mixture-of-Experts (MoE) block using the CK Tile programming model. The design fuses MoE sorting, group-GEMM, activation, and top-k weighting into a single kernel, minimizing memory traffic and maximizing throughput for large language models.


Algorithm and Math

MoE Block Structure

Given:

  • Input: X of shape [\text{tokens}, \text{hidden}]
  • TopK indices/weights: I, W from gating (shape [\text{tokens}, \text{topk}])
  • Expert weights: [\text{experts}, \text{hidden}, \text{hidden}]

Steps:

  1. MoE Sorting: Rearrange tokens so each expert receives its assigned tokens in contiguous blocks (see 13_moe_sorting).
  2. Group-GEMM: For each expert, perform GEMM on its assigned tokens:
    
    Y^{(e)} = X^{(e)} W^{(e)}
    
  3. Activation + TopK Weighting: Apply activation (e.g., GELU) and multiply by top-k weights.
  4. Scatter/Gather: Write results back to the original token order.

Technical Details

  • Scatter/Gather Group-GEMM: Uses indirect indexing to map tokens to experts and back.
  • Block Partitioning: Tokens are partitioned into slices per expert, with padding for alignment.
  • Atomic Accumulation: Second GEMM uses atomics for accumulation to support overlapping tokens.
  • Buffer Zeroing: Output buffer is zeroed in the sorting step, eliminating extra kernels.
  • Pre-shuffled Weights: Expert weights are pre-shuffled for coalesced memory access.
  • Micro-kernel Pipeline: Uses block-inline-asm micro-kernels for peak performance, while retaining composability.

Build & Run

mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_fused_moe -j
./bin/tile_example_fused_moe -?

Source Structure


Technical Notes

This is a scatter/gather-group-gemm based solution, similiar to that of vllm moe, but we introduce more kernel fusion to boost performance

The benifit of this fused-moe:

  • 1.5~2x perf boost compared with current vllm solution
  • zero workspace to reduce memory footprint
  • much less kernel instance, easy to maintain

Implementation and feature support

NOTES:

currently gate+up in fp16 case will very easily cause accumulator overflow the fp16 max(65504), hence result in INF. Please use BF16 for gate+up case, API side will have no check for this.

moe-sorting

this is a common pre-process step before the actual moe-gemm. The purpose is to transform the moe loop over from token-by-token to expert-by-expert, make sure very workgroup is working for a single expert (B matrix). Besides, we extend this op to do the zeroing of the output buffer(to be used for reduce buffer with atomic)

moe-gemm

moe-gemm is a group-gemm based back-to-back gemm, where the row-id of input token comes from another buffer. Naive understanding of fused-moe is from token-by-token view as below picture: After moe-sorting, we can view this algorithm as expert-by-expert, as below:

optimization

summary of the key design of this fused-moe operator:

  • fuse 2 group-gemm + activation + topk-weight multiply into single kernel, using atomic for 2nd gemm accumualation
  • fuse buffer-zeroing in moe-sorgin, user no longer need call extra torch.zero() for the out buffer
  • fused scatter-gather for row index(same as vllm)
  • pre-shuffle B matric(weight) to maximize memory throughput. input(activation) keep original layout [batch, hidden].
  • extrem optimized pipeline using block-inline-asm(we call it micro-kernel or uk), while not breaking the composable design of ck

// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
//                            tok-0      tok-1      tok-2      tok-3      tok-4
//           topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
//  (only for reference)    exp-0  exp-1     exp-2   exp-3          exp-4  exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr   : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
//                          |-  exp-0  -|-  exp-1  -|-  exp-2  -|-      exp-3          -|-  exp-4 -|-  exp-5  -|
// sorted_weight_ptr      : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr  : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
//   1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
//   2need sorted_weight_ptr
//   3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
//  1) sorted_token_ids_ptr [max_num_tokens_padded]
//  2) sorted_weight_ptr
//  3) sorted_expert_ids_ptr
//  4num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
//   max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)

example

args:
          -t    number of input tokens. (default:128)
                If "local_t" presents, this value indicates global concurrency of all ranks.
    -local_t    Number of local input tokens for curent rank. (default:-1)
                This value must be within range "[0, t)", or "-1"(no such feature)
                This feature is to simulate EP case where where each rank has different tokens.
                Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.
          -e    num of experts (default:32)
          -k    topk (default:5)
          -h    hidden_size of this model (default:8192)
          -i    intermediate_size between 2 gemms of FFN (default:8192)
     -stride    stride per row, if -1 then equal to hidden_size (default:-1)
         -bm    blocking factor for sorted tokens (default:32)
         -tp    tensor parallel size (default:8)
          -v    cpu validation or not (default:1)
      -kname    print kernel name or not (default:1)
     -prec_i    input precision (default:bf16)
     -prec_w    weight precision (default:bf16)
     -prec_o    output precision (default:bf16)
    -prec_st    token scale data type. auto will set to fp32 (default:auto)
    -prec_sw    weight scale data type. auto will set to fp32 (default:auto)
    -prec_sq    (dynamic) smooth quant data type. auto will set to fp32 (default:auto)
    -prec_kw    topk-weight data type. auto will set to fp32 (default:auto)
     -fquant    fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0)
  -gate_only    w0(gate/up) style, 0:gate+up will double interm size, 1:only gate (default:1)
        -api    benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm (default:0)
        -act    activation after first gemm. 0:gelu, 1:silu (default:0)
    -balance    if set to 1, will try balance the expert in topk-ids(convenient for testing) (default:0)
       -init    init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand normalized[0, 1]normalized(slow) (default:1)
       -seed    seed used to do random (default:11939)
     -warmup    cold iter (default:5)
     -repeat    hot iter (default:20)
       -json    0: No Json, 1: Dump Results in Json format (default:0)
   -jsonfile    json file name to dump results (default:fused_moe.json)

For distribution, see include/ck_tile/tile_program/tile_distribution/.


Back to CK Tile Examples