mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
[CK_TILE] Fix mock token id, support g1u1/g1u0 through same inline code block (#1808)
* fix mock token id
* prepare host for g1u1
* reformat inline-asm
* restructure uk_0
* restructure gate_up
* done
* change default to init=1
* update readme
* fix a bug in interleave pipeline
* rcp for silu
[ROCm/composable_kernel commit: 1ff50e78c6]
This commit is contained in:
@@ -8,6 +8,9 @@ The benifit of this fused-moe:
|
||||
* much less kernel instance, easy to maintain
|
||||
|
||||
# Implementation and feature support
|
||||
## NOTES:
|
||||
currently gate+up in fp16 case will very easily cause accumulator overflow the fp16 max(65504), hence result in INF. Please use BF16 for gate+up case, API side will have no check for this.
|
||||
|
||||
## moe-sorting
|
||||
this is a common pre-process step before the actual moe-gemm. The purpose is to transform the moe loop over from token-by-token to expert-by-expert, make sure very workgroup is working for a single expert (B matrix). Besides, we extend this op to do the zeroing of the output buffer(to be used for reduce buffer with atomic)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ struct fused_moe_args
|
||||
|
||||
ck_tile::index_t block_m; // block_m, used to devide the input
|
||||
ck_tile::index_t hidden_size; // k
|
||||
ck_tile::index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
|
||||
ck_tile::index_t intermediate_size; // n / TP, for Gate. and Up, Down is also this value
|
||||
ck_tile::index_t num_tokens; // input number of tokens for current iteration
|
||||
ck_tile::index_t num_experts; // number of groups
|
||||
ck_tile::index_t topk; // need this?
|
||||
@@ -45,7 +45,8 @@ struct fused_moe_traits
|
||||
std::string prec_sq; // smooth quant scale
|
||||
std::string prec_kw; // topk-weight data type
|
||||
int block_m;
|
||||
int gate_only;
|
||||
int activation; // 0:gelu, 1:silu
|
||||
int gate_only; // 0:g1u0, 1:g1u1
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
|
||||
@@ -77,7 +77,8 @@ struct fused_moegemm_traits
|
||||
std::string prec_sq; // smooth quant scale
|
||||
std::string prec_kw; // topk-weight data type
|
||||
int block_m;
|
||||
int gate_only;
|
||||
int activation; // 0:gelu, 1:silu
|
||||
int gate_only; // 0:g1u0, 1:g1u1
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
|
||||
t.prec_sq,
|
||||
t.prec_kw,
|
||||
t.block_m,
|
||||
t.activation,
|
||||
t.gate_only,
|
||||
t.fused_quant};
|
||||
auto a1 = fused_moegemm_args{
|
||||
|
||||
@@ -17,15 +17,67 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
|
||||
// clang-format off
|
||||
float r = -1;
|
||||
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
|
||||
{
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
|
||||
constexpr ck_tile::index_t act_ = 0;
|
||||
constexpr ck_tile::index_t go_ = 1;
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 0;
|
||||
constexpr ck_tile::index_t go_ = 0;
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
|
||||
{
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
|
||||
constexpr ck_tile::index_t act_ = 0;
|
||||
constexpr ck_tile::index_t go_ = 1;
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 0;
|
||||
constexpr ck_tile::index_t go_ = 0;
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 1;
|
||||
constexpr ck_tile::index_t go_ = 1;
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 1;
|
||||
constexpr ck_tile::index_t go_ = 0;
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 1;
|
||||
constexpr ck_tile::index_t go_ = 1;
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 1;
|
||||
constexpr ck_tile::index_t go_ = 0;
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
@@ -21,21 +21,31 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
|
||||
typename Ts_::BlockTile_1,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0>;
|
||||
using f_problem =
|
||||
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
|
||||
typename Ts_::GDataType,
|
||||
typename Ts_::DDataType,
|
||||
typename Ts_::AccDataType,
|
||||
typename Ts_::ODataType,
|
||||
typename Ts_::AScaleDataType,
|
||||
typename Ts_::GScaleDataType,
|
||||
typename Ts_::DScaleDataType,
|
||||
typename Ts_::YSmoothScaleDataType,
|
||||
typename Ts_::TopkWeightDataType,
|
||||
typename Ts_::IndexDataType,
|
||||
ck_tile::element_wise::FastGeluAsm, // TODO: hardcoded
|
||||
f_shape,
|
||||
f_traits>;
|
||||
|
||||
constexpr auto get_activation_ = []() {
|
||||
if constexpr(Ts_::Activation == 0)
|
||||
{
|
||||
return ck_tile::element_wise::FastGeluAsm{};
|
||||
}
|
||||
else
|
||||
return ck_tile::element_wise::Silu{};
|
||||
};
|
||||
using f_act_ = ck_tile::remove_cvref_t<decltype(get_activation_())>;
|
||||
|
||||
using f_problem = ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
|
||||
typename Ts_::GDataType,
|
||||
typename Ts_::DDataType,
|
||||
typename Ts_::AccDataType,
|
||||
typename Ts_::ODataType,
|
||||
typename Ts_::AScaleDataType,
|
||||
typename Ts_::GScaleDataType,
|
||||
typename Ts_::DScaleDataType,
|
||||
typename Ts_::YSmoothScaleDataType,
|
||||
typename Ts_::TopkWeightDataType,
|
||||
typename Ts_::IndexDataType,
|
||||
f_act_, // TODO: hardcoded
|
||||
f_shape,
|
||||
f_traits>;
|
||||
|
||||
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
|
||||
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
|
||||
|
||||
@@ -15,7 +15,8 @@ template <typename I,
|
||||
typename KW,
|
||||
typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down>
|
||||
typename WarpPerBlock_,
|
||||
typename WarpTile_, // seq<*,*,*>, used to select mfma
|
||||
typename WarpTile_, // seq<*,*,*>, used to select mfma
|
||||
ck_tile::index_t Activation_ = 0, // 0: Gelu 1: Silu
|
||||
ck_tile::index_t GateOnly_ = 0,
|
||||
ck_tile::index_t FusedQuant_ = 0>
|
||||
struct fmoe_ // traits, ugly name, only used for internal
|
||||
@@ -44,10 +45,11 @@ struct fmoe_ // traits, ugly name, only used for internal
|
||||
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
|
||||
|
||||
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>;
|
||||
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_>;
|
||||
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
|
||||
|
||||
static constexpr ck_tile::index_t Activation = Activation_; // 0: Gelu 1: Silu
|
||||
static constexpr ck_tile::index_t GateOnly = GateOnly_;
|
||||
static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
|
||||
};
|
||||
|
||||
@@ -8,7 +8,18 @@
|
||||
|
||||
// clang-format off
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,7 +8,19 @@
|
||||
|
||||
// clang-format off
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -108,12 +108,14 @@ auto create_args(int argc, char* argv[])
|
||||
.insert(
|
||||
"gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
|
||||
.insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm")
|
||||
.insert("act", "0", "activation after first gemm. 0:gelu, 1:silu")
|
||||
.insert("balance",
|
||||
"0",
|
||||
"if set to 1, will try balance the expert in topk-ids(convenient for testing)")
|
||||
.insert("init",
|
||||
"2",
|
||||
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized"
|
||||
"1",
|
||||
"init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand "
|
||||
"normalized[0, 1]"
|
||||
"normalized(slow)")
|
||||
.insert("seed", "11939", "seed used to do random")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
@@ -135,6 +137,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t intermediate_size = arg_parser.get_int("i");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
ck_tile::index_t block_m = arg_parser.get_int("bm");
|
||||
ck_tile::index_t activation = arg_parser.get_int("act");
|
||||
if(stride < 0)
|
||||
stride = hidden_size;
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
@@ -194,11 +197,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return std::string(", st:") + std::to_string(stride);
|
||||
}();
|
||||
|
||||
std::cout << "[" << api_str << "|" << prec_str << "]"
|
||||
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
|
||||
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
|
||||
<< ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
|
||||
<< ", go:" << gate_only << ", q:" << fused_quant << std::flush;
|
||||
std::cout
|
||||
<< "[" << api_str << "|" << prec_str << "]"
|
||||
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
|
||||
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
|
||||
<< ", act:"
|
||||
<< activation
|
||||
// << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
|
||||
<< (gate_only ? ", g1u0" : ", g1u1") << ", q:" << fused_quant << std::flush;
|
||||
|
||||
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
@@ -370,6 +376,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
prec_sq,
|
||||
prec_kw,
|
||||
block_m,
|
||||
activation,
|
||||
gate_only,
|
||||
fused_quant};
|
||||
|
||||
@@ -389,7 +396,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
num_sorted_tiles_buf.GetDeviceBuffer(),
|
||||
block_m,
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
intermediate_size / tp,
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
@@ -408,6 +415,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
<< cal_tbps(ave_time) << " TB/s" << std::flush;
|
||||
bool pass = true;
|
||||
|
||||
#define CPU_FUSED_MOE(act_type_) \
|
||||
ck_tile::reference_fused_moe<AccDataType, act_type_>(a_host, \
|
||||
g_host, \
|
||||
d_host, \
|
||||
sa_host, \
|
||||
sg_host, \
|
||||
sd_host, \
|
||||
sy_host, \
|
||||
o_host, \
|
||||
sorted_token_ids_host, \
|
||||
sorted_weight_host, \
|
||||
sorted_expert_ids_host, \
|
||||
num_sorted_tiles_host, \
|
||||
topk_ids_host, \
|
||||
block_m, \
|
||||
tokens, \
|
||||
experts, \
|
||||
hidden_size, \
|
||||
intermediate_size / tp, \
|
||||
topk, \
|
||||
gate_only)
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
|
||||
@@ -419,28 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
num_sorted_tiles_host.mData[0],
|
||||
experts,
|
||||
block_m);
|
||||
|
||||
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
|
||||
a_host,
|
||||
g_host,
|
||||
d_host,
|
||||
sa_host,
|
||||
sg_host,
|
||||
sd_host,
|
||||
sy_host,
|
||||
o_host,
|
||||
sorted_token_ids_host,
|
||||
sorted_weight_host,
|
||||
sorted_expert_ids_host,
|
||||
num_sorted_tiles_host,
|
||||
topk_ids_host,
|
||||
block_m,
|
||||
tokens,
|
||||
experts,
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
topk,
|
||||
gate_only);
|
||||
if(activation == 0)
|
||||
{
|
||||
CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
|
||||
}
|
||||
else
|
||||
{
|
||||
CPU_FUSED_MOE(ck_tile::element_wise::Silu);
|
||||
}
|
||||
|
||||
auto o_dev = o_buf.ToHost<ODataType>();
|
||||
// o_dev.savetxt("gpu-out.txt", "float");
|
||||
@@ -491,6 +506,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
prec_sq,
|
||||
prec_kw,
|
||||
block_m,
|
||||
activation,
|
||||
gate_only,
|
||||
fused_quant};
|
||||
|
||||
@@ -507,7 +523,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
sorted_expert_ids_buf.GetDeviceBuffer(),
|
||||
num_sorted_tiles_buf.GetDeviceBuffer(),
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
intermediate_size / tp,
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
@@ -529,27 +545,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
|
||||
a_host,
|
||||
g_host,
|
||||
d_host,
|
||||
sa_host,
|
||||
sg_host,
|
||||
sd_host,
|
||||
sy_host,
|
||||
o_host,
|
||||
sorted_token_ids_host,
|
||||
sorted_weight_host,
|
||||
sorted_expert_ids_host,
|
||||
num_sorted_tiles_host,
|
||||
topk_ids_host,
|
||||
block_m,
|
||||
tokens,
|
||||
experts,
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
topk,
|
||||
gate_only);
|
||||
if(activation == 0)
|
||||
{
|
||||
CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
|
||||
}
|
||||
else
|
||||
{
|
||||
CPU_FUSED_MOE(ck_tile::element_wise::Silu);
|
||||
}
|
||||
|
||||
auto o_dev = o_buf.ToHost<ODataType>();
|
||||
// o_dev.savetxt("gpu-out.txt", "float");
|
||||
|
||||
Reference in New Issue
Block a user