# fused-moe Implementing the fused-moe block operator using ck-tile. This is a scatter/gather-group-gemm based solution, similiar to that of [vllm moe](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), but we introduce more kernel fusion to boost performance ![](misc/moe-0.png) The benifit of this fused-moe: * 1.5~2x perf boost compared with current vllm solution * zero workspace to reduce memory footprint * much less kernel instance, easy to maintain # Implementation and feature support ## NOTES: currently gate+up in fp16 case will very easily cause accumulator overflow the fp16 max(65504), hence result in INF. Please use BF16 for gate+up case, API side will have no check for this. ## moe-sorting this is a common pre-process step before the actual moe-gemm. The purpose is to transform the moe loop over from token-by-token to expert-by-expert, make sure very workgroup is working for a single expert (B matrix). Besides, we extend this op to do the zeroing of the output buffer(to be used for reduce buffer with atomic) ## moe-gemm `moe-gemm` is a group-gemm based back-to-back gemm, where the row-id of input token comes from another buffer. Naive understanding of fused-moe is from token-by-token view as below picture: ![](misc/moe-1.png) After `moe-sorting`, we can view this algorithm as expert-by-expert, as below: ![](misc/moe-2.png) ## optimization summary of the key design of this fused-moe operator: * fuse 2 group-gemm + activation + `topk-weight` multiply into single kernel, using atomic for 2nd gemm accumualation * fuse buffer-zeroing in `moe-sorgin`, user no longer need call extra torch.zero() for the out buffer * fused scatter-gather for row index(same as vllm) * pre-shuffle B matric(weight) to maximize memory throughput. input(activation) keep original layout `[batch, hidden]`. * extrem optimized pipeline using block-inline-asm(we call it `micro-kernel` or `uk`), while not breaking the *composable* design of ck ## ``` // [indexing implementation-1] // using M_a as constexpr block_size to partition all tokens into different slices // each slice map to one expert, and one expert can have multiple slices // e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5 // before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]] // tok-0 tok-1 tok-2 tok-3 tok-4 // topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number) // // token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]] // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // // max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated) // * this could be larger than actual, since actual tokens are on GPU // // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] // |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -| // sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o] // // * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr // // sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5] // * length is (max_num_tokens_padded + block_size - 1) / block_size // // num_tokens_post_padded_ptr : [28] // num_sorted_tiles_ptr : [7] // // * different from vLLM // 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id // 2)need sorted_weight_ptr // 3) use num_sorted_tiles_ptr, already divided by M_a // // * below used for indexing // 1) sorted_token_ids_ptr [max_num_tokens_padded] // 2) sorted_weight_ptr // 3) sorted_expert_ids_ptr // 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one) // // max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1) ``` ## example ``` args: -t number of input tokens. (default:128) If "local_t" presents, this value indicates global concurrency of all ranks. -local_t Number of local input tokens for curent rank. (default:-1) This value must be within range "[0, t)", or "-1"(no such feature) This feature is to simulate EP case where where each rank has different tokens. Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph. -e num of experts (default:32) -k topk (default:5) -h hidden_size of this model (default:8192) -i intermediate_size between 2 gemms of FFN (default:8192) -stride stride per row, if -1 then equal to hidden_size (default:-1) -bm blocking factor for sorted tokens (default:32) -tp tensor parallel size (default:8) -v cpu validation or not (default:1) -kname print kernel name or not (default:1) -prec_i input precision (default:bf16) -prec_w weight precision (default:bf16) -prec_o output precision (default:bf16) -prec_st token scale data type. auto will set to fp32 (default:auto) -prec_sw weight scale data type. auto will set to fp32 (default:auto) -prec_sq (dynamic) smooth quant data type. auto will set to fp32 (default:auto) -prec_kw topk-weight data type. auto will set to fp32 (default:auto) -fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0) -gate_only w0(gate/up) style, 0:gate+up will double interm size, 1:only gate (default:1) -api benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm (default:0) -act activation after first gemm. 0:gelu, 1:silu (default:0) -balance if set to 1, will try balance the expert in topk-ids(convenient for testing) (default:0) -init init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand normalized[0, 1]normalized(slow) (default:1) -seed seed used to do random (default:11939) -warmup cold iter (default:5) -repeat hot iter (default:20) -json 0: No Json, 1: Dump Results in Json format (default:0) -jsonfile json file name to dump results (default:fused_moe.json) ```