Files
composable_kernel/example/ck_tile/15_fused_moe/README.md
Vidyasagar Ananthan 92c67a824f [DOCS] Documentation Addition (Readme updates) (#2495)
* GH-2368 Adding a basic glossary

GH-2368 Minor edits

GH-2368 Adding missing READMEs and standardization.

resolving readme updates

GH-2368 Minor improvements to documentation.

Improving some readmes.

Further improvement for readmes.

Cleaned up the documentation in 'client_example' (#2468)

Update for PR

Update ACRONYMS.md to remove trivial terms

Update ACRONYMS.md to provide detailed explanations for BF16 and BF8 formats

Apply suggestion from @spolifroni-amd

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

Apply suggestion from @spolifroni-amd

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

Update README.md to clarify CK Tile API description and remove outdated references to the Tile Engine.

revise 37_transpose readme

revise 36_copy readme

Remove references to the Tile Engine in README files for 19_gemm_multi_d and 35_batched_transpose, and update distribution links for clarity.

Remove references to the Tile Engine in multiple README files and update distribution links for consistency and clarity.

Remove references to the Tile Engine in README files across multiple examples

* GH-2368 Adding a basic glossary

GH-2368 Minor edits

GH-2368 Adding missing READMEs and standardization.

resolving readme updates

GH-2368 Minor improvements to documentation.

Improving some readmes.

Further improvement for readmes.

Cleaned up the documentation in 'client_example' (#2468)

Update for PR

Update ACRONYMS.md to remove trivial terms

Update ACRONYMS.md to provide detailed explanations for BF16 and BF8 formats

Apply suggestion from @spolifroni-amd

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

Apply suggestion from @spolifroni-amd

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

Update README.md to clarify CK Tile API description and remove outdated references to the Tile Engine.

revise 37_transpose readme

revise 36_copy readme

Remove references to the Tile Engine in README files for 19_gemm_multi_d and 35_batched_transpose, and update distribution links for clarity.

Remove references to the Tile Engine in multiple README files and update distribution links for consistency and clarity.

Remove references to the Tile Engine in README files across multiple examples

Refine README files by removing outdated references to the Tile Engine

* Updates based on PR feedback 1

* Updates based on PR feedback 2

* Updates based on PR feedback 3

* Updates based on PR feedback 4

* Updates based on PR feedback 5

* Updates based on PR feedback 6

* Updates based on PR feedback 7

* Updates based on PR feedback 8

* Content Modification of CK Tile Example

* Modify the ck_tile gemm config

---------

Co-authored-by: AviralGoelAMD <aviral.goel@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
2025-10-16 03:10:57 -07:00

8.7 KiB
Raw Blame History

Fused-MoE with CK Tile

This example implements a highly optimized fused Mixture-of-Experts (MoE) block using the CK Tile programming model. The design fuses MoE sorting, group-GEMM, activation, and top-k weighting into a single kernel, minimizing memory traffic and maximizing throughput for large language models.


Algorithm and Math

MoE Block Structure

Given:

  • Input: X of shape [\text{tokens}, \text{hidden}]
  • TopK indices/weights: I, W from gating (shape [\text{tokens}, \text{topk}])
  • Expert weights: [\text{experts}, \text{hidden}, \text{hidden}]

Steps:

  1. MoE Sorting: Rearrange tokens so each expert receives its assigned tokens in contiguous blocks (see 13_moe_sorting).
  2. Group-GEMM: For each expert, perform GEMM on its assigned tokens:
    
    Y^{(e)} = X^{(e)} W^{(e)}
    
  3. Activation + TopK Weighting: Apply activation (e.g., GELU) and multiply by top-k weights.
  4. Scatter/Gather: Write results back to the original token order.

Technical Details

  • Scatter/Gather Group-GEMM: Uses indirect indexing to map tokens to experts and back.
  • Block Partitioning: Tokens are partitioned into slices per expert, with padding for alignment.
  • Atomic Accumulation: Second GEMM uses atomics for accumulation to support overlapping tokens.
  • Buffer Zeroing: Output buffer is zeroed in the sorting step, eliminating extra kernels.
  • Pre-shuffled Weights: Expert weights are pre-shuffled for coalesced memory access.
  • Micro-kernel Pipeline: Uses block-inline-asm micro-kernels for peak performance, while retaining composability.

Build & Run

mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_fused_moe -j
./bin/tile_example_fused_moe -?

Source Structure


Technical Notes

This is a scatter/gather-group-gemm based solution, similiar to that of vllm moe, but we introduce more kernel fusion to boost performance

The benifit of this fused-moe:

  • 1.5~2x perf boost compared with current vllm solution
  • zero workspace to reduce memory footprint
  • much less kernel instance, easy to maintain

Implementation and feature support

NOTES:

currently gate+up in fp16 case will very easily cause accumulator overflow the fp16 max(65504), hence result in INF. Please use BF16 for gate+up case, API side will have no check for this.

moe-sorting

this is a common pre-process step before the actual moe-gemm. The purpose is to transform the moe loop over from token-by-token to expert-by-expert, make sure very workgroup is working for a single expert (B matrix). Besides, we extend this op to do the zeroing of the output buffer(to be used for reduce buffer with atomic)

moe-gemm

moe-gemm is a group-gemm based back-to-back gemm, where the row-id of input token comes from another buffer. Naive understanding of fused-moe is from token-by-token view as below picture: After moe-sorting, we can view this algorithm as expert-by-expert, as below:

optimization

summary of the key design of this fused-moe operator:

  • fuse 2 group-gemm + activation + topk-weight multiply into single kernel, using atomic for 2nd gemm accumualation
  • fuse buffer-zeroing in moe-sorgin, user no longer need call extra torch.zero() for the out buffer
  • fused scatter-gather for row index(same as vllm)
  • pre-shuffle B matric(weight) to maximize memory throughput. input(activation) keep original layout [batch, hidden].
  • extrem optimized pipeline using block-inline-asm(we call it micro-kernel or uk), while not breaking the composable design of ck

// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
//                            tok-0      tok-1      tok-2      tok-3      tok-4
//           topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
//  (only for reference)    exp-0  exp-1     exp-2   exp-3          exp-4  exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr   : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
//                          |-  exp-0  -|-  exp-1  -|-  exp-2  -|-      exp-3          -|-  exp-4 -|-  exp-5  -|
// sorted_weight_ptr      : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr  : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
//   1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
//   2need sorted_weight_ptr
//   3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
//  1) sorted_token_ids_ptr [max_num_tokens_padded]
//  2) sorted_weight_ptr
//  3) sorted_expert_ids_ptr
//  4num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
//   max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)

example

args:
          -t    number of input tokens. (default:128)
                If "local_t" presents, this value indicates global concurrency of all ranks.
    -local_t    Number of local input tokens for curent rank. (default:-1)
                This value must be within range "[0, t)", or "-1"(no such feature)
                This feature is to simulate EP case where where each rank has different tokens.
                Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.
          -e    num of experts (default:32)
          -k    topk (default:5)
          -h    hidden_size of this model (default:8192)
          -i    intermediate_size between 2 gemms of FFN (default:8192)
     -stride    stride per row, if -1 then equal to hidden_size (default:-1)
         -bm    blocking factor for sorted tokens (default:32)
         -tp    tensor parallel size (default:8)
          -v    cpu validation or not (default:1)
      -kname    print kernel name or not (default:1)
     -prec_i    input precision (default:bf16)
     -prec_w    weight precision (default:bf16)
     -prec_o    output precision (default:bf16)
    -prec_st    token scale data type. auto will set to fp32 (default:auto)
    -prec_sw    weight scale data type. auto will set to fp32 (default:auto)
    -prec_sq    (dynamic) smooth quant data type. auto will set to fp32 (default:auto)
    -prec_kw    topk-weight data type. auto will set to fp32 (default:auto)
     -fquant    fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0)
  -gate_only    w0(gate/up) style, 0:gate+up will double interm size, 1:only gate (default:1)
        -api    benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm (default:0)
        -act    activation after first gemm. 0:gelu, 1:silu (default:0)
    -balance    if set to 1, will try balance the expert in topk-ids(convenient for testing) (default:0)
       -init    init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand normalized[0, 1]normalized(slow) (default:1)
       -seed    seed used to do random (default:11939)
     -warmup    cold iter (default:5)
     -repeat    hot iter (default:20)
       -json    0: No Json, 1: Dump Results in Json format (default:0)
   -jsonfile    json file name to dump results (default:fused_moe.json)

For distribution, see include/ck_tile/tile_program/tile_distribution/.


Back to CK Tile Examples