mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
update code
This commit is contained in:
15
example/ck_tile/15_fused_moe/CMakeLists.txt
Normal file
15
example/ck_tile/15_fused_moe/CMakeLists.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
set(TILE_EXAPMLE_FUSED_MOE "tile_example_fused_moe")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding ${TILE_EXAPMLE_FUSED_MOE}")
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
add_executable(${TILE_EXAPMLE_FUSED_MOE} EXCLUDE_FROM_ALL main.cpp)
|
||||
target_include_directories(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${INSTANCE_SRCS})
|
||||
|
||||
set(TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
|
||||
target_compile_options(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS})
|
||||
@@ -16,33 +16,33 @@ struct FusedMoeGemmTypeConfig;
|
||||
template <typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>;
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using GDataType = ck_tile::bf16_t;
|
||||
using DDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
|
||||
using W0ScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using W1ScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using GDataType = ck_tile::bf16_t;
|
||||
using DDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
};
|
||||
|
||||
template <typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW>;
|
||||
{
|
||||
using ADataType = ck_tile::int8_t;
|
||||
using GDataType = ck_tile::int8_t;
|
||||
using DDataType = ck_tile::int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
|
||||
using W0ScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using W1ScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
|
||||
using ADataType = ck_tile::int8_t;
|
||||
using GDataType = ck_tile::int8_t;
|
||||
using DDataType = ck_tile::int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
};
|
||||
|
||||
// runtime args
|
||||
@@ -53,14 +53,16 @@ struct fused_moegemm_args : public ck_tile::Layernorm2dFwdHostArgs
|
||||
// This is the public API, will be generated by script
|
||||
struct fused_moegemm_traits
|
||||
{
|
||||
std::string prec_i; // input precision
|
||||
std::string prec_w; // weight precision
|
||||
std::string prec_o; // output precision
|
||||
std::string prec_i; // input precision
|
||||
std::string prec_w; // weight precision
|
||||
std::string prec_o; // output precision
|
||||
std::string prec_st; // token scale data type
|
||||
std::string prec_sw; // weight scale data type
|
||||
std::string prec_sq; // smooth quant scale
|
||||
std::string prec_kw; // topk-weight data type
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
std::string prec_kw; // topk-weight data type
|
||||
int block_m;
|
||||
int gate_only;
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
float fused_moegemm(fused_moegemm_traits, fused_moegemm_args, const ck_tile::stream_config&);
|
||||
|
||||
35
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
Normal file
35
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "fused_moegemm.hpp"
|
||||
|
||||
// Note: this internal API only declare, not define here, otherwise will block `make -j`
|
||||
template <typename Traits_>
|
||||
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile::stream_config& s)
|
||||
{
|
||||
template <ck_tile::index_t... Is>
|
||||
using S = ck_tile::sequence<Is...>;
|
||||
float r = -1;
|
||||
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && block_m == 32 &&
|
||||
gate_only == 1)
|
||||
{
|
||||
using t_ = fmoe_<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
S<32, 512, 128, 128>,
|
||||
S<4, 1, 1>,
|
||||
S<32, 32, 16>,
|
||||
1,
|
||||
0>;
|
||||
fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fused_moegemm_api_traits.hpp"
|
||||
#include "ck_tile/ops/fused_moe.hpp"
|
||||
|
||||
template <typename Ts_>
|
||||
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
|
||||
{
|
||||
using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>;
|
||||
using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts::WarpTile_0,
|
||||
typename Ts_::BlockTile_1,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts::WarpTile_0>;
|
||||
using f_problem = ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
|
||||
typename Ts_::GDataType,
|
||||
typename Ts_::DDataType,
|
||||
typename Ts_::AccDataType,
|
||||
typename Ts_::ODataType,
|
||||
typename Ts_::AScaleDataType,
|
||||
typename Ts_::GScaleDataType,
|
||||
typename Ts_::DScaleDataType,
|
||||
typename Ts_::YSmoothScaleDataType,
|
||||
typename Ts_::TopkWeightDataType,
|
||||
typename Ts_::IndexDataType,
|
||||
ck_tile::Gelu, // TODO: hardcoded
|
||||
f_shape,
|
||||
f_traits>
|
||||
|
||||
using f_pipeline = ck_tile::FusedMoeGemmPipeline_Flatmm<f_problem>;
|
||||
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
|
||||
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
|
||||
|
||||
const dim3 grids = f_kernel::GridSize(a);
|
||||
constexpr dim3 blocks = f_kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
auto kargs = f_kernel::MakeKargs(a);
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << f_kernel::GetName() << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(f_kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <typename I,
|
||||
typename W,
|
||||
typename O,
|
||||
typename ST,
|
||||
typename SW,
|
||||
typename SQ,
|
||||
typename KW,
|
||||
typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down>
|
||||
typename WarpPerBlock_,
|
||||
typename WarpTile_, // seq<*,*,*>, used to select mfma
|
||||
ck_tile::index_t GateOnly_ = 0,
|
||||
ck_tile::index_t FusedQuant_ = 0>
|
||||
struct fmoe_ // traits, ugly name, only used for internal
|
||||
{
|
||||
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename TypeConfig::ADataType>;
|
||||
using GDataType = remove_cvref_t<typename TypeConfig::GDataType>;
|
||||
using DDataType = remove_cvref_t<typename TypeConfig::DDataType>;
|
||||
using AccDataType = remove_cvref_t<typename TypeConfig::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename TypeConfig::ODataType>;
|
||||
using AScaleDataType = remove_cvref_t<typename TypeConfig::AScaleDataType>;
|
||||
using GScaleDataType = remove_cvref_t<typename TypeConfig::GScaleDataType>;
|
||||
using DScaleDataType = remove_cvref_t<typename TypeConfig::DScaleDataType>;
|
||||
using YSmoothScaleDataType = remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>;
|
||||
using TopkWeightDataType = remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
|
||||
using IndexDataType = remove_cvref_t<typename TypeConfig::IndexDataType>;
|
||||
|
||||
static constexpr index_t BT_ = BlockTIle_::at(number<0>{}); // block token
|
||||
static constexpr index_t BI_ = BlockTIle_::at(number<1>{}); // block intermediate
|
||||
static constexpr index_t BH_ = BlockTIle_::at(number<2>{}); // block hidden
|
||||
static constexpr index_t BD_ = BlockTIle_::at(number<3>{}); // block down
|
||||
|
||||
using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>;
|
||||
using WarpPerBlock_0 = remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpTile_0 = remove_cvref_t<WarpTile_>;
|
||||
|
||||
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>;
|
||||
using WarpPerBlock_1 = remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpTile_1 = remove_cvref_t<WarpTile_>;
|
||||
|
||||
static constexpr ck_tile::index_t GateOnly = GateOnly_;
|
||||
static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
|
||||
};
|
||||
@@ -1,7 +1,10 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "layernorm2d_fwd.hpp"
|
||||
#include "fused_moegemm.hpp"
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
@@ -20,18 +23,64 @@ auto get_elimit<ck_tile::bf16_t>()
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
|
||||
// mfma_type, 0:32x32, 1:16x16
|
||||
template<typename H>
|
||||
auto shuffle_moe_weight(const H& t, std::string mfma_dtype, int mfma_type = 0)
|
||||
// TODO: padding?
|
||||
template <typename T>
|
||||
auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type = 0)
|
||||
{
|
||||
static_assert(t.get_lengths().size() == 3);
|
||||
int b_ = t.get_lengths()[0];
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[2];
|
||||
if ((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0) {
|
||||
std::vector<ck_tile::index_t> new_lens {b_, n_/32, 32, k_/16, 2, 8};
|
||||
if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 16, 2, 8});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 32, 4, 8});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 32, 2, 16});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 64, 4, 16});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
template <typename IndexType>
|
||||
void topid_unique_gen(
|
||||
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
|
||||
{
|
||||
size_t total_size = topk * tokens;
|
||||
std::srand(seed);
|
||||
std::set<IndexType> unique_set;
|
||||
IndexType current_v;
|
||||
for(size_t i = 0; i < total_size; i++)
|
||||
{
|
||||
if(i % topk == 0)
|
||||
{
|
||||
unique_set.clear();
|
||||
}
|
||||
current_v = std::rand() % num_expert;
|
||||
while(unique_set.find(current_v) != unique_set.end())
|
||||
{
|
||||
current_v = std::rand() % num_expert;
|
||||
}
|
||||
unique_set.insert(current_v);
|
||||
host_tensor[i] = current_v;
|
||||
}
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
@@ -55,8 +104,11 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32")
|
||||
.insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32")
|
||||
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
|
||||
.insert("gonly", "0", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
|
||||
.insert("balance", "1", "if set to 1, will try balance the expert in topk-ids(convenient for testing)")
|
||||
.insert(
|
||||
"gate_only", "0", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
|
||||
.insert("balance",
|
||||
"1",
|
||||
"if set to 1, will try balance the expert in topk-ids(convenient for testing)")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
|
||||
@@ -64,133 +116,178 @@ auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type, SQ:smooth-quant-type, KW:topk-weight-type
|
||||
// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type,
|
||||
// SQ:smooth-quant-type, KW:topk-weight-type
|
||||
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t tokens = arg_parser.get_int("t");
|
||||
ck_tile::index_t experts = arg_parser.get_int("e");
|
||||
ck_tile::index_t topk = arg_parser.get_int("k");
|
||||
ck_tile::index_t hidden_size = arg_parser.get_int("h");
|
||||
ck_tile::index_t intermediate_size = arg_parser.get_int("i");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
ck_tile::index_t block_m = arg_parser.get_int("bm");
|
||||
ck_tile::index_t tokens = arg_parser.get_int("t");
|
||||
ck_tile::index_t experts = arg_parser.get_int("e");
|
||||
ck_tile::index_t topk = arg_parser.get_int("k");
|
||||
ck_tile::index_t hidden_size = arg_parser.get_int("h");
|
||||
ck_tile::index_t intermediate_size = arg_parser.get_int("i");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
ck_tile::index_t block_m = arg_parser.get_int("bm");
|
||||
if(stride < 0)
|
||||
stride = hidden_size;
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_w = arg_parser.get_str("prec_w");
|
||||
std::string prec_w = arg_parser.get_str("prec_w");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_st = arg_parser.get_str("prec_st");
|
||||
std::string prec_sw = arg_parser.get_str("prec_sw");
|
||||
std::string prec_sq = arg_parser.get_str("prec_sq");
|
||||
std::string prec_st = arg_parser.get_str("prec_st");
|
||||
std::string prec_sw = arg_parser.get_str("prec_sw");
|
||||
std::string prec_sq = arg_parser.get_str("prec_sq");
|
||||
std::string prec_kw = arg_parser.get_str("prec_kw");
|
||||
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
|
||||
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
|
||||
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
|
||||
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
int gonly = arg_parser.get_int("gonly");
|
||||
int balance = arg_parser.get_int("balance");
|
||||
int tp = arg_parser.get_int("tp");
|
||||
ck_tile::index_t shared_intermediate_size = intermediate_size * (gonly ? 1 : 2) / tp;
|
||||
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
|
||||
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
|
||||
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
|
||||
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
int gate_only = arg_parser.get_int("gate_only");
|
||||
int balance = arg_parser.get_int("balance");
|
||||
int tp = arg_parser.get_int("tp");
|
||||
|
||||
ck_tile::index_t shared_intermediate_size = intermediate_size * (gate_only ? 1 : 2) / tp;
|
||||
|
||||
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
|
||||
using ADataType = typename TypeConfig::ADataType ;
|
||||
using GDataType = typename TypeConfig::GDataType ;
|
||||
using DDataType = typename TypeConfig::DDataType ;
|
||||
using AccDataType = typename TypeConfig::AccDataType ;
|
||||
using ODataType = typename TypeConfig::ODataType ;
|
||||
using AScaleDataType = typename TypeConfig::AScaleDataType ;
|
||||
using W0ScaleDataType = typename TypeConfig::W0ScaleDataType ;
|
||||
using W1ScaleDataType = typename TypeConfig::W1ScaleDataType ;
|
||||
using YSmoothScaleDataType = typename TypeConfig::YSmoothScaleDataType;
|
||||
using TopkWeightDataType = typename TypeConfig::TopkWeightDataType ;
|
||||
using IndexDataType = typename TypeConfig::IndexDataType ;
|
||||
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using GDataType = typename TypeConfig::GDataType;
|
||||
using DDataType = typename TypeConfig::DDataType;
|
||||
using AccDataType = typename TypeConfig::AccDataType;
|
||||
using ODataType = typename TypeConfig::ODataType;
|
||||
using AScaleDataType = typename TypeConfig::AScaleDataType;
|
||||
using GScaleDataType = typename TypeConfig::GScaleDataType;
|
||||
using DScaleDataType = typename TypeConfig::DScaleDataType;
|
||||
using YSmoothScaleDataType = typename TypeConfig::YSmoothScaleDataType;
|
||||
using TopkWeightDataType = typename TypeConfig::TopkWeightDataType;
|
||||
using IndexDataType = typename TypeConfig::IndexDataType;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1});
|
||||
ck_tile::HostTensor<ADataType> g_host({e, shared_intermediate_size, hidden_size});
|
||||
ck_tile::HostTensor<ADataType> d_host({e, intermediate_size, hidden_size});
|
||||
ck_tile::HostTensor<GDataType> g_host({e, shared_intermediate_size, hidden_size});
|
||||
ck_tile::HostTensor<DDataType> d_host({e, intermediate_size, hidden_size});
|
||||
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
|
||||
ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
|
||||
ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size});
|
||||
ck_tile::HostTensor<DScaleDataType> sd_host({intermediate_size});
|
||||
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({intermediate_size}); // smooth-quant
|
||||
ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort
|
||||
ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort
|
||||
|
||||
int max_num_tokens_padded = topk * tokens + experts * (block_m - 1);
|
||||
ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<TopkWeightDataType> sorted_weight_host({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<IndexDataType> sorted_expert_ids_host(
|
||||
{(max_num_tokens_padded + block_m - 1) / block_m});
|
||||
ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1});
|
||||
|
||||
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {stride, 1});
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1});
|
||||
|
||||
ck_tile::HostTensor<MeanDataType> mean_host_ref({m});
|
||||
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m});
|
||||
ck_tile::HostTensor<YScaleDataType> y_scale_host_ref({m});
|
||||
ck_tile::HostTensor<YScaleDataType> y_scale_host_dev({m});
|
||||
|
||||
ck_tile::HostTensor<XScaleDataType> x_scale_host({n});
|
||||
ck_tile::HostTensor<XScaleDataType> x_scale_host_dev({n});
|
||||
// permute weight
|
||||
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w);
|
||||
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w);
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
||||
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
|
||||
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host);
|
||||
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
|
||||
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
|
||||
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f}(g_perm_host);
|
||||
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f}(d_perm_host);
|
||||
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f}(sa_host);
|
||||
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f}(sg_host);
|
||||
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f}(sd_host);
|
||||
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f}(sy_host);
|
||||
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f}(topk_weight_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem x_scale_buf(x_scale_host_dev.get_element_space_size_in_bytes());
|
||||
// do moe sorting
|
||||
if(balance)
|
||||
{
|
||||
int e_cnt = 0 for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++)
|
||||
{
|
||||
topk_ids_host.mData[i] = e_cnt;
|
||||
e_cnt++;
|
||||
if(e_cnt >= experts)
|
||||
e_cnt = 0;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, experts, 11913);
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
|
||||
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
|
||||
topk_ids_host,
|
||||
topk_weight_host,
|
||||
sorted_token_ids_host,
|
||||
sorted_weight_host,
|
||||
sorted_expert_ids_host,
|
||||
num_sorted_tiles_host.mData[0],
|
||||
experts,
|
||||
block_m);
|
||||
// done, preparing GPU buffer
|
||||
ck_tile::DeviceMem a_buf(a_host);
|
||||
ck_tile::DeviceMem g_perm_buf(g_perm_host);
|
||||
ck_tile::DeviceMem d_perm_buf(d_perm_host);
|
||||
ck_tile::DeviceMem sa_buf(sa_host);
|
||||
ck_tile::DeviceMem sg_buf(sg_host);
|
||||
ck_tile::DeviceMem sd_buf(sd_host);
|
||||
ck_tile::DeviceMem sy_buf(sy_host);
|
||||
ck_tile::DeviceMem o_buf(o_host);
|
||||
|
||||
x_buf.ToDevice(a_host.data());
|
||||
gamma_buf.ToDevice(gamma_host.data());
|
||||
beta_buf.ToDevice(beta_host.data());
|
||||
x_residual_buf.ToDevice(x_residual_host.data());
|
||||
x_scale_buf.ToDevice(x_scale_host.data());
|
||||
ck_tile::DeviceMem sorted_token_ids_buf(sorted_token_ids_host);
|
||||
ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host);
|
||||
ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host);
|
||||
ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host);
|
||||
|
||||
auto prec_str = [&]() {
|
||||
auto base_str = prec_i;
|
||||
if(prec_i != prec_w)
|
||||
base_str += "x" + prec_w;
|
||||
if(prec_i != prec_o)
|
||||
base_str += "=" + prec_o;
|
||||
if(fused_quant != 0)
|
||||
{
|
||||
base_str += "|" + prec_o;
|
||||
}
|
||||
if(fused_quant == 1)
|
||||
{
|
||||
base_str += std::string("(") + prec_sy + ")";
|
||||
base_str += std::string("(") + prec_sa + "|" + prec_sg + "|" + prec_sq + ")";
|
||||
}
|
||||
return base_str;
|
||||
}();
|
||||
|
||||
std::cout << "[" << prec_str << "]"
|
||||
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
|
||||
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << ", st:" << stride
|
||||
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
|
||||
<< ", go:" << gate_only << ", q:" << fused_quant << std::flush;
|
||||
|
||||
layernorm2d_fwd_traits traits{
|
||||
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant};
|
||||
fused_moegemm_traits traits{prec_i,
|
||||
prec_w,
|
||||
prec_o,
|
||||
prec_st,
|
||||
prec_sw,
|
||||
prec_sq,
|
||||
prec_kw,
|
||||
block_m,
|
||||
gate_only,
|
||||
fused_quant};
|
||||
|
||||
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
|
||||
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr,
|
||||
gamma_buf.GetDeviceBuffer(),
|
||||
beta_buf.GetDeviceBuffer(),
|
||||
fused_moegemm_args args{a_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
|
||||
g_buf.GetDeviceBuffer(),
|
||||
d_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0
|
||||
? sg_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0
|
||||
? sd_buf.GetDeviceBuffer(),
|
||||
fused_quant == 1
|
||||
? sy_buf.GetDeviceBuffer(),
|
||||
o_buf.GetDeviceBuffer(),
|
||||
sorted_token_ids_buf.GetDeviceBuffer(),
|
||||
sorted_weight_buf.GetDeviceBuffer(),
|
||||
sorted_expert_ids_buf.GetDeviceBuffer(),
|
||||
num_sorted_tiles_buf.GetDeviceBuffer(),
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
num_tokens,
|
||||
experts,
|
||||
stride };
|
||||
|
||||
y_buf.GetDeviceBuffer(),
|
||||
fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr,
|
||||
nullptr, // p_mean, unsupported yet
|
||||
nullptr, // p_invStd, unsupported yet
|
||||
|
||||
epsilon,
|
||||
m,
|
||||
n,
|
||||
stride};
|
||||
|
||||
float ave_time = layernorm2d_fwd(
|
||||
float ave_time = fused_moegemm(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
if(ave_time < 0)
|
||||
@@ -199,22 +296,30 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return false;
|
||||
}
|
||||
|
||||
#if 0
|
||||
std::size_t num_byte = sizeof(ADataType) * m * n + sizeof(GammaDataType) * n +
|
||||
sizeof(BetaDataType) * n + sizeof(YDataType) * m * n;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
|
||||
#else
|
||||
std::size_t flop_gemm_0 = 2 * tokens * topk * shared_intermediate_size * hidden_size;
|
||||
std::size_t flop_gemm_1 = 2 * tokens * topk * hidden_size * hidden_size;
|
||||
double tflops = (flop_gemm_0 + flop_gemm_1) / (static_cast<double>(ave_time) * 1e-3) / 1e12;
|
||||
|
||||
// float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << tflops << " tflops" << std::flush;
|
||||
#endif
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
#if 0
|
||||
// reference
|
||||
if(fused_add != 0)
|
||||
{
|
||||
// fused pre_add/pre_add_store
|
||||
// TODO we accumulate directly to a_host for simplcity here...
|
||||
|
||||
std::transform(a_host.mData.cbegin(),
|
||||
a_host.mData.cend(),
|
||||
x_residual_host.mData.cbegin(),
|
||||
@@ -353,6 +458,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
#else
|
||||
std::cout << std::flush << std::endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
return pass;
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
#include "ck_tile/host/reference/reference_im2col.hpp"
|
||||
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
|
||||
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
|
||||
#include "ck_tile/host/reference/reference_permute.hpp"
|
||||
#include "ck_tile/host/reference/reference_reduce.hpp"
|
||||
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <stdint.h>
|
||||
#include <stdexcept>
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename T>
|
||||
@@ -36,6 +37,19 @@ struct DeviceMem
|
||||
mpDeviceBuf = nullptr;
|
||||
}
|
||||
}
|
||||
template <T>
|
||||
DeviceMem(const HostTensor<T>& t) : mMemSize(t.get_element_space_size_in_bytes())
|
||||
{
|
||||
if(mMemSize != 0)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
|
||||
}
|
||||
else
|
||||
{
|
||||
mpDeviceBuf = nullptr;
|
||||
}
|
||||
ToDevice(t.data());
|
||||
}
|
||||
void Realloc(std::size_t mem_size)
|
||||
{
|
||||
if(mpDeviceBuf)
|
||||
@@ -92,6 +106,22 @@ struct DeviceMem
|
||||
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
}
|
||||
|
||||
// construct a host tensor with type T
|
||||
template <typename T>
|
||||
HostTensor<T> ToHost(std::size_t cpySize = mMemSize)
|
||||
{
|
||||
// TODO: host tensor could be slightly larger than the device tensor
|
||||
// we just copy all data from GPU buffer
|
||||
std::size_t host_elements =
|
||||
(cpySize + sizeof(T) - 1) / sizeof(T) HostTensor<T> h_({host_elements});
|
||||
if(mpDeviceBuf)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipMemcpy(h_.data(), mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
return h_;
|
||||
}
|
||||
|
||||
void SetZero() const
|
||||
{
|
||||
if(mpDeviceBuf)
|
||||
|
||||
78
include/ck_tile/host/reference/reference_moe_sorting.hpp
Normal file
78
include/ck_tile/host/reference/reference_moe_sorting.hpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename WeightType, typename IndexType = index_t>
|
||||
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
const HostTensor<WeightType>& weights,
|
||||
HostTensor<IndexType>& sorted_token_ids,
|
||||
HostTensor<WeightType>& sorted_weight,
|
||||
HostTensor<IndexType>& sorted_expert_ids,
|
||||
index_t& unit_cnt,
|
||||
const index_t experts,
|
||||
const index_t unit_size)
|
||||
{
|
||||
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
|
||||
const index_t topk = topk_ids.mDesc.get_lengths()[1];
|
||||
std::vector<std::vector<IndexType>> expert_tokens(experts,
|
||||
std::vector<IndexType>(unit_size, num_token));
|
||||
std::vector<std::vector<WeightType>> expert_token_weights(
|
||||
experts, std::vector<WeightType>(unit_size, 0));
|
||||
std::vector<IndexType> expert_slices(experts, 1);
|
||||
std::vector<IndexType> expert_slice_idxs(experts, 0);
|
||||
|
||||
for(index_t t = 0; t < num_token; t++)
|
||||
{
|
||||
for(index_t k = 0; k < topk; k++)
|
||||
{
|
||||
IndexType e = topk_ids(t, k);
|
||||
WeightType w = weights(t, k);
|
||||
index_t idx = expert_slice_idxs[e];
|
||||
if(idx > expert_slices[e] * unit_size - 1)
|
||||
{
|
||||
expert_slices[e]++;
|
||||
index_t new_size = expert_slices[e] * unit_size;
|
||||
expert_tokens[e].resize(new_size);
|
||||
expert_token_weights[e].resize(new_size);
|
||||
for(index_t i = (expert_slices[e] - 1) * unit_size; i < new_size; i++)
|
||||
{
|
||||
expert_tokens[e][i] = num_token;
|
||||
expert_token_weights[e][i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
expert_tokens[e][idx] = t;
|
||||
expert_token_weights[e][idx] = w;
|
||||
expert_slice_idxs[e]++;
|
||||
}
|
||||
}
|
||||
|
||||
IndexType* out_tokens = sorted_token_ids.data();
|
||||
WeightType* out_weights = sorted_weight.data();
|
||||
IndexType* out_expert_id = sorted_expert_ids.data();
|
||||
for(index_t e = 0; e < experts; e++)
|
||||
{
|
||||
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
|
||||
out_tokens += expert_slices[e] * unit_size;
|
||||
memcpy(out_weights,
|
||||
expert_token_weights[e].data(),
|
||||
sizeof(WeightType) * expert_slices[e] * unit_size);
|
||||
out_weights += expert_slices[e] * unit_size;
|
||||
|
||||
for(index_t s = 0; s < expert_slices[e]; s++)
|
||||
{
|
||||
out_expert_id[s] = e;
|
||||
unit_cnt++;
|
||||
}
|
||||
out_expert_id += expert_slices[e];
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -56,11 +56,10 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
CK_TILE_HOST auto
|
||||
reference_permute(const HostTensor<DataType>& x, std::vector<index_t> perm)
|
||||
CK_TILE_HOST auto reference_permute(const HostTensor<DataType>& x, std::vector<index_t> perm)
|
||||
{
|
||||
auto x_shape = x.get_lengths();
|
||||
ck_tile::index_t rank = perm.size();
|
||||
auto x_shape = x.get_lengths();
|
||||
ck_tile::index_t rank = perm.size();
|
||||
std::vector<ck_tile::index_t> y_shape = [&]() {
|
||||
std::vector<ck_tile::index_t> tmp(rank, 0);
|
||||
for(int i = 0; i < static_cast<int>(rank); i++)
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fused_moe/kernel/fused_moe_kernel.hpp"
|
||||
#include "ck_tile/ops/fused_moe/kernel/fused_moe_shape.hpp"
|
||||
#include "ck_tile/ops/fused_moe/kernel/fused_moe_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_flatmm.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_flatmm_policy.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_traits.hpp"
|
||||
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp"
|
||||
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
|
||||
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
@@ -22,17 +22,17 @@
|
||||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
|
||||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
|
||||
//
|
||||
// max_tokens_post_padded : top_k * input_tokens + num_experts * (M_a - 1)
|
||||
// max_num_tokens_padded : top_k * input_tokens + num_experts * (M_a - 1)
|
||||
// * this could be larger than actual, since actual tokens are on GPU
|
||||
//
|
||||
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
|
||||
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
|
||||
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
|
||||
//
|
||||
// * length is max_tokens_post_padded, actual size is num_tokens_post_padded_ptr
|
||||
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
|
||||
//
|
||||
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
|
||||
// * length is (max_tokens_post_padded + block_size - 1) / block_size
|
||||
// * length is (max_num_tokens_padded + block_size - 1) / block_size
|
||||
//
|
||||
// num_tokens_post_padded_ptr : [28]
|
||||
// num_sorted_tiles_ptr : [7]
|
||||
@@ -43,11 +43,12 @@
|
||||
// 3) use num_sorted_tiles_ptr, already divided by M_a
|
||||
//
|
||||
// * below used for indexing
|
||||
// 1) sorted_token_ids_ptr
|
||||
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
|
||||
// 2) sorted_weight_ptr
|
||||
// 3) sorted_expert_ids_ptr
|
||||
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
|
||||
//
|
||||
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
|
||||
//
|
||||
// [indexing implementation-2]
|
||||
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
|
||||
@@ -92,15 +93,15 @@ struct FusedMoeGemmHostArgs
|
||||
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
|
||||
void* o_ptr; // [m, k], output token
|
||||
|
||||
const void* sorted_token_ids_ptr;
|
||||
const void* sorted_weight_ptr;
|
||||
const void* sorted_expert_ids_ptr;
|
||||
const void* num_sorted_tiles_ptr;
|
||||
const void* sorted_token_ids_ptr; // [max_num_tokens_padded]
|
||||
const void* sorted_weight_ptr; // [max_num_tokens_padded]
|
||||
const void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size]
|
||||
const void* num_sorted_tiles_ptr; // [1]
|
||||
|
||||
index_t hidden_size; // k
|
||||
index_t hidden_size; // k
|
||||
index_t intermediate_size; // n (TP slice this)
|
||||
index_t num_tokens; // input number of tokens for current iteration
|
||||
index_t num_experts; // number of groups
|
||||
index_t num_tokens; // input number of tokens for current iteration
|
||||
index_t num_experts; // number of groups
|
||||
// index_t top_k; // need this?
|
||||
|
||||
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
|
||||
@@ -134,10 +135,10 @@ struct FusedMoeGemmKernel
|
||||
|
||||
using Traits = typename Pipeline::Problem::Traits;
|
||||
|
||||
static constexpr bool IsGateOnly = Traits::IsGateOnly;
|
||||
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
|
||||
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
|
||||
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
|
||||
static constexpr bool IsGateOnly = Traits::IsGateOnly;
|
||||
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
|
||||
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
|
||||
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
@@ -173,10 +174,10 @@ struct FusedMoeGemmKernel
|
||||
const void* sorted_expert_ids_ptr;
|
||||
const void* num_sorted_tiles_ptr;
|
||||
|
||||
index_t hidden_size; // k
|
||||
index_t hidden_size; // k
|
||||
index_t intermediate_size; // n (TP slice this)
|
||||
index_t num_tokens; // input number of tokens for current iteration
|
||||
index_t num_experts; // number of groups
|
||||
index_t num_tokens; // input number of tokens for current iteration
|
||||
index_t num_experts; // number of groups
|
||||
// index_t top_k; // need this?
|
||||
|
||||
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
|
||||
@@ -214,7 +215,7 @@ struct FusedMoeGemmKernel
|
||||
|
||||
index_t nr_0 = kargs.intermediate_size / Pipeline::Block_Nr0;
|
||||
index_t kr_0 = kargs.hidden_size / Pipeline::Block_Kr0;
|
||||
index_t nr_1 = kargs.hidden_size / Pipeline::Block_Nr1; // should be same as kr_0
|
||||
index_t nr_1 = kargs.hidden_size / Pipeline::Block_Nr1; // should be same as kr_0
|
||||
index_t kr_1 = kargs.intermediate_size / Pipeline::Block_Kr1; // should be same as nr_0
|
||||
|
||||
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
|
||||
@@ -280,11 +281,12 @@ struct FusedMoeGemmKernel
|
||||
make_tuple(kr_0 * BlockShape::Block_W0, number<Pipeline::Block_W0>{}, 1),
|
||||
number<Pipeline::kAlignmentG>{},
|
||||
number<1>{});
|
||||
const auto g_view_1_ = pad_tensor_view(g_view_,
|
||||
make_tuple(number<Pipeline::Block_Nr0>{},
|
||||
number<Pipeline::Block_Kr0>{},
|
||||
number<Pipeline::Block_W0>{}),
|
||||
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
|
||||
const auto g_view_1_ =
|
||||
pad_tensor_view(g_view_,
|
||||
make_tuple(number<Pipeline::Block_Nr0>{},
|
||||
number<Pipeline::Block_Kr0>{},
|
||||
number<Pipeline::Block_W0>{}),
|
||||
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
|
||||
|
||||
const auto g_window_ = make_tile_window(g_view_1_,
|
||||
make_tuple(number<BlockShape::Block_Nr0>{},
|
||||
@@ -308,11 +310,12 @@ struct FusedMoeGemmKernel
|
||||
make_tuple(kr_1 * Pipeline::Block_W1, Pipeline::Block_W1, 1),
|
||||
number<Pipeline::kAlignmentD>{},
|
||||
number<1>{});
|
||||
const auto d_view_1_ = pad_tensor_view(d_view_,
|
||||
make_tuple(number<Pipeline::kBlockNr_1>{},
|
||||
number<Pipeline::kBlockKr_1>{},
|
||||
number<Pipeline::Block_W1>{}),
|
||||
sequence<PadHiddenSize, PadIntermediateSize, 0>{});
|
||||
const auto d_view_1_ =
|
||||
pad_tensor_view(d_view_,
|
||||
make_tuple(number<Pipeline::kBlockNr_1>{},
|
||||
number<Pipeline::kBlockKr_1>{},
|
||||
number<Pipeline::Block_W1>{}),
|
||||
sequence<PadHiddenSize, PadIntermediateSize, 0>{});
|
||||
|
||||
const auto d_window_ = make_tile_window(d_view_1_,
|
||||
make_tuple(number<Pipeline::kBlockNr_1>{},
|
||||
|
||||
@@ -44,10 +44,10 @@ struct FusedMoeGemmPipeline_Flatmm
|
||||
|
||||
using Traits = typename Pipeline::Problem::Traits;
|
||||
|
||||
static constexpr bool IsGateOnly = Traits::IsGateOnly;
|
||||
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
|
||||
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
|
||||
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
|
||||
static constexpr bool IsGateOnly = Traits::IsGateOnly;
|
||||
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
|
||||
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
|
||||
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
|
||||
|
||||
static constexpr index_t kAlignmentA = Policy::GetAlignment_A<Problem>();
|
||||
static constexpr index_t kAlignmentG = Policy::GetAlignment_G<Problem>();
|
||||
@@ -133,11 +133,12 @@ struct FusedMoeGemmPipeline_Flatmm
|
||||
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
|
||||
number<kAlignmentG>{},
|
||||
number<1>{});
|
||||
const auto u_view_1_ = pad_tensor_view(u_view_,
|
||||
make_tuple(number<BlockShape::Block_Nr0>{},
|
||||
number<BlockShape::Block_Kr0>{},
|
||||
number<BlockShape::Block_W0>{}),
|
||||
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
|
||||
const auto u_view_1_ =
|
||||
pad_tensor_view(u_view_,
|
||||
make_tuple(number<BlockShape::Block_Nr0>{},
|
||||
number<BlockShape::Block_Kr0>{},
|
||||
number<BlockShape::Block_W0>{}),
|
||||
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
|
||||
return u_view_1_;
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -225,7 +225,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_0()
|
||||
{
|
||||
if constexpr(Problem::Traits::PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
|
||||
if constexpr(Problem::Traits::PermuteEnum ==
|
||||
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
|
||||
{
|
||||
using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same
|
||||
constexpr index_t NPerBlock = Problem::BlockShape::Block_N0;
|
||||
@@ -703,7 +704,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_0()
|
||||
{
|
||||
if constexpr(Problem::Traits::PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
|
||||
if constexpr(Problem::Traits::PermuteEnum ==
|
||||
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
|
||||
{
|
||||
using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same
|
||||
constexpr index_t NPerBlock = Problem::BlockShape::Block_N0;
|
||||
@@ -723,7 +725,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_1()
|
||||
{
|
||||
if constexpr(Problem::Traits::PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
|
||||
if constexpr(Problem::Traits::PermuteEnum ==
|
||||
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
|
||||
{
|
||||
using WarpGemm = GetWarpGemm1<Problem>{}; // assume warpgemm0/1 are the same
|
||||
constexpr index_t NPerBlock = Problem::BlockShape::kBlockN_1;
|
||||
|
||||
@@ -14,8 +14,8 @@ template <typename ADataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
typename AScaleDataType_,
|
||||
typename W0ScaleDataType_,
|
||||
typename W1ScaleDataType_,
|
||||
typename GScaleDataType_,
|
||||
typename DScaleDataType_,
|
||||
typename YSmoothScaleDataType_,
|
||||
typename TopkWeightDataType_,
|
||||
typename IndexDataType_, // data type for all indexing
|
||||
|
||||
@@ -19,14 +19,18 @@ enum class FusedMoeGemmWeightPermuteEnum
|
||||
template <bool IsGateOnly_,
|
||||
bool UseSmoothQuant_,
|
||||
index_t OAtomic_, // 0-no atomic, 1-atomic-pk-f16/bf16, 2-atomic-f32
|
||||
FusedMoeGemmWeightPermuteEnum PermuteEnum_ = FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten;
|
||||
bool PadHiddenSize_ = false, bool PadIntermediateSize_ = false > struct FusedMoeGemmTraits
|
||||
FusedMoeGemmWeightPermuteEnum PermuteEnum_ =
|
||||
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten,
|
||||
bool PadHiddenSize_ = false,
|
||||
bool PadIntermediateSize_ = false>
|
||||
struct FusedMoeGemmTraits
|
||||
{
|
||||
// Gate+Up or Gate only
|
||||
static constexpr bool IsGateOnly = IsGateOnly_;
|
||||
static constexpr bool UseSmoothQuant = UseSmoothQuant_;
|
||||
static constexpr index_t OAtomic = OAtomic_;
|
||||
static constexpr bool PadHiddenSize = PadHiddenSize_;
|
||||
static constexpr bool PadIntermediateSize = PadIntermediateSize_;
|
||||
static constexpr bool IsGateOnly = IsGateOnly_;
|
||||
static constexpr bool UseSmoothQuant = UseSmoothQuant_;
|
||||
static constexpr index_t OAtomic = OAtomic_;
|
||||
static constexpr FusedMoeGemmWeightPermuteEnum PermuteEnum = PermuteEnum_;
|
||||
static constexpr bool PadHiddenSize = PadHiddenSize_;
|
||||
static constexpr bool PadIntermediateSize = PadIntermediateSize_;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user