mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Ck tile/moe sorting (#1624)
* add moe_sorting & check ok
* fix comments & typo
* Run remod.py under include/ck_tile & example/ck_tile directories
* format codes
* fix output ci check bug
* fix moe sorting readme and error commit file
* use magiv div to accelerate compute
* add an loop unroll for moe lds ops
* add extblocksnel to set zeros for moebufs
* [Ck_tile] moe set zero run ok, add size check and fix ref check
* [Ck_tile]fix moe_sorting fuse set_zero remod
* [Ck_tile] change name style, fix zero buffer size err, change folder
* [Ck_tile] moe_sorting: fix name style
* [Ck_tile] moe_sorting, remove useless params in traits
* [Ck_tile] change outputtile cnt * unit_size; change output buf alloc
---------
Co-authored-by: dummycoderfe <noplydummmycoder@163.com>
Co-authored-by: Po Yen, Chen <PoYen.Chen@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
[ROCm/composable_kernel commit: bec6fbc65f]
This commit is contained in:
8
example/ck_tile/13_moe_sorting/CMakeLists.txt
Normal file
8
example/ck_tile/13_moe_sorting/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
add_executable(tile_example_moe_sorting EXCLUDE_FROM_ALL moe_sorting.cpp moe_sorting_api.cpp)
|
||||
target_include_directories(tile_example_moe_sorting PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
|
||||
|
||||
set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
# list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
target_compile_options(tile_example_moe_sorting PRIVATE ${EXAMPLE_MOE_SORTING_COMPILE_OPTIONS})
|
||||
27
example/ck_tile/13_moe_sorting/README.md
Normal file
27
example/ck_tile/13_moe_sorting/README.md
Normal file
@@ -0,0 +1,27 @@
|
||||
# moe-sorting
|
||||
|
||||
This folder contains example for moe-sorting kernel using ck_tile tile-programming implementation. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input&weight is a `token*topk` 2d matrix. The op rearange the input weight ids into different experts and feed into fuse moe gemm kernel.
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_moe_sorting -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_moe_sorting`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-v weather do CPU validation or not (default:1)
|
||||
-pr_i index data type. (currently only fp32 supported now) (default:int32)
|
||||
-pr_w output weight data type(currently only fp32 supported now) (default:fp32)
|
||||
-t number of input tokens (default:32)
|
||||
-e number of experts (default:8)
|
||||
-k topk (default:2)
|
||||
-st_i row stride of input, -1 means same as experts (default:-1)
|
||||
-seed seed to be used, -1 means random every time (default:-1)
|
||||
-kname when set to 1 it will print kernel name (default:0)
|
||||
|
||||
```
|
||||
223
example/ck_tile/13_moe_sorting/moe_sorting.cpp
Normal file
223
example/ck_tile/13_moe_sorting/moe_sorting.cpp
Normal file
@@ -0,0 +1,223 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <time.h>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "moe_sorting_api.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("pr_i", "int32", "index data type. (currently only int32 supported now)")
|
||||
.insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)")
|
||||
.insert("t", "128", "number of input tokens")
|
||||
.insert("e", "8", "number of num_experts")
|
||||
.insert("k", "4", "topk")
|
||||
.insert("unit", "32", "unit_size")
|
||||
.insert("moe_buf_size", "0", "moe_buf_size")
|
||||
.insert("seed", "-1", "seed to be used, -1 means random every time")
|
||||
.insert("kname", "0", "when set to 1 it will print kernel name")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename IndexType>
|
||||
void topid_unique_gen(
|
||||
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
|
||||
{
|
||||
size_t total_size = topk * tokens;
|
||||
std::srand(seed);
|
||||
std::set<IndexType> unique_set;
|
||||
IndexType current_v;
|
||||
for(size_t i = 0; i < total_size; i++)
|
||||
{
|
||||
if(i % topk == 0)
|
||||
{
|
||||
unique_set.clear();
|
||||
}
|
||||
current_v = std::rand() % num_expert;
|
||||
while(unique_set.find(current_v) != unique_set.end())
|
||||
{
|
||||
current_v = std::rand() % num_expert;
|
||||
}
|
||||
unique_set.insert(current_v);
|
||||
host_tensor[i] = current_v;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename WeightType, typename IndexType = ck_tile::index_t>
|
||||
bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
{
|
||||
int validate = args.get_int("v");
|
||||
std::string index_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
int tokens = args.get_int("t");
|
||||
int num_experts = args.get_int("e");
|
||||
int topk = args.get_int("k");
|
||||
int seed = args.get_int("seed");
|
||||
int unit_size = args.get_int("unit");
|
||||
int moe_buf_size = args.get_int("moe_buf_size");
|
||||
int kname = args.get_int("kname");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
int max_output_ids =
|
||||
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
|
||||
|
||||
if(seed < 0)
|
||||
{
|
||||
seed = std::time(nullptr);
|
||||
}
|
||||
|
||||
if(topk > num_experts)
|
||||
{
|
||||
printf("topk:%d value should be smaller than, or equal to number of num_experts:%d\n",
|
||||
topk,
|
||||
num_experts);
|
||||
return false;
|
||||
}
|
||||
|
||||
// tokens already considered batch size
|
||||
ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1});
|
||||
ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1});
|
||||
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1});
|
||||
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1}, {1});
|
||||
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size});
|
||||
|
||||
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
|
||||
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
|
||||
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, num_experts, seed);
|
||||
|
||||
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_ids_dev(sorted_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_weights_dev(sorted_weights_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_expert_ids_dev(
|
||||
sorted_expert_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
|
||||
|
||||
topk_ids_dev.ToDevice(topk_ids_host.data());
|
||||
weights_dev.ToDevice(weights_host.data());
|
||||
if(moe_buf_size > 0)
|
||||
{
|
||||
moe_buf_dev.ToDevice(moe_buf_host.data());
|
||||
}
|
||||
|
||||
moe_sorting_trait trait{index_prec, weight_prec};
|
||||
|
||||
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
|
||||
weights_dev.GetDeviceBuffer(),
|
||||
sorted_ids_dev.GetDeviceBuffer(),
|
||||
sorted_weights_dev.GetDeviceBuffer(),
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(),
|
||||
sorted_id_cnt_dev.GetDeviceBuffer(),
|
||||
moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
|
||||
tokens,
|
||||
unit_size,
|
||||
num_experts,
|
||||
topk,
|
||||
static_cast<ck_tile::index_t>(moe_buf_size * sizeof(float))};
|
||||
|
||||
ck_tile::stream_config sc{nullptr,
|
||||
true,
|
||||
/* log_level = */ (kname ? 1 : 0),
|
||||
warmup,
|
||||
repeat};
|
||||
auto ms = moe_sorting(trait, karg, sc);
|
||||
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ",
|
||||
index_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
ms);
|
||||
if(ms < 0)
|
||||
printf("not supported\n");
|
||||
fflush(stdout);
|
||||
if(ms < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
sorted_ids_dev.FromDevice(sorted_ids_host.data());
|
||||
sorted_weights_dev.FromDevice(sorted_weights_host.data());
|
||||
sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data());
|
||||
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
|
||||
if(moe_buf_size > 0)
|
||||
{
|
||||
moe_buf_dev.FromDevice(moe_buf_host.data());
|
||||
}
|
||||
|
||||
bool rtn = true;
|
||||
if(validate)
|
||||
{
|
||||
ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<IndexType> sorted_expert_ids_ref({max_output_ids / unit_size}, {1});
|
||||
|
||||
int32_t ref_total_tokens_post_pad = 0;
|
||||
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
|
||||
weights_host,
|
||||
sorted_ids_ref,
|
||||
sorted_weights_ref,
|
||||
sorted_expert_ids_ref,
|
||||
ref_total_tokens_post_pad,
|
||||
num_experts,
|
||||
unit_size);
|
||||
rtn &= ck_tile::check_err(
|
||||
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
|
||||
rtn &= ck_tile::check_err(sorted_weights_host,
|
||||
sorted_weights_ref,
|
||||
std::string("OUT Error: Incorrect w!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
rtn &= ck_tile::check_err(sorted_expert_ids_host,
|
||||
sorted_expert_ids_ref,
|
||||
std::string("OUT Error: Incorrect eid!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
if(moe_buf_size)
|
||||
{
|
||||
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
|
||||
rtn &= ck_tile::check_err(
|
||||
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
|
||||
}
|
||||
rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
|
||||
}
|
||||
|
||||
printf("valid:%s\n", rtn ? "y" : "n");
|
||||
fflush(stdout);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
auto [result, args] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
std::string index_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
|
||||
bool r = true;
|
||||
if(weight_prec.compare("fp32") == 0 && index_prec.compare("int32") == 0)
|
||||
{
|
||||
r &= test_moe_sorting<float, ck_tile::index_t>(args);
|
||||
}
|
||||
return r ? 0 : -1;
|
||||
}
|
||||
73
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
Normal file
73
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
Normal file
@@ -0,0 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_sorting_api.hpp"
|
||||
|
||||
#define MOE_SORTING_DISPATCH(unroll_num_) \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_bytes = kernel::GetSmemSize(a); \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
if(a.num_experts > 127)
|
||||
{
|
||||
printf("lds size exceed, only support experts <127 \n");
|
||||
return -1;
|
||||
}
|
||||
if(a.moe_buf_bytes % 16)
|
||||
{
|
||||
printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_bytes);
|
||||
return -1;
|
||||
}
|
||||
using index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64);
|
||||
switch(smem_io_unroll_num)
|
||||
{
|
||||
case(1): {
|
||||
MOE_SORTING_DISPATCH(1);
|
||||
}
|
||||
case(2): {
|
||||
MOE_SORTING_DISPATCH(2);
|
||||
}
|
||||
case(3): {
|
||||
MOE_SORTING_DISPATCH(3);
|
||||
}
|
||||
case(5): {
|
||||
MOE_SORTING_DISPATCH(5);
|
||||
}
|
||||
case(6): {
|
||||
MOE_SORTING_DISPATCH(6);
|
||||
}
|
||||
case(7): {
|
||||
MOE_SORTING_DISPATCH(7);
|
||||
}
|
||||
case(8): {
|
||||
MOE_SORTING_DISPATCH(8);
|
||||
}
|
||||
case(9): {
|
||||
MOE_SORTING_DISPATCH(9);
|
||||
}
|
||||
case(10): {
|
||||
MOE_SORTING_DISPATCH(10);
|
||||
}
|
||||
case(11): {
|
||||
MOE_SORTING_DISPATCH(11);
|
||||
}
|
||||
default: {
|
||||
MOE_SORTING_DISPATCH(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
20
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
Normal file
20
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/moe_sorting.hpp"
|
||||
|
||||
struct moe_sorting_trait
|
||||
{
|
||||
std::string index_type;
|
||||
std::string weight_type; // currently always float
|
||||
};
|
||||
|
||||
struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
|
||||
19
example/ck_tile/13_moe_sorting/script/smoke_test.sh
Normal file
19
example/ck_tile/13_moe_sorting/script/smoke_test.sh
Normal file
@@ -0,0 +1,19 @@
|
||||
# #!/bin/sh
|
||||
|
||||
EXE=./build/bin/tile_example_moe_sorting
|
||||
|
||||
$EXE -t=80 -e=17 -moe_buf_size=16
|
||||
$EXE -t=111 -e=117 -moe_buf_size=4
|
||||
$EXE -t=1000 -e=55 -moe_buf_size=1024
|
||||
$EXE -t=99 -e=120 -moe_buf_size=10244
|
||||
$EXE -t=175 -e=64 -k=8
|
||||
$EXE -t=65 -e=8 -k=2
|
||||
$EXE -t=1 -e=25
|
||||
$EXE -t=31 -e=19 -k=15
|
||||
$EXE -t=81 -e=37 -k=7
|
||||
$EXE -t=23 -e=1 -k=1
|
||||
$EXE -t=127 -e=99 -k=19
|
||||
$EXE -t=71 -e=11 -k=11
|
||||
$EXE -t=1 -e=1 -k=1
|
||||
$EXE -t=99 -e=2 -k=1
|
||||
$EXE -t=333 -e=99 -k=13
|
||||
@@ -12,3 +12,4 @@ add_subdirectory(09_topk_softmax)
|
||||
add_subdirectory(10_rmsnorm2d)
|
||||
add_subdirectory(11_add_rmsnorm2d_rdquant)
|
||||
add_subdirectory(12_smoothquant)
|
||||
add_subdirectory(13_moe_sorting)
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
#include "ck_tile/host/reference/reference_im2col.hpp"
|
||||
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
|
||||
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
|
||||
#include "ck_tile/host/reference/reference_permute.hpp"
|
||||
#include "ck_tile/host/reference/reference_reduce.hpp"
|
||||
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
|
||||
|
||||
78
include/ck_tile/host/reference/reference_moe_sorting.hpp
Normal file
78
include/ck_tile/host/reference/reference_moe_sorting.hpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename WeightType, typename IndexType = index_t>
|
||||
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
const HostTensor<WeightType>& weights,
|
||||
HostTensor<IndexType>& p_sorted_token_ids,
|
||||
HostTensor<WeightType>& sorted_weight,
|
||||
HostTensor<IndexType>& sorted_expert_ids,
|
||||
index_t& unit_cnt,
|
||||
const index_t experts,
|
||||
const index_t unit_size)
|
||||
{
|
||||
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
|
||||
const index_t topk = topk_ids.mDesc.get_lengths()[1];
|
||||
std::vector<std::vector<IndexType>> expert_tokens(experts,
|
||||
std::vector<IndexType>(unit_size, num_token));
|
||||
std::vector<std::vector<WeightType>> expert_token_weights(
|
||||
experts, std::vector<WeightType>(unit_size, 0));
|
||||
std::vector<IndexType> expert_slices(experts, 1);
|
||||
std::vector<IndexType> expert_slice_idxs(experts, 0);
|
||||
|
||||
for(index_t t = 0; t < num_token; t++)
|
||||
{
|
||||
for(index_t k = 0; k < topk; k++)
|
||||
{
|
||||
IndexType e = topk_ids(t, k);
|
||||
WeightType w = weights(t, k);
|
||||
index_t idx = expert_slice_idxs[e];
|
||||
if(idx > expert_slices[e] * unit_size - 1)
|
||||
{
|
||||
expert_slices[e]++;
|
||||
index_t new_size = expert_slices[e] * unit_size;
|
||||
expert_tokens[e].resize(new_size);
|
||||
expert_token_weights[e].resize(new_size);
|
||||
for(index_t i = (expert_slices[e] - 1) * unit_size; i < new_size; i++)
|
||||
{
|
||||
expert_tokens[e][i] = num_token;
|
||||
expert_token_weights[e][i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
expert_tokens[e][idx] = t;
|
||||
expert_token_weights[e][idx] = w;
|
||||
expert_slice_idxs[e]++;
|
||||
}
|
||||
}
|
||||
|
||||
IndexType* out_tokens = p_sorted_token_ids.data();
|
||||
WeightType* out_weights = sorted_weight.data();
|
||||
IndexType* out_expert_id = sorted_expert_ids.data();
|
||||
for(index_t e = 0; e < experts; e++)
|
||||
{
|
||||
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
|
||||
out_tokens += expert_slices[e] * unit_size;
|
||||
memcpy(out_weights,
|
||||
expert_token_weights[e].data(),
|
||||
sizeof(WeightType) * expert_slices[e] * unit_size);
|
||||
out_weights += expert_slices[e] * unit_size;
|
||||
|
||||
for(index_t s = 0; s < expert_slices[e]; s++)
|
||||
{
|
||||
out_expert_id[s] = e;
|
||||
unit_cnt++;
|
||||
}
|
||||
out_expert_id += expert_slices[e];
|
||||
}
|
||||
unit_cnt *= unit_size;
|
||||
return;
|
||||
}
|
||||
} // namespace ck_tile
|
||||
232
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
Normal file
232
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
Normal file
@@ -0,0 +1,232 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct MoeSortingHostArgs
|
||||
{
|
||||
const void* p_topk_ids;
|
||||
const void* p_weights;
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_sorted_expert_ids;
|
||||
void* p_total_tokens_post_pad;
|
||||
void* p_moe_buf;
|
||||
index_t tokens;
|
||||
index_t unit_size;
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
index_t moe_buf_bytes;
|
||||
};
|
||||
|
||||
template <typename Problem_>
|
||||
struct MoeSortingKernel
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
|
||||
using IndexType = typename Problem::IndexType;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_topk_ids;
|
||||
const void* p_weights;
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_sorted_expert_ids;
|
||||
void* p_total_tokens_post_pad;
|
||||
void* p_moe_buf;
|
||||
index_t tokens;
|
||||
index_t num_experts;
|
||||
index_t moe_buf_bytes;
|
||||
|
||||
index_t tokens_per_thread;
|
||||
mdiv unit_size_mdiv;
|
||||
mdiv topk_mdiv;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
// TODO: assume num-experts not too much
|
||||
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BlockSize(h).x * 16));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h)
|
||||
{
|
||||
return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size()));
|
||||
}
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
|
||||
{
|
||||
const auto blocks = BlockSize(h);
|
||||
return ((blocks.x + 1) * h.num_experts + (h.num_experts + 1)) * sizeof(index_t);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_weights = h.p_weights;
|
||||
k.p_sorted_token_ids = h.p_sorted_token_ids;
|
||||
k.p_sorted_weights = h.p_sorted_weights;
|
||||
k.p_sorted_expert_ids = h.p_sorted_expert_ids;
|
||||
k.p_moe_buf = h.p_moe_buf;
|
||||
k.p_total_tokens_post_pad = h.p_total_tokens_post_pad;
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
k.moe_buf_bytes = h.moe_buf_bytes;
|
||||
|
||||
const auto blocks = BlockSize(h);
|
||||
k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x);
|
||||
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const
|
||||
{
|
||||
return row * total_col + col;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes) const
|
||||
{
|
||||
const index_t offset = (blockIdx.x - 1) * blockDim.x + threadIdx.x;
|
||||
if(offset < buf_bytes / 16)
|
||||
{
|
||||
buf[offset] = uint8x16_t{0};
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id,
|
||||
const WeightType* __restrict__ weights,
|
||||
index_t* p_sorted_token_ids,
|
||||
WeightType* p_sorted_weights,
|
||||
index_t* p_sorted_expert_ids,
|
||||
index_t* p_total_tokens_post_pad,
|
||||
const index_t num_experts,
|
||||
const index_t tokens_per_thread,
|
||||
const index_t numel,
|
||||
const mdiv unit_size_mdiv,
|
||||
const mdiv topk_mdiv,
|
||||
void* smem) const
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
const index_t start_idx = tid * tokens_per_thread;
|
||||
|
||||
index_t* shared_mem = reinterpret_cast<index_t*>(smem);
|
||||
|
||||
index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts)
|
||||
index_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1: (num_experts + 1)
|
||||
for(int i = 0; i < num_experts; ++i)
|
||||
{
|
||||
tokens_cnts[calc_index(num_experts, tid + 1, i)] = 0;
|
||||
}
|
||||
#pragma unroll Problem_::InternalLoadUnroll
|
||||
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
|
||||
{
|
||||
++tokens_cnts[calc_index(num_experts, tid + 1, topk_id[i])];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if(tid < num_experts)
|
||||
{
|
||||
tokens_cnts[calc_index(num_experts, 0, tid)] = 0;
|
||||
for(int i = 1; i <= static_cast<index_t>(blockDim.x); ++i)
|
||||
{
|
||||
tokens_cnts[calc_index(num_experts, i, tid)] +=
|
||||
tokens_cnts[calc_index(num_experts, i - 1, tid)];
|
||||
}
|
||||
}
|
||||
|
||||
// __syncthreads();
|
||||
if(tid == 0)
|
||||
{
|
||||
cumsum[0] = 0;
|
||||
for(int i = 1; i <= num_experts; ++i)
|
||||
{
|
||||
auto current_units = [&]() {
|
||||
index_t x_ = tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)] +
|
||||
unit_size_mdiv.divisor - 1;
|
||||
index_t y_ = unit_size_mdiv.div(x_);
|
||||
return max(y_, 1) * unit_size_mdiv.divisor;
|
||||
}();
|
||||
cumsum[i] = cumsum[i - 1] + current_units;
|
||||
}
|
||||
*p_total_tokens_post_pad = cumsum[num_experts];
|
||||
}
|
||||
__syncthreads();
|
||||
if(tid < num_experts)
|
||||
{
|
||||
for(int i = cumsum[tid]; i < cumsum[tid + 1]; i += unit_size_mdiv.divisor)
|
||||
{
|
||||
p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll Problem_::InternalLoadUnroll
|
||||
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
|
||||
{
|
||||
index_t expert_id = topk_id[i];
|
||||
index_t rank_post_pad =
|
||||
tokens_cnts[calc_index(num_experts, tid, expert_id)] + cumsum[expert_id];
|
||||
p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i);
|
||||
p_sorted_weights[rank_post_pad] = weights[i];
|
||||
++tokens_cnts[calc_index(num_experts, tid, expert_id)];
|
||||
}
|
||||
|
||||
const index_t prefill_token = topk_mdiv.div(numel);
|
||||
if(tid < num_experts)
|
||||
{
|
||||
index_t expert_offset =
|
||||
cumsum[tid] + tokens_cnts[calc_index(num_experts, blockDim.x, tid)];
|
||||
while(expert_offset < cumsum[tid + 1])
|
||||
{
|
||||
p_sorted_token_ids[expert_offset] = prefill_token;
|
||||
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
|
||||
expert_offset++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
if(blockIdx.x > 0)
|
||||
{
|
||||
if(kargs.p_moe_buf)
|
||||
{
|
||||
moe_buf_set_zero_kernel(reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes);
|
||||
}
|
||||
return;
|
||||
}
|
||||
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
|
||||
extern __shared__ char smem[];
|
||||
return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
|
||||
static_cast<const WeightType*>(kargs.p_weights),
|
||||
static_cast<IndexType*>(kargs.p_sorted_token_ids),
|
||||
static_cast<WeightType*>(kargs.p_sorted_weights),
|
||||
static_cast<IndexType*>(kargs.p_sorted_expert_ids),
|
||||
static_cast<IndexType*>(kargs.p_total_tokens_post_pad),
|
||||
kargs.num_experts,
|
||||
kargs.tokens_per_thread,
|
||||
numel,
|
||||
kargs.unit_size_mdiv,
|
||||
kargs.topk_mdiv,
|
||||
smem);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
|
||||
#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// template <typename Problem_, typename Policy_ = MoeSortingPolicy>
|
||||
// struct MoeSortingPipeline
|
||||
// {
|
||||
// // TODO: this kernel only support warp per row
|
||||
// using Problem = remove_cvref_t<Problem_>;
|
||||
// using Policy = remove_cvref_t<Policy_>;
|
||||
// using WeightType = typename Problem::WeightType;
|
||||
|
||||
// template <typename TopkIdWindow, typename WeightWindow>
|
||||
// CK_TILE_DEVICE auto operator()(const TopkIdWindow& topk_id_window,
|
||||
// const WeightWindow& weight_window,
|
||||
// index_t* p_sorted_token_ids,
|
||||
// WeightType* p_sorted_weights,
|
||||
// index_t* p_sorted_expert_ids,
|
||||
// index_t* p_total_tokens_post_pad,
|
||||
// const index_t num_experts,
|
||||
// const index_t unit_size,
|
||||
// const size_t numel,
|
||||
// const index_t topk)
|
||||
// {
|
||||
// }
|
||||
// };
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,15 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/softmax.hpp"
|
||||
#include "ck_tile/ops/topk.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct MoeSortingPolicy
|
||||
{
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename IndexType_, typename WeightType_, index_t InternalLoadUnroll_>
|
||||
struct MoeSortingProblem
|
||||
{
|
||||
// TODO: this kernel only support warp per row
|
||||
using WeightType = remove_cvref_t<WeightType_>;
|
||||
using IndexType = remove_cvref_t<IndexType_>;
|
||||
|
||||
static constexpr index_t WarpSize = get_warp_size();
|
||||
static constexpr index_t WarpsPerBlock = 1;
|
||||
static constexpr index_t InternalLoadUnroll = InternalLoadUnroll_;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
11
include/ck_tile/ops/moe_sorting.hpp
Normal file
11
include/ck_tile/ops/moe_sorting.hpp
Normal file
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
Reference in New Issue
Block a user