mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
This commit is contained in:
22
example/ck_tile/15_fused_moe/CMakeLists.txt
Normal file
22
example/ck_tile/15_fused_moe/CMakeLists.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
set(TILE_EXAMPLE_FUSED_MOE "tile_example_fused_moe")
|
||||
message(DEBUG "adding ${TILE_EXAMPLE_FUSED_MOE}")
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
add_executable(${TILE_EXAMPLE_FUSED_MOE} main.cpp)
|
||||
target_include_directories(${TILE_EXAMPLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${TILE_EXAMPLE_FUSED_MOE} PRIVATE ${INSTANCE_SRCS})
|
||||
|
||||
set(TILE_EXAMPLE_FUSED_MOE_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND TILE_EXAMPLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
list(APPEND TILE_EXAMPLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1) # TODO: enable load to a
|
||||
list(APPEND TILE_EXAMPLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=4) # rta
|
||||
# list(APPEND TILE_EXAMPLE_FUSED_MOE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1)
|
||||
# list(APPEND TILE_EXAMPLE_FUSED_MOE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
|
||||
target_compile_options(${TILE_EXAMPLE_FUSED_MOE} PRIVATE ${TILE_EXAMPLE_FUSED_MOE_COMPILE_OPTIONS})
|
||||
endif()
|
||||
174
example/ck_tile/15_fused_moe/README.md
Normal file
174
example/ck_tile/15_fused_moe/README.md
Normal file
@@ -0,0 +1,174 @@
|
||||
# Fused-MoE with CK Tile
|
||||
|
||||
This example implements a highly optimized fused Mixture-of-Experts (MoE) block using the CK Tile programming model. The design fuses MoE sorting, group-GEMM, activation, and top-k weighting into a single kernel, minimizing memory traffic and maximizing throughput for large language models.
|
||||
|
||||
---
|
||||
|
||||
## Algorithm and Math
|
||||
|
||||
### MoE Block Structure
|
||||
|
||||
Given:
|
||||
- **Input**: $X$ of shape $[\text{tokens}, \text{hidden}]$
|
||||
- **TopK indices/weights**: $I, W$ from gating (shape $[\text{tokens}, \text{topk}]$)
|
||||
- **Expert weights**: $[\text{experts}, \text{hidden}, \text{hidden}]$
|
||||
|
||||
**Steps:**
|
||||
1. **MoE Sorting**: Rearrange tokens so each expert receives its assigned tokens in contiguous blocks (see [13_moe_sorting](../13_moe_sorting/README.md)).
|
||||
2. **Group-GEMM**: For each expert, perform GEMM on its assigned tokens:
|
||||
$$
|
||||
Y^{(e)} = X^{(e)} W^{(e)}
|
||||
$$
|
||||
3. **Activation + TopK Weighting**: Apply activation (e.g., GELU) and multiply by top-k weights.
|
||||
4. **Scatter/Gather**: Write results back to the original token order.
|
||||
|
||||
### Technical Details
|
||||
|
||||
- **Scatter/Gather Group-GEMM**: Uses indirect indexing to map tokens to experts and back.
|
||||
- **Block Partitioning**: Tokens are partitioned into slices per expert, with padding for alignment.
|
||||
- **Atomic Accumulation**: Second GEMM uses atomics for accumulation to support overlapping tokens.
|
||||
- **Buffer Zeroing**: Output buffer is zeroed in the sorting step, eliminating extra kernels.
|
||||
- **Pre-shuffled Weights**: Expert weights are pre-shuffled for coalesced memory access.
|
||||
- **Micro-kernel Pipeline**: Uses block-inline-asm micro-kernels for peak performance, while retaining composability.
|
||||
|
||||
|
||||
## Build & Run
|
||||
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
make tile_example_fused_moe -j
|
||||
./bin/tile_example_fused_moe -?
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Source Structure
|
||||
|
||||
- **Kernel**: [`fused_moe.hpp`](fused_moe.hpp), [`fused_moegemm.hpp`](fused_moegemm.hpp), [`fused_moesorting.hpp`](fused_moesorting.hpp)
|
||||
- **Executable**: [`main.cpp`](main.cpp)
|
||||
- **Build**: `CMakeLists.txt`, `instances/`, `misc/`
|
||||
|
||||
---
|
||||
|
||||
## Technical Notes
|
||||
|
||||
This is a scatter/gather-group-gemm based solution, similiar to that of [vllm moe](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), but we introduce more kernel fusion to boost performance
|
||||

|
||||
|
||||
The benifit of this fused-moe:
|
||||
* 1.5~2x perf boost compared with current vllm solution
|
||||
* zero workspace to reduce memory footprint
|
||||
* much less kernel instance, easy to maintain
|
||||
|
||||
# Implementation and feature support
|
||||
## NOTES:
|
||||
currently gate+up in fp16 case will very easily cause accumulator overflow the fp16 max(65504), hence result in INF. Please use BF16 for gate+up case, API side will have no check for this.
|
||||
|
||||
## moe-sorting
|
||||
this is a common pre-process step before the actual moe-gemm. The purpose is to transform the moe loop over from token-by-token to expert-by-expert, make sure very workgroup is working for a single expert (B matrix). Besides, we extend this op to do the zeroing of the output buffer(to be used for reduce buffer with atomic)
|
||||
|
||||
## moe-gemm
|
||||
`moe-gemm` is a group-gemm based back-to-back gemm, where the row-id of input token comes from another buffer. Naive understanding of fused-moe is from token-by-token view as below picture:
|
||||

|
||||
After `moe-sorting`, we can view this algorithm as expert-by-expert, as below:
|
||||

|
||||
|
||||
## optimization
|
||||
summary of the key design of this fused-moe operator:
|
||||
* fuse 2 group-gemm + activation + `topk-weight` multiply into single kernel, using atomic for 2nd gemm accumualation
|
||||
* fuse buffer-zeroing in `moe-sorgin`, user no longer need call extra torch.zero() for the out buffer
|
||||
* fused scatter-gather for row index(same as vllm)
|
||||
* pre-shuffle B matric(weight) to maximize memory throughput. input(activation) keep original layout `[batch, hidden]`.
|
||||
* extrem optimized pipeline using block-inline-asm(we call it `micro-kernel` or `uk`), while not breaking the *composable* design of ck
|
||||
|
||||
##
|
||||
```
|
||||
// [indexing implementation-1]
|
||||
// using M_a as constexpr block_size to partition all tokens into different slices
|
||||
// each slice map to one expert, and one expert can have multiple slices
|
||||
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
|
||||
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
|
||||
// tok-0 tok-1 tok-2 tok-3 tok-4
|
||||
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
|
||||
//
|
||||
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
|
||||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
|
||||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
|
||||
//
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
|
||||
// * this could be larger than actual, since actual tokens are on GPU
|
||||
//
|
||||
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
|
||||
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
|
||||
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
|
||||
//
|
||||
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
|
||||
//
|
||||
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
|
||||
// * length is (max_num_tokens_padded + block_size - 1) / block_size
|
||||
//
|
||||
// num_tokens_post_padded_ptr : [28]
|
||||
// num_sorted_tiles_ptr : [7]
|
||||
//
|
||||
// * different from vLLM
|
||||
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
|
||||
// 2)need sorted_weight_ptr
|
||||
// 3) use num_sorted_tiles_ptr, already divided by M_a
|
||||
//
|
||||
// * below used for indexing
|
||||
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
|
||||
// 2) sorted_weight_ptr
|
||||
// 3) sorted_expert_ids_ptr
|
||||
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
|
||||
//
|
||||
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
|
||||
```
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-t number of input tokens. (default:128)
|
||||
If "local_t" presents, this value indicates global concurrency of all ranks.
|
||||
-local_t Number of local input tokens for curent rank. (default:-1)
|
||||
This value must be within range "[0, t)", or "-1"(no such feature)
|
||||
This feature is to simulate EP case where where each rank has different tokens.
|
||||
Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.
|
||||
-e num of experts (default:32)
|
||||
-k topk (default:5)
|
||||
-h hidden_size of this model (default:8192)
|
||||
-i intermediate_size between 2 gemms of FFN (default:8192)
|
||||
-stride stride per row, if -1 then equal to hidden_size (default:-1)
|
||||
-bm blocking factor for sorted tokens (default:32)
|
||||
-tp tensor parallel size (default:8)
|
||||
-v cpu validation or not (default:1)
|
||||
-kname print kernel name or not (default:1)
|
||||
-prec_i input precision (default:bf16)
|
||||
-prec_w weight precision (default:bf16)
|
||||
-prec_o output precision (default:bf16)
|
||||
-prec_st token scale data type. auto will set to fp32 (default:auto)
|
||||
-prec_sw weight scale data type. auto will set to fp32 (default:auto)
|
||||
-prec_sq (dynamic) smooth quant data type. auto will set to fp32 (default:auto)
|
||||
-prec_kw topk-weight data type. auto will set to fp32 (default:auto)
|
||||
-fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0)
|
||||
-gate_only w0(gate/up) style, 0:gate+up will double interm size, 1:only gate (default:1)
|
||||
-api benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm (default:0)
|
||||
-act activation after first gemm. 0:gelu, 1:silu (default:0)
|
||||
-balance if set to 1, will try balance the expert in topk-ids(convenient for testing) (default:0)
|
||||
-init init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand normalized[0, 1]normalized(slow) (default:1)
|
||||
-seed seed used to do random (default:11939)
|
||||
-warmup cold iter (default:5)
|
||||
-repeat hot iter (default:20)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:fused_moe.json)
|
||||
```
|
||||
## Related CK Tile Examples
|
||||
|
||||
- [13_moe_sorting](../13_moe_sorting/README.md): MoE sorting for expert dispatch
|
||||
- [09_topk_softmax](../09_topk_softmax/README.md): TopK-Softmax for MoE gating
|
||||
- [03_gemm](../03_gemm/README.md): GEMM with tiles
|
||||
|
||||
For distribution, see [`include/ck_tile/tile_program/tile_distribution/`](../../../include/ck_tile/tile_program/tile_distribution/).
|
||||
|
||||
---
|
||||
[Back to CK Tile Examples](../README.md)
|
||||
62
example/ck_tile/15_fused_moe/fused_moe.hpp
Normal file
62
example/ck_tile/15_fused_moe/fused_moe.hpp
Normal file
@@ -0,0 +1,62 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fused_moesorting.hpp"
|
||||
#include "fused_moegemm.hpp"
|
||||
|
||||
struct fused_moe_args
|
||||
{
|
||||
const void* a_ptr; // [m, k], input token
|
||||
const void* a_scale_ptr; // [m, 1], token scale
|
||||
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
|
||||
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
|
||||
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
|
||||
const void* d_scale_ptr; // [e, 1, k], down scale
|
||||
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
|
||||
const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP
|
||||
const void* local_tokens; // [1] if not nullptr, tokens read from here
|
||||
void* o_ptr; // [m, k], output token (no need to do zeroing)
|
||||
void* ws_ptr; // size is moe_sorting_get_workspace_size()
|
||||
// if return zero, then could be nullptr
|
||||
// must be cleard before use
|
||||
|
||||
const void* topk_ids_ptr; // [tokens, topk]
|
||||
const void* topk_weight_ptr; // [tokens, topk]
|
||||
void* sorted_token_ids_ptr; // [max_num_tokens_padded]
|
||||
void* sorted_weight_ptr; // [max_num_tokens_padded]
|
||||
void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size]
|
||||
void* num_sorted_tiles_ptr; // [1]
|
||||
|
||||
ck_tile::index_t block_m; // block_m, used to devide the input
|
||||
ck_tile::index_t hidden_size; // k
|
||||
ck_tile::index_t intermediate_size; // n / TP, for Gate. and Up, Down is also this value
|
||||
ck_tile::index_t num_tokens; // input number of tokens for current iteration
|
||||
ck_tile::index_t num_experts; // number of groups
|
||||
ck_tile::index_t topk; // need this?
|
||||
|
||||
ck_tile::index_t stride_token; // for input/output, stride for each row, should >= hidden_size
|
||||
};
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fused_moe_traits
|
||||
{
|
||||
std::string prec_i; // input precision
|
||||
std::string prec_w; // weight precision
|
||||
std::string prec_o; // output precision
|
||||
std::string prec_st; // token scale data type
|
||||
std::string prec_sw; // weight scale data type
|
||||
std::string prec_sq; // smooth quant scale
|
||||
std::string prec_kw; // topk-weight data type
|
||||
int block_m;
|
||||
int activation; // 0:gelu, 1:silu
|
||||
int gate_only; // 0:g1u0, 1:g1u1
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
|
||||
bool local_expert_masking; // if mask experts as local expert
|
||||
};
|
||||
|
||||
// if return zero, no ws needed
|
||||
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk);
|
||||
float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&);
|
||||
85
example/ck_tile/15_fused_moe/fused_moegemm.hpp
Normal file
85
example/ck_tile/15_fused_moe/fused_moegemm.hpp
Normal file
@@ -0,0 +1,85 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/fused_moe.hpp"
|
||||
#include <string>
|
||||
|
||||
// this is only a convenient structure for creating an example
|
||||
// this is not part of the host API
|
||||
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FusedMoeGemmTypeConfig;
|
||||
|
||||
template <typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using GDataType = ck_tile::bf16_t;
|
||||
using DDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
};
|
||||
|
||||
template <typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FusedMoeGemmTypeConfig<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, ST, SW, SQ, KW>
|
||||
{
|
||||
using ADataType = ck_tile::fp16_t;
|
||||
using GDataType = ck_tile::fp16_t;
|
||||
using DDataType = ck_tile::fp16_t;
|
||||
using AccDataType = float;
|
||||
using ODataType = ck_tile::fp16_t;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
};
|
||||
|
||||
template <typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW>
|
||||
{
|
||||
using ADataType = ck_tile::int8_t;
|
||||
using GDataType = ck_tile::int8_t;
|
||||
using DDataType = ck_tile::int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
};
|
||||
|
||||
// runtime args
|
||||
struct fused_moegemm_args : public ck_tile::FusedMoeGemmHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fused_moegemm_traits
|
||||
{
|
||||
std::string prec_i; // input precision
|
||||
std::string prec_w; // weight precision
|
||||
std::string prec_o; // output precision
|
||||
std::string prec_st; // token scale data type
|
||||
std::string prec_sw; // weight scale data type
|
||||
std::string prec_sq; // smooth quant scale
|
||||
std::string prec_kw; // topk-weight data type
|
||||
int block_m;
|
||||
int activation; // 0:gelu, 1:silu
|
||||
int gate_only; // 0:g1u0, 1:g1u1
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
float fused_moegemm(fused_moegemm_traits, fused_moegemm_args, const ck_tile::stream_config&);
|
||||
22
example/ck_tile/15_fused_moe/fused_moesorting.hpp
Normal file
22
example/ck_tile/15_fused_moe/fused_moesorting.hpp
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/fused_moe.hpp"
|
||||
|
||||
struct fused_moesorting_trait
|
||||
{
|
||||
std::string index_type;
|
||||
std::string weight_type; // currently always float
|
||||
bool local_expert_masking; // if mask experts as local expert
|
||||
};
|
||||
|
||||
struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk);
|
||||
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s);
|
||||
97
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
Normal file
97
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
Normal file
@@ -0,0 +1,97 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "fused_moe.hpp"
|
||||
#include "ck_tile/ops/fused_moe.hpp"
|
||||
|
||||
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(
|
||||
tokens, num_experts, topk, 0 /*dispatch policy*/);
|
||||
}
|
||||
|
||||
float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s)
|
||||
{
|
||||
auto s_sub = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1};
|
||||
|
||||
auto o_data_bytes = [&]() {
|
||||
if(t.prec_o == "fp32")
|
||||
return 4;
|
||||
else if(t.prec_o == "fp16" || t.prec_o == "bf16")
|
||||
return 2;
|
||||
else if(t.prec_o == "int8" || t.prec_o == "fp8")
|
||||
return 1;
|
||||
return 1;
|
||||
}();
|
||||
|
||||
auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking};
|
||||
auto a0 = fused_moesorting_args{
|
||||
a.topk_ids_ptr, // const void* p_topk_ids;
|
||||
a.topk_weight_ptr, // const void* p_weights;
|
||||
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
|
||||
a.local_tokens,
|
||||
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
|
||||
a.sorted_weight_ptr, // void* p_sorted_weights;
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
|
||||
a.o_ptr, // void* p_moe_buf;
|
||||
a.ws_ptr, // void* p_ws;
|
||||
a.num_tokens, // index_t tokens;
|
||||
a.block_m, // index_t unit_size;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
a.stride_token,
|
||||
o_data_bytes,
|
||||
#else
|
||||
static_cast<ck_tile::long_index_t>(a.num_tokens) * a.stride_token *
|
||||
o_data_bytes // index_t moe_buf_bytes;
|
||||
#endif
|
||||
};
|
||||
|
||||
auto t1 = fused_moegemm_traits{t.prec_i,
|
||||
t.prec_w,
|
||||
t.prec_o,
|
||||
t.prec_st,
|
||||
t.prec_sw,
|
||||
t.prec_sq,
|
||||
t.prec_kw,
|
||||
t.block_m,
|
||||
t.activation,
|
||||
t.gate_only,
|
||||
t.fused_quant};
|
||||
auto a1 = fused_moegemm_args{
|
||||
a.a_ptr, // const void* a_ptr;
|
||||
a.a_scale_ptr, // const void* a_scale_ptr;
|
||||
a.g_ptr, // const void* g_ptr;
|
||||
a.d_ptr, // const void* d_ptr;
|
||||
a.g_scale_ptr, // const void* g_scale_ptr;
|
||||
a.d_scale_ptr, // const void* d_scale_ptr;
|
||||
a.y_smooth_scale_ptr, // const void* y_smooth_scale_ptr;
|
||||
a.o_ptr, // void* o_ptr;
|
||||
a.sorted_token_ids_ptr, // const void* sorted_token_ids_ptr;
|
||||
a.sorted_weight_ptr, // const void* sorted_weight_ptr;
|
||||
a.sorted_expert_ids_ptr, // const void* sorted_expert_ids_ptr;
|
||||
a.num_sorted_tiles_ptr, // const void* num_sorted_tiles_ptr;
|
||||
a.hidden_size, // index_t hidden_size;
|
||||
a.intermediate_size, // index_t intermediate_size;
|
||||
a.num_tokens, // index_t num_tokens;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
a.stride_token // index_t stride_token;
|
||||
};
|
||||
|
||||
float r0 = -1;
|
||||
float r1 = -1;
|
||||
|
||||
float r = ck_tile::launch_kernel(
|
||||
s,
|
||||
[=, &r0](const ck_tile::stream_config&) { r0 = fused_moesorting(t0, a0, s_sub); },
|
||||
[=, &r1](const ck_tile::stream_config&) { r1 = fused_moegemm(t1, a1, s_sub); });
|
||||
|
||||
// keep unsupported case return negative
|
||||
if(r0 < 0 || r1 < 0)
|
||||
return -1;
|
||||
|
||||
return r;
|
||||
}
|
||||
85
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
Normal file
85
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
Normal file
@@ -0,0 +1,85 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "fused_moegemm.hpp"
|
||||
#include "fused_moegemm_api_traits.hpp"
|
||||
|
||||
// Note: this internal API only declare, not define here, otherwise will block `make -j`
|
||||
template <typename Traits_>
|
||||
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template <ck_tile::index_t... Is>
|
||||
using S = ck_tile::sequence<Is...>;
|
||||
|
||||
float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile::stream_config& s)
|
||||
{
|
||||
// clang-format off
|
||||
float r = -1;
|
||||
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 0;
|
||||
constexpr ck_tile::index_t go_ = 1;
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 0;
|
||||
constexpr ck_tile::index_t go_ = 0;
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 0;
|
||||
constexpr ck_tile::index_t go_ = 1;
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 0;
|
||||
constexpr ck_tile::index_t go_ = 0;
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 1;
|
||||
constexpr ck_tile::index_t go_ = 1;
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 1;
|
||||
constexpr ck_tile::index_t go_ = 0;
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 1;
|
||||
constexpr ck_tile::index_t go_ = 1;
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 1;
|
||||
constexpr ck_tile::index_t go_ = 0;
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
// clang-format on
|
||||
return r;
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fused_moegemm_api_traits.hpp"
|
||||
#include "ck_tile/ops/fused_moe.hpp"
|
||||
#include <iostream>
|
||||
|
||||
template <ck_tile::index_t... Is>
|
||||
using S = ck_tile::sequence<Is...>;
|
||||
|
||||
// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j
|
||||
template <typename Ts_>
|
||||
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
|
||||
{
|
||||
using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>;
|
||||
using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0,
|
||||
typename Ts_::BlockTile_1,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0>;
|
||||
|
||||
constexpr auto get_activation_ = []() {
|
||||
if constexpr(Ts_::Activation == 0)
|
||||
{
|
||||
return ck_tile::element_wise::FastGeluAsm{};
|
||||
}
|
||||
else
|
||||
return ck_tile::element_wise::Silu{};
|
||||
};
|
||||
using f_act_ = ck_tile::remove_cvref_t<decltype(get_activation_())>;
|
||||
|
||||
using f_problem = ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
|
||||
typename Ts_::GDataType,
|
||||
typename Ts_::DDataType,
|
||||
typename Ts_::AccDataType,
|
||||
typename Ts_::ODataType,
|
||||
typename Ts_::AScaleDataType,
|
||||
typename Ts_::GScaleDataType,
|
||||
typename Ts_::DScaleDataType,
|
||||
typename Ts_::YSmoothScaleDataType,
|
||||
typename Ts_::TopkWeightDataType,
|
||||
typename Ts_::IndexDataType,
|
||||
f_act_, // TODO: hardcoded
|
||||
f_shape,
|
||||
f_traits>;
|
||||
|
||||
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
|
||||
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
|
||||
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
|
||||
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
|
||||
|
||||
const dim3 grids = f_kernel::GridSize(a);
|
||||
const dim3 blocks = f_kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
static int printed = 0;
|
||||
|
||||
auto kargs = f_kernel::MakeKargs(a);
|
||||
if(s.log_level_ > 0 && printed == 0)
|
||||
{
|
||||
std::cout << ", " << f_kernel::GetName() << std::flush;
|
||||
printed = 1;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(f_kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <typename I,
|
||||
typename W,
|
||||
typename O,
|
||||
typename ST,
|
||||
typename SW,
|
||||
typename SQ,
|
||||
typename KW,
|
||||
typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down>
|
||||
typename WarpPerBlock_,
|
||||
typename WarpTile_, // seq<*,*,*>, used to select mfma
|
||||
ck_tile::index_t Activation_ = 0, // 0: Gelu 1: Silu
|
||||
ck_tile::index_t GateOnly_ = 0,
|
||||
ck_tile::index_t FusedQuant_ = 0>
|
||||
struct fmoe_ // traits, ugly name, only used for internal
|
||||
{
|
||||
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
|
||||
|
||||
using ADataType = ck_tile::remove_cvref_t<typename TypeConfig::ADataType>;
|
||||
using GDataType = ck_tile::remove_cvref_t<typename TypeConfig::GDataType>;
|
||||
using DDataType = ck_tile::remove_cvref_t<typename TypeConfig::DDataType>;
|
||||
using AccDataType = ck_tile::remove_cvref_t<typename TypeConfig::AccDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename TypeConfig::ODataType>;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::AScaleDataType>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::GScaleDataType>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::DScaleDataType>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
|
||||
using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>;
|
||||
|
||||
static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token
|
||||
static constexpr ck_tile::index_t BI_ =
|
||||
BlockTIle_::at(ck_tile::number<1>{}); // block intermediate
|
||||
static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden
|
||||
static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down
|
||||
|
||||
using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>;
|
||||
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
|
||||
|
||||
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_>;
|
||||
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
|
||||
|
||||
static constexpr ck_tile::index_t Activation = Activation_; // 0: Gelu 1: Silu
|
||||
static constexpr ck_tile::index_t GateOnly = GateOnly_;
|
||||
static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
|
||||
};
|
||||
@@ -0,0 +1,25 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "fused_moegemm.hpp"
|
||||
#include "fused_moegemm_api_traits.hpp"
|
||||
#include "fused_moegemm_api_internal.hpp"
|
||||
|
||||
// clang-format off
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,26 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "fused_moegemm.hpp"
|
||||
#include "fused_moegemm_api_traits.hpp"
|
||||
#include "fused_moegemm_api_internal.hpp"
|
||||
|
||||
// clang-format off
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
// clang-format on
|
||||
549
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
Normal file
549
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
Normal file
@@ -0,0 +1,549 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "fused_moesorting.hpp"
|
||||
|
||||
#ifndef MOE_SORTING_USE_EX_KERNEL
|
||||
#define MOE_SORTING_USE_EX_KERNEL 1
|
||||
#endif
|
||||
|
||||
#ifndef MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
#define MOE_SORTING_SUPPORT_LARGE_EXPERT 0
|
||||
#endif
|
||||
|
||||
#ifndef MOE_SORTING_SUPPORT_LARGE_TOPK
|
||||
#define MOE_SORTING_SUPPORT_LARGE_TOPK 0
|
||||
#endif
|
||||
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
|
||||
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr ck_tile::index_t expert_tile = expert_tile_; \
|
||||
using ms_problem = \
|
||||
ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_bytes = kernel::GetSmemSize(a); \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#else
|
||||
|
||||
#define MOE_SORTING_DISPATCH_( \
|
||||
sub_token_tile_, sub_token_onshot_, local_expert_masking_, local_token_) \
|
||||
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
|
||||
constexpr bool sub_token_onshot = sub_token_onshot_; \
|
||||
constexpr bool local_expert_masking = local_expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
|
||||
ms_weight_type, \
|
||||
sub_token_tile, \
|
||||
sub_token_onshot, \
|
||||
local_expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_bytes = kernel::GetSmemSize(a); \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUB_TOKEN_( \
|
||||
row_, sub_token_onshot_, local_expert_masking_, local_token_) \
|
||||
if(row_ % 8 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else if(row_ % 4 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else if(row_ % 2 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, true) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, false) \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
|
||||
if(is_sub_token_onshot) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, true, local_expert_masking_) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, false, local_expert_masking_) \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
|
||||
if(is_local_expert_masking) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
#define MOE_SORTING_DISPATCH(unroll_num_) \
|
||||
if(a.num_experts <= 8) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
|
||||
} \
|
||||
else if(a.num_experts <= 16) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
|
||||
} \
|
||||
else if(a.num_experts <= 32) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
|
||||
} \
|
||||
else if(a.num_experts <= 64) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
|
||||
}
|
||||
#endif
|
||||
|
||||
float fused_moesorting_mp(fused_moesorting_trait t,
|
||||
fused_moesorting_args a,
|
||||
ck_tile::stream_config s);
|
||||
|
||||
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
if(a.num_experts > 127)
|
||||
{
|
||||
printf("lds size exceed, only support experts <127 \n");
|
||||
return -1;
|
||||
}
|
||||
if(a.moe_buf_bytes % 16)
|
||||
{
|
||||
printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_bytes);
|
||||
return -1;
|
||||
}
|
||||
using index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64);
|
||||
switch(smem_io_unroll_num)
|
||||
{
|
||||
case(1): {
|
||||
MOE_SORTING_DISPATCH(1);
|
||||
}
|
||||
case(2): {
|
||||
MOE_SORTING_DISPATCH(2);
|
||||
}
|
||||
case(3): {
|
||||
MOE_SORTING_DISPATCH(3);
|
||||
}
|
||||
case(5): {
|
||||
MOE_SORTING_DISPATCH(5);
|
||||
}
|
||||
case(6): {
|
||||
MOE_SORTING_DISPATCH(6);
|
||||
}
|
||||
case(8): {
|
||||
MOE_SORTING_DISPATCH(8);
|
||||
}
|
||||
case(10): {
|
||||
MOE_SORTING_DISPATCH(10);
|
||||
}
|
||||
default: {
|
||||
MOE_SORTING_DISPATCH(4);
|
||||
}
|
||||
}
|
||||
#else
|
||||
if(fused_moe_get_workspace_size(a.tokens, a.num_experts, a.topk) != 0)
|
||||
{
|
||||
return fused_moesorting_mp(t, a, s);
|
||||
}
|
||||
using index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
auto sub_token_ = ck_tile::moe_sorting_get_sub_token(a.tokens, a.num_experts);
|
||||
auto row_ = sub_token_ / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
bool is_local_expert_masking = t.local_expert_masking;
|
||||
bool is_local_token = a.p_local_tokens != nullptr;
|
||||
|
||||
MOE_SORTING_DISPATCH_EMASK_(row_);
|
||||
// MOE_SORTING_DISPATCH_ETILE(0, 0);
|
||||
#endif
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
#define MOE_SORTING_MP_0_V1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel<kernel::kBlockSize>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_0_V2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
#endif
|
||||
|
||||
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_size = kernel::GetSmemSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \
|
||||
}()
|
||||
|
||||
#define MOR_SORTING_MP_DISPATCH_SMALL_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, \
|
||||
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, \
|
||||
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
}
|
||||
|
||||
float fused_moesorting_mp(fused_moesorting_trait t,
|
||||
fused_moesorting_args a,
|
||||
ck_tile::stream_config s)
|
||||
{
|
||||
bool is_local_token = a.p_local_tokens != nullptr;
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
using ms_index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
|
||||
if(a.tokens < 2048)
|
||||
{
|
||||
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
|
||||
ck_tile::get_smem_capacity())
|
||||
{
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
if(t.local_expert_masking)
|
||||
{
|
||||
float ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
MOE_SORTING_MP_0_V2(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, true));
|
||||
return ave_time;
|
||||
}
|
||||
else
|
||||
{
|
||||
float ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
MOE_SORTING_MP_0_V2(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, false));
|
||||
return ave_time;
|
||||
}
|
||||
#else
|
||||
printf("do not support large expert %d\n", a.num_experts);
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::index_t mesh_byte_size =
|
||||
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
|
||||
if(mesh_byte_size == 1)
|
||||
{
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 4, 16, 16)
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 1, 16, 16)
|
||||
}
|
||||
}
|
||||
else if(mesh_byte_size == 2)
|
||||
{
|
||||
#if MOE_SORTING_SUPPORT_LARGE_TOPK
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 4, 8, 8)
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 1, 8, 8)
|
||||
}
|
||||
#else
|
||||
printf("do not support large topk %d\n", a.topk);
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_SMALL_(ck_tile::index_t, 1, 1, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
|
||||
ck_tile::get_smem_capacity())
|
||||
{
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
if(t.local_expert_masking)
|
||||
{
|
||||
float ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0_V1(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, true));
|
||||
return ave_time;
|
||||
}
|
||||
else
|
||||
{
|
||||
float ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0_V1(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_3(ms_index_t, 1, false));
|
||||
return ave_time;
|
||||
}
|
||||
#else
|
||||
printf("do not support large expert %d\n", a.num_experts);
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::index_t mesh_byte_size =
|
||||
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
|
||||
if(mesh_byte_size == 1)
|
||||
{
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
|
||||
}
|
||||
}
|
||||
else if(mesh_byte_size == 2)
|
||||
{
|
||||
#if MOE_SORTING_SUPPORT_LARGE_TOPK
|
||||
if(a.tokens * a.topk % 4 == 0)
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
|
||||
}
|
||||
#else
|
||||
printf("do not support large topk %d\n", a.topk);
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int fused_moesorting_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(
|
||||
tokens, num_experts, topk, 0 /*dispatch policy*/);
|
||||
}
|
||||
713
example/ck_tile/15_fused_moe/main.cpp
Normal file
713
example/ck_tile/15_fused_moe/main.cpp
Normal file
@@ -0,0 +1,713 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include "fused_moe.hpp"
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
// mfma_type, 0:32x32, 1:16x16
|
||||
// TODO: padding?
|
||||
template <typename T>
|
||||
auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type = 0)
|
||||
{
|
||||
assert(t.get_lengths().size() == 3);
|
||||
int b_ = t.get_lengths()[0];
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[2];
|
||||
if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 16, 2, 8});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 32, 4, 8});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 32, 2, 16});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 64, 4, 16});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
template <typename IndexType>
|
||||
void topid_unique_gen(
|
||||
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
|
||||
{
|
||||
size_t total_size = topk * tokens;
|
||||
std::srand(seed);
|
||||
std::set<IndexType> unique_set;
|
||||
IndexType current_v;
|
||||
for(size_t i = 0; i < total_size; i++)
|
||||
{
|
||||
if(i % topk == 0)
|
||||
{
|
||||
unique_set.clear();
|
||||
}
|
||||
current_v = std::rand() % num_expert;
|
||||
while(unique_set.find(current_v) != unique_set.end())
|
||||
{
|
||||
current_v = std::rand() % num_expert;
|
||||
}
|
||||
unique_set.insert(current_v);
|
||||
host_tensor[i] = current_v;
|
||||
}
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser
|
||||
.insert("t",
|
||||
"128",
|
||||
"number of input tokens.\n"
|
||||
"If \"local_t\" presents, this value indicates global concurrency of all ranks.")
|
||||
.insert(
|
||||
"local_t",
|
||||
"-1",
|
||||
"Number of local input tokens for curent rank.\n"
|
||||
"This value must be within range \"[0, t)\", or \"-1\"(no such feature)\n"
|
||||
"This feature is to simulate EP case where where each rank has different tokens.\n"
|
||||
"Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.")
|
||||
.insert("e", "32", "num of experts")
|
||||
.insert("k", "5", "topk")
|
||||
.insert("h", "8192", "hidden_size of this model")
|
||||
.insert("i", "8192", "intermediate_size between 2 gemms of FFN")
|
||||
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
|
||||
.insert("bm", "32", "blocking factor for sorted tokens")
|
||||
.insert("tp", "8", "tensor parallel size")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("kname", "1", "print kernel name or not")
|
||||
.insert("prec_i", "bf16", "input precision")
|
||||
.insert("prec_w", "bf16", "weight precision")
|
||||
.insert("prec_o", "bf16", "output precision")
|
||||
.insert("prec_st", "auto", "token scale data type. auto will set to fp32")
|
||||
.insert("prec_sw", "auto", "weight scale data type. auto will set to fp32")
|
||||
.insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32")
|
||||
.insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32")
|
||||
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
|
||||
.insert(
|
||||
"gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
|
||||
.insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm")
|
||||
.insert("act", "0", "activation after first gemm. 0:gelu, 1:silu")
|
||||
.insert("balance",
|
||||
"0",
|
||||
"if set to 1, will try balance the expert in topk-ids(convenient for testing)")
|
||||
.insert("init",
|
||||
"1",
|
||||
"init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand "
|
||||
"normalized[0, 1]"
|
||||
"normalized(slow)")
|
||||
.insert("seed", "11939", "seed used to do random")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "fused_moe.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type,
|
||||
// SQ:smooth-quant-type, KW:topk-weight-type
|
||||
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t tokens = arg_parser.get_int("t");
|
||||
ck_tile::index_t local_tokens = arg_parser.get_int("local_t");
|
||||
ck_tile::index_t experts = arg_parser.get_int("e");
|
||||
ck_tile::index_t topk = arg_parser.get_int("k");
|
||||
ck_tile::index_t hidden_size = arg_parser.get_int("h");
|
||||
ck_tile::index_t intermediate_size = arg_parser.get_int("i");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
ck_tile::index_t block_m = arg_parser.get_int("bm");
|
||||
ck_tile::index_t activation = arg_parser.get_int("act");
|
||||
if(stride < 0)
|
||||
stride = hidden_size;
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_w = arg_parser.get_str("prec_w");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_st = arg_parser.get_str("prec_st");
|
||||
std::string prec_sw = arg_parser.get_str("prec_sw");
|
||||
std::string prec_sq = arg_parser.get_str("prec_sq");
|
||||
std::string prec_kw = arg_parser.get_str("prec_kw");
|
||||
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
|
||||
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
|
||||
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
|
||||
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
int gate_only = arg_parser.get_int("gate_only");
|
||||
int api = arg_parser.get_int("api");
|
||||
int balance = arg_parser.get_int("balance");
|
||||
int tp = arg_parser.get_int("tp");
|
||||
int init = arg_parser.get_int("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
bool local_expert_masking = false; // TODO...
|
||||
|
||||
// w0 (Gate+Up or Gate only, N size)
|
||||
ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp;
|
||||
// w1 (Down, N size)
|
||||
ck_tile::index_t shared_intermediate_size_1 = intermediate_size / tp;
|
||||
|
||||
bool is_local_token = local_tokens >= 0 && local_tokens < tokens;
|
||||
|
||||
if(local_tokens > tokens)
|
||||
{
|
||||
printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens);
|
||||
return false;
|
||||
}
|
||||
|
||||
auto prec_str = [&]() {
|
||||
auto base_str = prec_i;
|
||||
if(prec_i != prec_w)
|
||||
base_str += "x" + prec_w;
|
||||
if(prec_i != prec_o)
|
||||
base_str += "=" + prec_o;
|
||||
if(fused_quant != 0)
|
||||
{
|
||||
base_str += std::string("(") + prec_st + "|" + prec_sw + "|" + prec_sq + ")";
|
||||
}
|
||||
return base_str;
|
||||
}();
|
||||
auto api_str = [&]() {
|
||||
if(api == 0)
|
||||
return std::string("fmoe");
|
||||
else if(api == 1)
|
||||
return std::string("moeg");
|
||||
else if(api == 2)
|
||||
return std::string("moes");
|
||||
return std::string("");
|
||||
}();
|
||||
|
||||
auto stride_str = [&]() {
|
||||
if(stride == hidden_size)
|
||||
return std::string("");
|
||||
else
|
||||
return std::string(", st:") + std::to_string(stride);
|
||||
}();
|
||||
|
||||
std::cout << "[" << api_str << "|" << prec_str << "]" << " t:" << tokens;
|
||||
|
||||
if(is_local_token)
|
||||
{
|
||||
std::cout << "(" << local_tokens << ")";
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< ", e:" << experts << ", k:" << topk << stride_str << ", hidden:" << hidden_size
|
||||
<< ", interm:" << intermediate_size << ", tp:" << tp << ", act:"
|
||||
<< activation
|
||||
// << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
|
||||
<< (gate_only ? ", g1u0" : ", g1u1") << ", q:" << fused_quant << std::flush;
|
||||
|
||||
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using GDataType = typename TypeConfig::GDataType;
|
||||
using DDataType = typename TypeConfig::DDataType;
|
||||
using AccDataType = typename TypeConfig::AccDataType;
|
||||
using ODataType = typename TypeConfig::ODataType;
|
||||
using AScaleDataType = typename TypeConfig::AScaleDataType;
|
||||
using GScaleDataType = typename TypeConfig::GScaleDataType;
|
||||
using DScaleDataType = typename TypeConfig::DScaleDataType;
|
||||
using YSmoothScaleDataType = typename TypeConfig::YSmoothScaleDataType;
|
||||
using TopkWeightDataType = typename TypeConfig::TopkWeightDataType;
|
||||
using IndexDataType = typename TypeConfig::IndexDataType;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1});
|
||||
ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size_0, hidden_size});
|
||||
ck_tile::HostTensor<DDataType> d_host({experts, hidden_size, shared_intermediate_size_1});
|
||||
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
|
||||
ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
|
||||
ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size_0});
|
||||
ck_tile::HostTensor<DScaleDataType> sd_host({shared_intermediate_size_1});
|
||||
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({shared_intermediate_size_1}); // smooth-quant
|
||||
ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort
|
||||
ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort
|
||||
ck_tile::HostTensor<IndexDataType> local_expert_mask_host({experts});
|
||||
|
||||
int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
|
||||
ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<TopkWeightDataType> sorted_weight_host({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<IndexDataType> sorted_expert_ids_host(
|
||||
{(max_num_tokens_padded + block_m - 1) / block_m});
|
||||
ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1});
|
||||
|
||||
if(init == 0)
|
||||
{
|
||||
ck_tile::FillStepRange<ADataType>{-.5f, .5f, 0.01f}(a_host);
|
||||
ck_tile::FillStepRange<GDataType>{-.5f, .5f, 0.01f}(g_host);
|
||||
ck_tile::FillStepRange<DDataType, false>{.5f, -.5f, -0.01f}(d_host);
|
||||
ck_tile::FillStepRange<AScaleDataType>{0.f, 1.f, 0.01f}(sa_host);
|
||||
ck_tile::FillStepRange<GScaleDataType>{0.f, 1.f, 0.01f}(sg_host);
|
||||
ck_tile::FillStepRange<DScaleDataType>{0.f, 1.f, 0.01f}(sd_host);
|
||||
ck_tile::FillStepRange<YSmoothScaleDataType>{0.f, 1.f, 0.01f}(sy_host);
|
||||
ck_tile::FillStepRange<TopkWeightDataType>{-.5f, .5f, 0.01f}(topk_weight_host);
|
||||
}
|
||||
else if(init == 1)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f, seed}(a_host);
|
||||
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f, seed}(g_host);
|
||||
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f, seed}(d_host);
|
||||
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f, seed}(sa_host);
|
||||
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed}(sg_host);
|
||||
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed}(sd_host);
|
||||
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f, seed}(sy_host);
|
||||
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f, seed}(topk_weight_host);
|
||||
}
|
||||
else if(init == 2)
|
||||
{
|
||||
ck_tile::FillNormalDistribution<ADataType>{0.f, 1.f, seed}(a_host);
|
||||
ck_tile::FillNormalDistribution<GDataType>{0.f, 1.f, seed}(g_host);
|
||||
ck_tile::FillNormalDistribution<DDataType>{0.f, 1.f, seed}(d_host);
|
||||
ck_tile::FillNormalDistribution<AScaleDataType>{0.f, 1.f, seed}(sa_host);
|
||||
ck_tile::FillNormalDistribution<GScaleDataType>{0.f, 1.f, seed}(sg_host);
|
||||
ck_tile::FillNormalDistribution<DScaleDataType>{0.f, 1.f, seed}(sd_host);
|
||||
ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed}(sy_host);
|
||||
ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed}(topk_weight_host);
|
||||
}
|
||||
|
||||
// permute weight
|
||||
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
|
||||
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
|
||||
|
||||
// do moe sorting
|
||||
if(balance)
|
||||
{
|
||||
int e_cnt = 0;
|
||||
for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++)
|
||||
{
|
||||
topk_ids_host.mData[i] = e_cnt;
|
||||
e_cnt++;
|
||||
if(e_cnt >= experts)
|
||||
e_cnt = 0;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
|
||||
}
|
||||
|
||||
// leave it here for future debug purpose
|
||||
#if 0
|
||||
a_host.loadtxt("../../ater/input_torch.txt");
|
||||
|
||||
topk_ids_host.loadtxt("../../ater/topk_ids_torch.txt", "int");
|
||||
// topk_ids_host.savetxt("topk_ids_2.txt");
|
||||
topk_weight_host.loadtxt("../../ater/topk_weights_torch.txt", "float");
|
||||
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
|
||||
|
||||
g_host.loadtxt("../../ater/w1_torch.txt", "float");
|
||||
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
|
||||
d_host.loadtxt("../../ater/w2_torch.txt", "float");
|
||||
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
|
||||
|
||||
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
|
||||
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
|
||||
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
|
||||
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
std::cout << "sorted_token_ids_host:" << sorted_token_ids_host << std::endl;
|
||||
std::cout << "num_sorted_tiles_host:" << num_sorted_tiles_host << std::endl;
|
||||
std::cout << "sorted_expert_ids_host:" << sorted_expert_ids_host << std::endl;
|
||||
std::cout << "topk_weight_host:" << topk_weight_host << std::endl;
|
||||
std::cout << "sorted_weight_host:" << sorted_weight_host << std::endl;
|
||||
#endif
|
||||
auto cal_tflops = [&](auto ms) {
|
||||
double flop_gemm_0 =
|
||||
2 * static_cast<double>(tokens) * topk * shared_intermediate_size_0 * hidden_size;
|
||||
double flop_gemm_1 =
|
||||
2 * static_cast<double>(tokens) * topk * shared_intermediate_size_1 * hidden_size;
|
||||
return (flop_gemm_0 + flop_gemm_1) / (static_cast<double>(ms) * 1e-3) / 1e12;
|
||||
};
|
||||
|
||||
// TODO: this method we use expert-by-expert view, just for reference
|
||||
auto cal_tbps = [&](auto ms) {
|
||||
double token_bytes =
|
||||
static_cast<double>(tokens) * topk / experts * hidden_size * sizeof(ADataType);
|
||||
double w0_bytes = static_cast<double>(shared_intermediate_size_0) * experts * hidden_size *
|
||||
sizeof(GDataType);
|
||||
double w1_bytes = static_cast<double>(shared_intermediate_size_1) * experts * hidden_size *
|
||||
sizeof(DDataType);
|
||||
double o_bytes =
|
||||
static_cast<double>(tokens) * topk / experts * hidden_size * sizeof(ODataType);
|
||||
double topk_weights_bytes = static_cast<double>(tokens) * topk * sizeof(TopkWeightDataType);
|
||||
// ignore index, they are too small
|
||||
|
||||
return (token_bytes + w0_bytes + w1_bytes + o_bytes + topk_weights_bytes) /
|
||||
(static_cast<double>(ms) * 1e-3) / 1e12;
|
||||
};
|
||||
|
||||
if(api == 0)
|
||||
{
|
||||
ck_tile::DeviceMem a_buf(a_host);
|
||||
ck_tile::DeviceMem g_perm_buf(g_perm_host);
|
||||
ck_tile::DeviceMem d_perm_buf(d_perm_host);
|
||||
ck_tile::DeviceMem sa_buf(sa_host);
|
||||
ck_tile::DeviceMem sg_buf(sg_host);
|
||||
ck_tile::DeviceMem sd_buf(sd_host);
|
||||
ck_tile::DeviceMem sy_buf(sy_host);
|
||||
ck_tile::DeviceMem local_expert_mask_buf(local_expert_mask_host);
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem topk_ids_buf(topk_ids_host);
|
||||
ck_tile::DeviceMem topk_weight_buf(topk_weight_host);
|
||||
|
||||
ck_tile::DeviceMem sorted_token_ids_buf(
|
||||
sorted_token_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_expert_ids_buf(
|
||||
sorted_expert_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem num_sorted_tiles_buf(
|
||||
num_sorted_tiles_host.get_element_space_size_in_bytes());
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size =
|
||||
ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk, 0 /*dispatch_policy*/);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
|
||||
if(is_local_token)
|
||||
{
|
||||
local_tokens_dev.ToDevice(&local_tokens);
|
||||
}
|
||||
|
||||
fused_moe_traits traits{prec_i,
|
||||
prec_w,
|
||||
prec_o,
|
||||
prec_st,
|
||||
prec_sw,
|
||||
prec_sq,
|
||||
prec_kw,
|
||||
block_m,
|
||||
activation,
|
||||
gate_only,
|
||||
fused_quant,
|
||||
local_expert_masking};
|
||||
|
||||
fused_moe_args args{a_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
|
||||
g_perm_buf.GetDeviceBuffer(),
|
||||
d_perm_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
|
||||
local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer()
|
||||
: nullptr,
|
||||
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
|
||||
o_buf.GetDeviceBuffer(),
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
|
||||
topk_ids_buf.GetDeviceBuffer(),
|
||||
topk_weight_buf.GetDeviceBuffer(),
|
||||
sorted_token_ids_buf.GetDeviceBuffer(),
|
||||
sorted_weight_buf.GetDeviceBuffer(),
|
||||
sorted_expert_ids_buf.GetDeviceBuffer(),
|
||||
num_sorted_tiles_buf.GetDeviceBuffer(),
|
||||
block_m,
|
||||
hidden_size,
|
||||
intermediate_size / tp,
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
stride};
|
||||
float ave_time = fused_moe(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
if(ave_time < 0)
|
||||
{
|
||||
std::cout << " not supported!" << std::endl << std::flush;
|
||||
return false;
|
||||
}
|
||||
|
||||
// float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, "
|
||||
<< cal_tbps(ave_time) << " TB/s" << std::flush;
|
||||
bool pass = true;
|
||||
|
||||
#define CPU_FUSED_MOE(act_type_) \
|
||||
ck_tile::reference_fused_moe<AccDataType, act_type_>(a_host, \
|
||||
g_host, \
|
||||
d_host, \
|
||||
sa_host, \
|
||||
sg_host, \
|
||||
sd_host, \
|
||||
sy_host, \
|
||||
o_host, \
|
||||
sorted_token_ids_host, \
|
||||
sorted_weight_host, \
|
||||
sorted_expert_ids_host, \
|
||||
num_sorted_tiles_host, \
|
||||
topk_ids_host, \
|
||||
block_m, \
|
||||
tokens, \
|
||||
experts, \
|
||||
hidden_size, \
|
||||
intermediate_size / tp, \
|
||||
topk, \
|
||||
gate_only)
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
|
||||
topk_ids_host,
|
||||
topk_weight_host,
|
||||
local_expert_mask_host,
|
||||
sorted_token_ids_host,
|
||||
sorted_weight_host,
|
||||
sorted_expert_ids_host,
|
||||
num_sorted_tiles_host.mData[0],
|
||||
experts,
|
||||
block_m,
|
||||
is_local_token ? local_tokens : tokens,
|
||||
local_expert_masking);
|
||||
if(activation == 0)
|
||||
{
|
||||
CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
|
||||
}
|
||||
else
|
||||
{
|
||||
CPU_FUSED_MOE(ck_tile::element_wise::Silu);
|
||||
}
|
||||
|
||||
auto o_dev = o_buf.ToHost<ODataType>();
|
||||
// o_dev.savetxt("gpu-out.txt", "float");
|
||||
auto [rtol, atol] = get_elimit<ADataType>();
|
||||
pass &= ck_tile::check_err(
|
||||
o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
|
||||
}
|
||||
std::cout << std::flush << std::endl;
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_fused_moe_json(arg_parser.get_str("jsonfile"),
|
||||
api_str,
|
||||
prec_str,
|
||||
tokens,
|
||||
is_local_token,
|
||||
local_tokens,
|
||||
experts,
|
||||
topk,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
stride,
|
||||
block_m,
|
||||
activation,
|
||||
gate_only,
|
||||
fused_quant,
|
||||
pass,
|
||||
ave_time,
|
||||
cal_tflops(ave_time),
|
||||
cal_tbps(ave_time));
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
else if(api == 1)
|
||||
{
|
||||
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
|
||||
topk_ids_host,
|
||||
topk_weight_host,
|
||||
local_expert_mask_host,
|
||||
sorted_token_ids_host,
|
||||
sorted_weight_host,
|
||||
sorted_expert_ids_host,
|
||||
num_sorted_tiles_host.mData[0],
|
||||
experts,
|
||||
block_m,
|
||||
is_local_token ? local_tokens : tokens,
|
||||
local_expert_masking);
|
||||
|
||||
// done, preparing GPU buffer
|
||||
ck_tile::DeviceMem a_buf(a_host);
|
||||
ck_tile::DeviceMem g_perm_buf(g_perm_host);
|
||||
ck_tile::DeviceMem d_perm_buf(d_perm_host);
|
||||
ck_tile::DeviceMem sa_buf(sa_host);
|
||||
ck_tile::DeviceMem sg_buf(sg_host);
|
||||
ck_tile::DeviceMem sd_buf(sd_host);
|
||||
ck_tile::DeviceMem sy_buf(sy_host);
|
||||
ck_tile::DeviceMem o_buf(o_host);
|
||||
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
|
||||
if(is_local_token)
|
||||
{
|
||||
local_tokens_dev.ToDevice(&local_tokens);
|
||||
}
|
||||
|
||||
// manually clear output buffer for atomic
|
||||
o_buf.SetZero();
|
||||
//
|
||||
|
||||
ck_tile::DeviceMem sorted_token_ids_buf(sorted_token_ids_host);
|
||||
ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host);
|
||||
ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host);
|
||||
ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host);
|
||||
|
||||
fused_moegemm_traits traits{prec_i,
|
||||
prec_w,
|
||||
prec_o,
|
||||
prec_st,
|
||||
prec_sw,
|
||||
prec_sq,
|
||||
prec_kw,
|
||||
block_m,
|
||||
activation,
|
||||
gate_only,
|
||||
fused_quant};
|
||||
|
||||
fused_moegemm_args args{a_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
|
||||
g_perm_buf.GetDeviceBuffer(),
|
||||
d_perm_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
|
||||
o_buf.GetDeviceBuffer(),
|
||||
sorted_token_ids_buf.GetDeviceBuffer(),
|
||||
sorted_weight_buf.GetDeviceBuffer(),
|
||||
sorted_expert_ids_buf.GetDeviceBuffer(),
|
||||
num_sorted_tiles_buf.GetDeviceBuffer(),
|
||||
hidden_size,
|
||||
intermediate_size / tp,
|
||||
is_local_token ? local_tokens : tokens,
|
||||
experts,
|
||||
topk,
|
||||
stride};
|
||||
|
||||
float ave_time = fused_moegemm(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
if(ave_time < 0)
|
||||
{
|
||||
std::cout << " not supported!" << std::endl << std::flush;
|
||||
return false;
|
||||
}
|
||||
|
||||
// float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, "
|
||||
<< cal_tbps(ave_time) << " TB/s" << std::flush;
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
if(activation == 0)
|
||||
{
|
||||
CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
|
||||
}
|
||||
else
|
||||
{
|
||||
CPU_FUSED_MOE(ck_tile::element_wise::Silu);
|
||||
}
|
||||
|
||||
auto o_dev = o_buf.ToHost<ODataType>();
|
||||
// o_dev.savetxt("gpu-out.txt", "float");
|
||||
auto [rtol, atol] = get_elimit<ADataType>();
|
||||
pass &= ck_tile::check_err(
|
||||
o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
|
||||
}
|
||||
std::cout << std::flush << std::endl;
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_fused_moe_json(arg_parser.get_str("jsonfile"),
|
||||
api_str,
|
||||
prec_str,
|
||||
tokens,
|
||||
is_local_token,
|
||||
local_tokens,
|
||||
experts,
|
||||
topk,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
stride,
|
||||
block_m,
|
||||
activation,
|
||||
gate_only,
|
||||
fused_quant,
|
||||
pass,
|
||||
ave_time,
|
||||
cal_tflops(ave_time),
|
||||
cal_tbps(ave_time));
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_w = arg_parser.get_str("prec_w");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_st = arg_parser.get_str("prec_st");
|
||||
std::string prec_sw = arg_parser.get_str("prec_sw");
|
||||
std::string prec_sq = arg_parser.get_str("prec_sq");
|
||||
std::string prec_kw = arg_parser.get_str("prec_kw");
|
||||
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
|
||||
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
|
||||
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
|
||||
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
|
||||
|
||||
// no dynamic quant case
|
||||
if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32")
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float>(
|
||||
arg_parser)
|
||||
? 0
|
||||
: -2;
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_w == "fp16" && prec_o == "fp16" && prec_kw == "fp32")
|
||||
{
|
||||
return run<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float>(
|
||||
arg_parser)
|
||||
? 0
|
||||
: -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
}
|
||||
BIN
example/ck_tile/15_fused_moe/misc/moe-0.png
Normal file
BIN
example/ck_tile/15_fused_moe/misc/moe-0.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 75 KiB |
BIN
example/ck_tile/15_fused_moe/misc/moe-1.png
Normal file
BIN
example/ck_tile/15_fused_moe/misc/moe-1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 90 KiB |
BIN
example/ck_tile/15_fused_moe/misc/moe-2.png
Normal file
BIN
example/ck_tile/15_fused_moe/misc/moe-2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 124 KiB |
BIN
example/ck_tile/15_fused_moe/misc/moe-3.png
Normal file
BIN
example/ck_tile/15_fused_moe/misc/moe-3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 18 KiB |
Reference in New Issue
Block a user