mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Ck tile/moe sorting (#1624)
* add moe_sorting & check ok
* fix comments & typo
* Run remod.py under include/ck_tile & example/ck_tile directories
* format codes
* fix output ci check bug
* fix moe sorting readme and error commit file
* use magiv div to accelerate compute
* add an loop unroll for moe lds ops
* add extblocksnel to set zeros for moebufs
* [Ck_tile] moe set zero run ok, add size check and fix ref check
* [Ck_tile]fix moe_sorting fuse set_zero remod
* [Ck_tile] change name style, fix zero buffer size err, change folder
* [Ck_tile] moe_sorting: fix name style
* [Ck_tile] moe_sorting, remove useless params in traits
* [Ck_tile] change outputtile cnt * unit_size; change output buf alloc
---------
Co-authored-by: dummycoderfe <noplydummmycoder@163.com>
Co-authored-by: Po Yen, Chen <PoYen.Chen@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
[ROCm/composable_kernel commit: bec6fbc65f]
This commit is contained in:
8
example/ck_tile/13_moe_sorting/CMakeLists.txt
Normal file
8
example/ck_tile/13_moe_sorting/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
add_executable(tile_example_moe_sorting EXCLUDE_FROM_ALL moe_sorting.cpp moe_sorting_api.cpp)
|
||||
target_include_directories(tile_example_moe_sorting PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
|
||||
|
||||
set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
# list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
target_compile_options(tile_example_moe_sorting PRIVATE ${EXAMPLE_MOE_SORTING_COMPILE_OPTIONS})
|
||||
27
example/ck_tile/13_moe_sorting/README.md
Normal file
27
example/ck_tile/13_moe_sorting/README.md
Normal file
@@ -0,0 +1,27 @@
|
||||
# moe-sorting
|
||||
|
||||
This folder contains example for moe-sorting kernel using ck_tile tile-programming implementation. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input&weight is a `token*topk` 2d matrix. The op rearange the input weight ids into different experts and feed into fuse moe gemm kernel.
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_moe_sorting -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_moe_sorting`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-v weather do CPU validation or not (default:1)
|
||||
-pr_i index data type. (currently only fp32 supported now) (default:int32)
|
||||
-pr_w output weight data type(currently only fp32 supported now) (default:fp32)
|
||||
-t number of input tokens (default:32)
|
||||
-e number of experts (default:8)
|
||||
-k topk (default:2)
|
||||
-st_i row stride of input, -1 means same as experts (default:-1)
|
||||
-seed seed to be used, -1 means random every time (default:-1)
|
||||
-kname when set to 1 it will print kernel name (default:0)
|
||||
|
||||
```
|
||||
223
example/ck_tile/13_moe_sorting/moe_sorting.cpp
Normal file
223
example/ck_tile/13_moe_sorting/moe_sorting.cpp
Normal file
@@ -0,0 +1,223 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <time.h>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "moe_sorting_api.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("pr_i", "int32", "index data type. (currently only int32 supported now)")
|
||||
.insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)")
|
||||
.insert("t", "128", "number of input tokens")
|
||||
.insert("e", "8", "number of num_experts")
|
||||
.insert("k", "4", "topk")
|
||||
.insert("unit", "32", "unit_size")
|
||||
.insert("moe_buf_size", "0", "moe_buf_size")
|
||||
.insert("seed", "-1", "seed to be used, -1 means random every time")
|
||||
.insert("kname", "0", "when set to 1 it will print kernel name")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename IndexType>
|
||||
void topid_unique_gen(
|
||||
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
|
||||
{
|
||||
size_t total_size = topk * tokens;
|
||||
std::srand(seed);
|
||||
std::set<IndexType> unique_set;
|
||||
IndexType current_v;
|
||||
for(size_t i = 0; i < total_size; i++)
|
||||
{
|
||||
if(i % topk == 0)
|
||||
{
|
||||
unique_set.clear();
|
||||
}
|
||||
current_v = std::rand() % num_expert;
|
||||
while(unique_set.find(current_v) != unique_set.end())
|
||||
{
|
||||
current_v = std::rand() % num_expert;
|
||||
}
|
||||
unique_set.insert(current_v);
|
||||
host_tensor[i] = current_v;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename WeightType, typename IndexType = ck_tile::index_t>
|
||||
bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
{
|
||||
int validate = args.get_int("v");
|
||||
std::string index_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
int tokens = args.get_int("t");
|
||||
int num_experts = args.get_int("e");
|
||||
int topk = args.get_int("k");
|
||||
int seed = args.get_int("seed");
|
||||
int unit_size = args.get_int("unit");
|
||||
int moe_buf_size = args.get_int("moe_buf_size");
|
||||
int kname = args.get_int("kname");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
int max_output_ids =
|
||||
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
|
||||
|
||||
if(seed < 0)
|
||||
{
|
||||
seed = std::time(nullptr);
|
||||
}
|
||||
|
||||
if(topk > num_experts)
|
||||
{
|
||||
printf("topk:%d value should be smaller than, or equal to number of num_experts:%d\n",
|
||||
topk,
|
||||
num_experts);
|
||||
return false;
|
||||
}
|
||||
|
||||
// tokens already considered batch size
|
||||
ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1});
|
||||
ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1});
|
||||
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1});
|
||||
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1}, {1});
|
||||
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size});
|
||||
|
||||
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
|
||||
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
|
||||
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, num_experts, seed);
|
||||
|
||||
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_ids_dev(sorted_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_weights_dev(sorted_weights_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_expert_ids_dev(
|
||||
sorted_expert_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
|
||||
|
||||
topk_ids_dev.ToDevice(topk_ids_host.data());
|
||||
weights_dev.ToDevice(weights_host.data());
|
||||
if(moe_buf_size > 0)
|
||||
{
|
||||
moe_buf_dev.ToDevice(moe_buf_host.data());
|
||||
}
|
||||
|
||||
moe_sorting_trait trait{index_prec, weight_prec};
|
||||
|
||||
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
|
||||
weights_dev.GetDeviceBuffer(),
|
||||
sorted_ids_dev.GetDeviceBuffer(),
|
||||
sorted_weights_dev.GetDeviceBuffer(),
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(),
|
||||
sorted_id_cnt_dev.GetDeviceBuffer(),
|
||||
moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
|
||||
tokens,
|
||||
unit_size,
|
||||
num_experts,
|
||||
topk,
|
||||
static_cast<ck_tile::index_t>(moe_buf_size * sizeof(float))};
|
||||
|
||||
ck_tile::stream_config sc{nullptr,
|
||||
true,
|
||||
/* log_level = */ (kname ? 1 : 0),
|
||||
warmup,
|
||||
repeat};
|
||||
auto ms = moe_sorting(trait, karg, sc);
|
||||
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ",
|
||||
index_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
ms);
|
||||
if(ms < 0)
|
||||
printf("not supported\n");
|
||||
fflush(stdout);
|
||||
if(ms < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
sorted_ids_dev.FromDevice(sorted_ids_host.data());
|
||||
sorted_weights_dev.FromDevice(sorted_weights_host.data());
|
||||
sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data());
|
||||
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
|
||||
if(moe_buf_size > 0)
|
||||
{
|
||||
moe_buf_dev.FromDevice(moe_buf_host.data());
|
||||
}
|
||||
|
||||
bool rtn = true;
|
||||
if(validate)
|
||||
{
|
||||
ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<IndexType> sorted_expert_ids_ref({max_output_ids / unit_size}, {1});
|
||||
|
||||
int32_t ref_total_tokens_post_pad = 0;
|
||||
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
|
||||
weights_host,
|
||||
sorted_ids_ref,
|
||||
sorted_weights_ref,
|
||||
sorted_expert_ids_ref,
|
||||
ref_total_tokens_post_pad,
|
||||
num_experts,
|
||||
unit_size);
|
||||
rtn &= ck_tile::check_err(
|
||||
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
|
||||
rtn &= ck_tile::check_err(sorted_weights_host,
|
||||
sorted_weights_ref,
|
||||
std::string("OUT Error: Incorrect w!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
rtn &= ck_tile::check_err(sorted_expert_ids_host,
|
||||
sorted_expert_ids_ref,
|
||||
std::string("OUT Error: Incorrect eid!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
if(moe_buf_size)
|
||||
{
|
||||
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
|
||||
rtn &= ck_tile::check_err(
|
||||
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
|
||||
}
|
||||
rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
|
||||
}
|
||||
|
||||
printf("valid:%s\n", rtn ? "y" : "n");
|
||||
fflush(stdout);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
auto [result, args] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
std::string index_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
|
||||
bool r = true;
|
||||
if(weight_prec.compare("fp32") == 0 && index_prec.compare("int32") == 0)
|
||||
{
|
||||
r &= test_moe_sorting<float, ck_tile::index_t>(args);
|
||||
}
|
||||
return r ? 0 : -1;
|
||||
}
|
||||
73
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
Normal file
73
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
Normal file
@@ -0,0 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_sorting_api.hpp"
|
||||
|
||||
#define MOE_SORTING_DISPATCH(unroll_num_) \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_bytes = kernel::GetSmemSize(a); \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
if(a.num_experts > 127)
|
||||
{
|
||||
printf("lds size exceed, only support experts <127 \n");
|
||||
return -1;
|
||||
}
|
||||
if(a.moe_buf_bytes % 16)
|
||||
{
|
||||
printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_bytes);
|
||||
return -1;
|
||||
}
|
||||
using index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64);
|
||||
switch(smem_io_unroll_num)
|
||||
{
|
||||
case(1): {
|
||||
MOE_SORTING_DISPATCH(1);
|
||||
}
|
||||
case(2): {
|
||||
MOE_SORTING_DISPATCH(2);
|
||||
}
|
||||
case(3): {
|
||||
MOE_SORTING_DISPATCH(3);
|
||||
}
|
||||
case(5): {
|
||||
MOE_SORTING_DISPATCH(5);
|
||||
}
|
||||
case(6): {
|
||||
MOE_SORTING_DISPATCH(6);
|
||||
}
|
||||
case(7): {
|
||||
MOE_SORTING_DISPATCH(7);
|
||||
}
|
||||
case(8): {
|
||||
MOE_SORTING_DISPATCH(8);
|
||||
}
|
||||
case(9): {
|
||||
MOE_SORTING_DISPATCH(9);
|
||||
}
|
||||
case(10): {
|
||||
MOE_SORTING_DISPATCH(10);
|
||||
}
|
||||
case(11): {
|
||||
MOE_SORTING_DISPATCH(11);
|
||||
}
|
||||
default: {
|
||||
MOE_SORTING_DISPATCH(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
20
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
Normal file
20
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/moe_sorting.hpp"
|
||||
|
||||
struct moe_sorting_trait
|
||||
{
|
||||
std::string index_type;
|
||||
std::string weight_type; // currently always float
|
||||
};
|
||||
|
||||
struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
|
||||
19
example/ck_tile/13_moe_sorting/script/smoke_test.sh
Normal file
19
example/ck_tile/13_moe_sorting/script/smoke_test.sh
Normal file
@@ -0,0 +1,19 @@
|
||||
# #!/bin/sh
|
||||
|
||||
EXE=./build/bin/tile_example_moe_sorting
|
||||
|
||||
$EXE -t=80 -e=17 -moe_buf_size=16
|
||||
$EXE -t=111 -e=117 -moe_buf_size=4
|
||||
$EXE -t=1000 -e=55 -moe_buf_size=1024
|
||||
$EXE -t=99 -e=120 -moe_buf_size=10244
|
||||
$EXE -t=175 -e=64 -k=8
|
||||
$EXE -t=65 -e=8 -k=2
|
||||
$EXE -t=1 -e=25
|
||||
$EXE -t=31 -e=19 -k=15
|
||||
$EXE -t=81 -e=37 -k=7
|
||||
$EXE -t=23 -e=1 -k=1
|
||||
$EXE -t=127 -e=99 -k=19
|
||||
$EXE -t=71 -e=11 -k=11
|
||||
$EXE -t=1 -e=1 -k=1
|
||||
$EXE -t=99 -e=2 -k=1
|
||||
$EXE -t=333 -e=99 -k=13
|
||||
@@ -12,3 +12,4 @@ add_subdirectory(09_topk_softmax)
|
||||
add_subdirectory(10_rmsnorm2d)
|
||||
add_subdirectory(11_add_rmsnorm2d_rdquant)
|
||||
add_subdirectory(12_smoothquant)
|
||||
add_subdirectory(13_moe_sorting)
|
||||
|
||||
Reference in New Issue
Block a user