mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Fix mock token id, support g1u1/g1u0 through same inline code block (#1808)
* fix mock token id
* prepare host for g1u1
* reformat inline-asm
* restructure uk_0
* restructure gate_up
* done
* change default to init=1
* update readme
* fix a bug in interleave pipeline
* rcp for silu
[ROCm/composable_kernel commit: 1ff50e78c6]
This commit is contained in:
@@ -8,6 +8,9 @@ The benifit of this fused-moe:
|
||||
* much less kernel instance, easy to maintain
|
||||
|
||||
# Implementation and feature support
|
||||
## NOTES:
|
||||
currently gate+up in fp16 case will very easily cause accumulator overflow the fp16 max(65504), hence result in INF. Please use BF16 for gate+up case, API side will have no check for this.
|
||||
|
||||
## moe-sorting
|
||||
this is a common pre-process step before the actual moe-gemm. The purpose is to transform the moe loop over from token-by-token to expert-by-expert, make sure very workgroup is working for a single expert (B matrix). Besides, we extend this op to do the zeroing of the output buffer(to be used for reduce buffer with atomic)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ struct fused_moe_args
|
||||
|
||||
ck_tile::index_t block_m; // block_m, used to devide the input
|
||||
ck_tile::index_t hidden_size; // k
|
||||
ck_tile::index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
|
||||
ck_tile::index_t intermediate_size; // n / TP, for Gate. and Up, Down is also this value
|
||||
ck_tile::index_t num_tokens; // input number of tokens for current iteration
|
||||
ck_tile::index_t num_experts; // number of groups
|
||||
ck_tile::index_t topk; // need this?
|
||||
@@ -45,7 +45,8 @@ struct fused_moe_traits
|
||||
std::string prec_sq; // smooth quant scale
|
||||
std::string prec_kw; // topk-weight data type
|
||||
int block_m;
|
||||
int gate_only;
|
||||
int activation; // 0:gelu, 1:silu
|
||||
int gate_only; // 0:g1u0, 1:g1u1
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
|
||||
@@ -77,7 +77,8 @@ struct fused_moegemm_traits
|
||||
std::string prec_sq; // smooth quant scale
|
||||
std::string prec_kw; // topk-weight data type
|
||||
int block_m;
|
||||
int gate_only;
|
||||
int activation; // 0:gelu, 1:silu
|
||||
int gate_only; // 0:g1u0, 1:g1u1
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
|
||||
t.prec_sq,
|
||||
t.prec_kw,
|
||||
t.block_m,
|
||||
t.activation,
|
||||
t.gate_only,
|
||||
t.fused_quant};
|
||||
auto a1 = fused_moegemm_args{
|
||||
|
||||
@@ -17,15 +17,67 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
|
||||
// clang-format off
|
||||
float r = -1;
|
||||
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
|
||||
{
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
|
||||
constexpr ck_tile::index_t act_ = 0;
|
||||
constexpr ck_tile::index_t go_ = 1;
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 0;
|
||||
constexpr ck_tile::index_t go_ = 0;
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
|
||||
{
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
|
||||
constexpr ck_tile::index_t act_ = 0;
|
||||
constexpr ck_tile::index_t go_ = 1;
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 0;
|
||||
constexpr ck_tile::index_t go_ = 0;
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 1;
|
||||
constexpr ck_tile::index_t go_ = 1;
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 1;
|
||||
constexpr ck_tile::index_t go_ = 0;
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 1;
|
||||
constexpr ck_tile::index_t go_ = 1;
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1)
|
||||
{
|
||||
constexpr ck_tile::index_t act_ = 1;
|
||||
constexpr ck_tile::index_t go_ = 0;
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
@@ -21,21 +21,31 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
|
||||
typename Ts_::BlockTile_1,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0>;
|
||||
using f_problem =
|
||||
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
|
||||
typename Ts_::GDataType,
|
||||
typename Ts_::DDataType,
|
||||
typename Ts_::AccDataType,
|
||||
typename Ts_::ODataType,
|
||||
typename Ts_::AScaleDataType,
|
||||
typename Ts_::GScaleDataType,
|
||||
typename Ts_::DScaleDataType,
|
||||
typename Ts_::YSmoothScaleDataType,
|
||||
typename Ts_::TopkWeightDataType,
|
||||
typename Ts_::IndexDataType,
|
||||
ck_tile::element_wise::FastGeluAsm, // TODO: hardcoded
|
||||
f_shape,
|
||||
f_traits>;
|
||||
|
||||
constexpr auto get_activation_ = []() {
|
||||
if constexpr(Ts_::Activation == 0)
|
||||
{
|
||||
return ck_tile::element_wise::FastGeluAsm{};
|
||||
}
|
||||
else
|
||||
return ck_tile::element_wise::Silu{};
|
||||
};
|
||||
using f_act_ = ck_tile::remove_cvref_t<decltype(get_activation_())>;
|
||||
|
||||
using f_problem = ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
|
||||
typename Ts_::GDataType,
|
||||
typename Ts_::DDataType,
|
||||
typename Ts_::AccDataType,
|
||||
typename Ts_::ODataType,
|
||||
typename Ts_::AScaleDataType,
|
||||
typename Ts_::GScaleDataType,
|
||||
typename Ts_::DScaleDataType,
|
||||
typename Ts_::YSmoothScaleDataType,
|
||||
typename Ts_::TopkWeightDataType,
|
||||
typename Ts_::IndexDataType,
|
||||
f_act_, // TODO: hardcoded
|
||||
f_shape,
|
||||
f_traits>;
|
||||
|
||||
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
|
||||
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
|
||||
|
||||
@@ -15,7 +15,8 @@ template <typename I,
|
||||
typename KW,
|
||||
typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down>
|
||||
typename WarpPerBlock_,
|
||||
typename WarpTile_, // seq<*,*,*>, used to select mfma
|
||||
typename WarpTile_, // seq<*,*,*>, used to select mfma
|
||||
ck_tile::index_t Activation_ = 0, // 0: Gelu 1: Silu
|
||||
ck_tile::index_t GateOnly_ = 0,
|
||||
ck_tile::index_t FusedQuant_ = 0>
|
||||
struct fmoe_ // traits, ugly name, only used for internal
|
||||
@@ -44,10 +45,11 @@ struct fmoe_ // traits, ugly name, only used for internal
|
||||
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
|
||||
|
||||
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>;
|
||||
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_>;
|
||||
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
|
||||
|
||||
static constexpr ck_tile::index_t Activation = Activation_; // 0: Gelu 1: Silu
|
||||
static constexpr ck_tile::index_t GateOnly = GateOnly_;
|
||||
static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
|
||||
};
|
||||
|
||||
@@ -8,7 +8,18 @@
|
||||
|
||||
// clang-format off
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,7 +8,19 @@
|
||||
|
||||
// clang-format off
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -108,12 +108,14 @@ auto create_args(int argc, char* argv[])
|
||||
.insert(
|
||||
"gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
|
||||
.insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm")
|
||||
.insert("act", "0", "activation after first gemm. 0:gelu, 1:silu")
|
||||
.insert("balance",
|
||||
"0",
|
||||
"if set to 1, will try balance the expert in topk-ids(convenient for testing)")
|
||||
.insert("init",
|
||||
"2",
|
||||
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized"
|
||||
"1",
|
||||
"init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand "
|
||||
"normalized[0, 1]"
|
||||
"normalized(slow)")
|
||||
.insert("seed", "11939", "seed used to do random")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
@@ -135,6 +137,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t intermediate_size = arg_parser.get_int("i");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
ck_tile::index_t block_m = arg_parser.get_int("bm");
|
||||
ck_tile::index_t activation = arg_parser.get_int("act");
|
||||
if(stride < 0)
|
||||
stride = hidden_size;
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
@@ -194,11 +197,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return std::string(", st:") + std::to_string(stride);
|
||||
}();
|
||||
|
||||
std::cout << "[" << api_str << "|" << prec_str << "]"
|
||||
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
|
||||
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
|
||||
<< ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
|
||||
<< ", go:" << gate_only << ", q:" << fused_quant << std::flush;
|
||||
std::cout
|
||||
<< "[" << api_str << "|" << prec_str << "]"
|
||||
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
|
||||
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
|
||||
<< ", act:"
|
||||
<< activation
|
||||
// << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
|
||||
<< (gate_only ? ", g1u0" : ", g1u1") << ", q:" << fused_quant << std::flush;
|
||||
|
||||
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
@@ -370,6 +376,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
prec_sq,
|
||||
prec_kw,
|
||||
block_m,
|
||||
activation,
|
||||
gate_only,
|
||||
fused_quant};
|
||||
|
||||
@@ -389,7 +396,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
num_sorted_tiles_buf.GetDeviceBuffer(),
|
||||
block_m,
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
intermediate_size / tp,
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
@@ -408,6 +415,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
<< cal_tbps(ave_time) << " TB/s" << std::flush;
|
||||
bool pass = true;
|
||||
|
||||
#define CPU_FUSED_MOE(act_type_) \
|
||||
ck_tile::reference_fused_moe<AccDataType, act_type_>(a_host, \
|
||||
g_host, \
|
||||
d_host, \
|
||||
sa_host, \
|
||||
sg_host, \
|
||||
sd_host, \
|
||||
sy_host, \
|
||||
o_host, \
|
||||
sorted_token_ids_host, \
|
||||
sorted_weight_host, \
|
||||
sorted_expert_ids_host, \
|
||||
num_sorted_tiles_host, \
|
||||
topk_ids_host, \
|
||||
block_m, \
|
||||
tokens, \
|
||||
experts, \
|
||||
hidden_size, \
|
||||
intermediate_size / tp, \
|
||||
topk, \
|
||||
gate_only)
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
|
||||
@@ -419,28 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
num_sorted_tiles_host.mData[0],
|
||||
experts,
|
||||
block_m);
|
||||
|
||||
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
|
||||
a_host,
|
||||
g_host,
|
||||
d_host,
|
||||
sa_host,
|
||||
sg_host,
|
||||
sd_host,
|
||||
sy_host,
|
||||
o_host,
|
||||
sorted_token_ids_host,
|
||||
sorted_weight_host,
|
||||
sorted_expert_ids_host,
|
||||
num_sorted_tiles_host,
|
||||
topk_ids_host,
|
||||
block_m,
|
||||
tokens,
|
||||
experts,
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
topk,
|
||||
gate_only);
|
||||
if(activation == 0)
|
||||
{
|
||||
CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
|
||||
}
|
||||
else
|
||||
{
|
||||
CPU_FUSED_MOE(ck_tile::element_wise::Silu);
|
||||
}
|
||||
|
||||
auto o_dev = o_buf.ToHost<ODataType>();
|
||||
// o_dev.savetxt("gpu-out.txt", "float");
|
||||
@@ -491,6 +506,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
prec_sq,
|
||||
prec_kw,
|
||||
block_m,
|
||||
activation,
|
||||
gate_only,
|
||||
fused_quant};
|
||||
|
||||
@@ -507,7 +523,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
sorted_expert_ids_buf.GetDeviceBuffer(),
|
||||
num_sorted_tiles_buf.GetDeviceBuffer(),
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
intermediate_size / tp,
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
@@ -529,27 +545,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
|
||||
a_host,
|
||||
g_host,
|
||||
d_host,
|
||||
sa_host,
|
||||
sg_host,
|
||||
sd_host,
|
||||
sy_host,
|
||||
o_host,
|
||||
sorted_token_ids_host,
|
||||
sorted_weight_host,
|
||||
sorted_expert_ids_host,
|
||||
num_sorted_tiles_host,
|
||||
topk_ids_host,
|
||||
block_m,
|
||||
tokens,
|
||||
experts,
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
topk,
|
||||
gate_only);
|
||||
if(activation == 0)
|
||||
{
|
||||
CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
|
||||
}
|
||||
else
|
||||
{
|
||||
CPU_FUSED_MOE(ck_tile::element_wise::Silu);
|
||||
}
|
||||
|
||||
auto o_dev = o_buf.ToHost<ODataType>();
|
||||
// o_dev.savetxt("gpu-out.txt", "float");
|
||||
|
||||
@@ -73,7 +73,7 @@ void reference_fused_moe(
|
||||
ck_tile::index_t tokens,
|
||||
ck_tile::index_t experts,
|
||||
ck_tile::index_t hidden_size,
|
||||
ck_tile::index_t intermediate_size, // this size is for gate/up
|
||||
ck_tile::index_t intermediate_size, // this size is for gate/up/down
|
||||
ck_tile::index_t topk,
|
||||
ck_tile::index_t gate_only)
|
||||
{
|
||||
@@ -82,19 +82,8 @@ void reference_fused_moe(
|
||||
assert(sorted_expert_ids_host.get_num_of_dimension() == 1);
|
||||
assert(num_sorted_tiles_host.get_element_size() == 1);
|
||||
ck_tile::index_t num_sorted_tiles = num_sorted_tiles_host.mData[0] / block_m;
|
||||
ck_tile::index_t intermediate_size_0 = intermediate_size;
|
||||
ck_tile::index_t intermediate_size_1 = intermediate_size / (gate_only ? 1 : 2);
|
||||
|
||||
// TODO: better remove this in the future, or modify the token_id value
|
||||
auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
|
||||
for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
|
||||
{
|
||||
if(token_ids_host(token_id_, i_) == expert_id_)
|
||||
return i_;
|
||||
}
|
||||
throw std::runtime_error("not correct token/expert pair\n");
|
||||
return -1; // TODO: not correct!!
|
||||
};
|
||||
ck_tile::index_t intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2);
|
||||
ck_tile::index_t intermediate_size_1 = intermediate_size;
|
||||
|
||||
ck_tile::HostTensor<AccDataType> out_topk_tokens({tokens, topk, hidden_size});
|
||||
|
||||
@@ -105,11 +94,31 @@ void reference_fused_moe(
|
||||
if(i_tile >= num_sorted_tiles)
|
||||
return;
|
||||
ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
|
||||
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
|
||||
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
|
||||
ck_tile::index_t i_topk = i_token >> 24;
|
||||
i_token &= 0xffffff;
|
||||
if(i_token >= tokens)
|
||||
return;
|
||||
(void)token_ids_host;
|
||||
#else
|
||||
// TODO: better remove this in the future, or modify the token_id value
|
||||
auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
|
||||
for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
|
||||
{
|
||||
if(token_ids_host(token_id_, i_) == expert_id_)
|
||||
return i_;
|
||||
}
|
||||
throw std::runtime_error("not correct token/expert pair\n");
|
||||
return -1; // TODO: not correct!!
|
||||
};
|
||||
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
|
||||
if(i_token >= tokens)
|
||||
return;
|
||||
ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly
|
||||
auto weight = sorted_weight_host.mData[i_flatten];
|
||||
#endif
|
||||
auto weight = sorted_weight_host.mData[i_flatten];
|
||||
|
||||
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
|
||||
// first gemm
|
||||
|
||||
@@ -719,8 +719,83 @@ struct Silu
|
||||
constexpr T one = type_convert<T>(1);
|
||||
y = x * (one / (one + ck_tile::exp(-x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
|
||||
{
|
||||
constexpr auto one = type_convert<float>(1);
|
||||
y[0] = x[0] * __builtin_amdgcn_rcpf(one + ck_tile::exp(-x[0]));
|
||||
y[1] = x[1] * __builtin_amdgcn_rcpf(one + ck_tile::exp(-x[1]));
|
||||
};
|
||||
};
|
||||
|
||||
#if 0
|
||||
// Silu, the formular is not so good to do inline asm (dependency)
|
||||
// we put the code here purposely if in the future ppl want to try
|
||||
struct SiluAsm
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST void operator()(T& y, T& x) const
|
||||
{
|
||||
static_assert(std::is_same_v<T, float>, "Data type is not supported by this operation!");
|
||||
constexpr T one = type_convert<T>(1);
|
||||
y = x * (one / (one + ck_tile::exp(-x)));
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& y, T& x) const
|
||||
{
|
||||
static_assert(std::is_same_v<T, float>, "Data type is not supported by this operation!");
|
||||
|
||||
const uint32_t log2e_neg_ = 0x3fb8aa3b | 0x80000000; // log2e_v<float> * -1;
|
||||
|
||||
// NOTE: x/y can't be same register before inline asm
|
||||
// "+v" as y, "v" as x is not enought, x/y stil maybe put to same register
|
||||
T tmp = x;
|
||||
asm volatile("v_mul_f32 %[v_y], %[s_log2e], %[v_x]\n"
|
||||
"v_exp_f32 %[v_y], %[v_y]\n"
|
||||
"s_nop 0 ; hazard for exp\n"
|
||||
"v_add_f32 %[v_y], %[v_y], 1.0\n"
|
||||
"v_rcp_f32 %[v_y], %[v_y]\n"
|
||||
"s_nop 0 ; hazard for rcp\n"
|
||||
"v_mul_f32 %[v_y], %[v_x], %[v_y]\n"
|
||||
: [v_y] "+v"(y), [v_x] "+v"(tmp)
|
||||
: [s_log2e] "s"(log2e_neg_)
|
||||
:);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST void operator()<fp32x2_t>(fp32x2_t& y, fp32x2_t& x) const
|
||||
{
|
||||
constexpr auto one = type_convert<float>(1);
|
||||
y[0] = x[0] * (one / (one + ck_tile::exp(-x[0])));
|
||||
y[1] = x[1] * (one / (one + ck_tile::exp(-x[1])));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void operator()<fp32x2_t>(fp32x2_t& y, fp32x2_t& x) const
|
||||
{
|
||||
const uint32_t log2e_neg_ = 0x3fb8aa3b | 0x80000000; // log2e_v<float> * -1;
|
||||
|
||||
// NOTE: x/y can't be same register before inline asm
|
||||
// float tmp0 = x[0], tmp1 = x[1];
|
||||
asm volatile("v_mul_f32 %[v_y0], %[s_log2e], %[v_x0]\n"
|
||||
"v_mul_f32 %[v_y1], %[s_log2e], %[v_x1]\n"
|
||||
"v_exp_f32 %[v_y0], %[v_y0]\n"
|
||||
"v_exp_f32 %[v_y1], %[v_y1]\n"
|
||||
"v_add_f32 %[v_y0], %[v_y0], 1.0\n"
|
||||
"v_add_f32 %[v_y1], %[v_y1], 1.0\n"
|
||||
"v_rcp_f32 %[v_y0], %[v_y0]\n"
|
||||
"v_rcp_f32 %[v_y1], %[v_y1]\n"
|
||||
"v_mul_f32 %[v_y0], %[v_x0], %[v_y0]\n"
|
||||
"v_mul_f32 %[v_y1], %[v_x1], %[v_y1]\n"
|
||||
: [v_y0] "+v"(y[0]), [v_y1] "+v"(y[1]), [v_x0] "+v"(x[0]), [v_x1] "+v"(x[1])
|
||||
: [s_log2e] "s"(log2e_neg_)
|
||||
:);
|
||||
};
|
||||
};
|
||||
#endif
|
||||
|
||||
struct TanH
|
||||
{
|
||||
template <typename T>
|
||||
|
||||
@@ -234,10 +234,153 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return 32 * (128 + 8) * sizeof(bf16_t);
|
||||
// return 32 * (128 + 8) * sizeof(bf16_t);
|
||||
return MakeLdsLoadDesc_A().get_element_space_size() * sizeof(bf16_t) * 2; // 2 lds buffers
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
#define _EXPAND_ASM_ARGS_OUT_ONE_ACC \
|
||||
[s_loop_cnt]"+s"(loop_cnt), \
|
||||
[v_acc_0]"+v"(v_acc[0]), \
|
||||
[v_acc_1]"+v"(v_acc[1]), \
|
||||
[v_acc_2]"+v"(v_acc[2]), \
|
||||
[v_acc_3]"+v"(v_acc[3]), \
|
||||
[v_acc_4]"+v"(v_acc[4]), \
|
||||
[v_acc_5]"+v"(v_acc[5]), \
|
||||
[v_acc_6]"+v"(v_acc[6]), \
|
||||
[v_acc_7]"+v"(v_acc[7]), \
|
||||
[v_acc_8]"+v"(v_acc[8]), \
|
||||
[v_acc_9]"+v"(v_acc[9]), \
|
||||
[v_acc_10]"+v"(v_acc[10]), \
|
||||
[v_acc_11]"+v"(v_acc[11]), \
|
||||
[v_acc_12]"+v"(v_acc[12]), \
|
||||
[v_acc_13]"+v"(v_acc[13]), \
|
||||
[v_acc_14]"+v"(v_acc[14]), \
|
||||
[v_acc_15]"+v"(v_acc[15]), \
|
||||
[s_mem_]"+r"(smem)
|
||||
|
||||
#define _EXPAND_ASM_ARGS_OUT_TWO_ACC \
|
||||
[s_loop_cnt]"+s"(loop_cnt), \
|
||||
[v_acc_0]"+v"(v_acc[0]), \
|
||||
[v_acc_1]"+v"(v_acc[1]), \
|
||||
[v_acc_2]"+v"(v_acc[2]), \
|
||||
[v_acc_3]"+v"(v_acc[3]), \
|
||||
[v_acc_4]"+v"(v_acc[4]), \
|
||||
[v_acc_5]"+v"(v_acc[5]), \
|
||||
[v_acc_6]"+v"(v_acc[6]), \
|
||||
[v_acc_7]"+v"(v_acc[7]), \
|
||||
[v_acc_8]"+v"(v_acc[8]), \
|
||||
[v_acc_9]"+v"(v_acc[9]), \
|
||||
[v_acc_10]"+v"(v_acc[10]), \
|
||||
[v_acc_11]"+v"(v_acc[11]), \
|
||||
[v_acc_12]"+v"(v_acc[12]), \
|
||||
[v_acc_13]"+v"(v_acc[13]), \
|
||||
[v_acc_14]"+v"(v_acc[14]), \
|
||||
[v_acc_15]"+v"(v_acc[15]), \
|
||||
[v_acc_16]"+v"(v_acc[16]), \
|
||||
[v_acc_17]"+v"(v_acc[17]), \
|
||||
[v_acc_18]"+v"(v_acc[18]), \
|
||||
[v_acc_19]"+v"(v_acc[19]), \
|
||||
[v_acc_20]"+v"(v_acc[20]), \
|
||||
[v_acc_21]"+v"(v_acc[21]), \
|
||||
[v_acc_22]"+v"(v_acc[22]), \
|
||||
[v_acc_23]"+v"(v_acc[23]), \
|
||||
[v_acc_24]"+v"(v_acc[24]), \
|
||||
[v_acc_25]"+v"(v_acc[25]), \
|
||||
[v_acc_26]"+v"(v_acc[26]), \
|
||||
[v_acc_27]"+v"(v_acc[27]), \
|
||||
[v_acc_28]"+v"(v_acc[28]), \
|
||||
[v_acc_29]"+v"(v_acc[29]), \
|
||||
[v_acc_30]"+v"(v_acc[30]), \
|
||||
[v_acc_31]"+v"(v_acc[31]), \
|
||||
[s_mem_]"+r"(smem)
|
||||
|
||||
#define _EXPAND_ASM_ARGS_IN \
|
||||
[s_res_a0]"s"(res_a[0]), \
|
||||
[s_res_a1]"s"(res_a[1]), \
|
||||
[s_res_a2]"s"(res_a[2]), \
|
||||
[s_res_a3]"s"(res_a[3]), \
|
||||
[s_res_b0]"s"(res_b[0]), \
|
||||
[s_res_b1]"s"(res_b[1]), \
|
||||
[s_res_b2]"s"(res_b[2]), \
|
||||
[s_res_b3]"s"(res_b[3]), \
|
||||
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))), \
|
||||
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))), \
|
||||
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))), \
|
||||
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))), \
|
||||
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))), \
|
||||
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))), \
|
||||
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))), \
|
||||
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))), \
|
||||
\
|
||||
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))), \
|
||||
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))), \
|
||||
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))), \
|
||||
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))), \
|
||||
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))), \
|
||||
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))), \
|
||||
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))), \
|
||||
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))), \
|
||||
\
|
||||
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),\
|
||||
[s_m0_init]"s"(m0_init_value), \
|
||||
[s_size_per_issue]"s"(size_per_issue), \
|
||||
[smem_sz]"n"(smem_buf_size), \
|
||||
[sld_os_0]"n"(sld_os[number<0>{}].value), \
|
||||
[sld_os_1]"n"(sld_os[number<1>{}].value), \
|
||||
[sld_os_2]"n"(sld_os[number<2>{}].value), \
|
||||
[sld_os_3]"n"(sld_os[number<3>{}].value), \
|
||||
[sld_os_4]"n"(sld_os[number<4>{}].value), \
|
||||
[sld_os_5]"n"(sld_os[number<5>{}].value), \
|
||||
[sld_os_6]"n"(sld_os[number<6>{}].value), \
|
||||
[sld_os_7]"n"(sld_os[number<7>{}].value), \
|
||||
[s_tile_os_a]"s"(tile_offset_a_bytes), \
|
||||
[s_tile_os_b]"s"(tile_offset_b_bytes)
|
||||
|
||||
#define _EXPAND_ASM_ARGS_CLOBBER \
|
||||
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", \
|
||||
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", \
|
||||
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29", \
|
||||
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39", \
|
||||
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49", \
|
||||
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59", \
|
||||
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69", \
|
||||
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79", \
|
||||
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89", \
|
||||
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99", \
|
||||
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107", \
|
||||
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115", \
|
||||
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123", \
|
||||
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131", \
|
||||
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139", \
|
||||
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147", \
|
||||
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155", \
|
||||
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163", \
|
||||
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171", \
|
||||
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179", \
|
||||
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187", \
|
||||
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195", \
|
||||
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203", \
|
||||
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211", \
|
||||
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219", \
|
||||
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227", \
|
||||
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235", \
|
||||
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", \
|
||||
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", \
|
||||
"a252", "a253", "a254", "a255", \
|
||||
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23", \
|
||||
"s86", \
|
||||
"v64", "v65", "v66", "v67", "v68", "v69", \
|
||||
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79", \
|
||||
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89", \
|
||||
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99", \
|
||||
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107", \
|
||||
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115", \
|
||||
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123", \
|
||||
"v124", "v125", "v126", "v127"
|
||||
// clang-format on
|
||||
|
||||
struct Flatmm_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_32x512x128_1x4x1_16x16x32_Base
|
||||
{
|
||||
using ADataType = bf16_t;
|
||||
@@ -245,7 +388,9 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_32x512x128_1x4x1_16
|
||||
|
||||
// TODO: need paired with tile_window_linear!
|
||||
// TODO: need call init_raw() before call this function!
|
||||
template <typename ARes, typename ACoords, typename BRes, typename BCoords>
|
||||
// Is2B: originally for B matrix we have 2 prefetch buffers. If set this to true
|
||||
// we can support A matric serve 2 B matrix, B0/B1, each B0/B1 still have same tile size
|
||||
template <typename ARes, typename ACoords, typename BRes, typename BCoords, bool Is2B = false>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const ARes& res_a,
|
||||
const ACoords& cached_coords_a,
|
||||
@@ -254,7 +399,8 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_32x512x128_1x4x1_16
|
||||
CK_TILE_LDS_ADDR void* smem,
|
||||
index_t k,
|
||||
index_t tile_offset_a, // for each tile, the offset to move for each unroll
|
||||
index_t tile_offset_b) // for each tile, the offset to move for each unroll
|
||||
index_t tile_offset_b,
|
||||
bool_constant<Is2B> = {}) // for each tile, the offset to move for each unroll
|
||||
{
|
||||
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
|
||||
static_assert(BCoords::size() == Repeat_N);
|
||||
@@ -299,129 +445,78 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_32x512x128_1x4x1_16
|
||||
|
||||
index_t loop_cnt = k / Block_K;
|
||||
|
||||
// this is the acc thread buffer
|
||||
fp32x4_t v_acc[16]{.0f};
|
||||
if constexpr(Is2B)
|
||||
{
|
||||
// this is the acc thread buffer
|
||||
fp32x4_t v_acc[32]{.0f};
|
||||
|
||||
// B nr->kr
|
||||
// B nr->kr
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Winline-asm"
|
||||
// clang-format off
|
||||
asm volatile(
|
||||
// clang-format off
|
||||
asm volatile(
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
|
||||
#define CK_TILE_FLATMM_UK_2B 1
|
||||
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
|
||||
#undef CK_TILE_FLATMM_UK_MFMA
|
||||
: [s_loop_cnt]"+s"(loop_cnt),
|
||||
[v_acc_0]"+v"(v_acc[0]),
|
||||
[v_acc_1]"+v"(v_acc[1]),
|
||||
[v_acc_2]"+v"(v_acc[2]),
|
||||
[v_acc_3]"+v"(v_acc[3]),
|
||||
[v_acc_4]"+v"(v_acc[4]),
|
||||
[v_acc_5]"+v"(v_acc[5]),
|
||||
[v_acc_6]"+v"(v_acc[6]),
|
||||
[v_acc_7]"+v"(v_acc[7]),
|
||||
[v_acc_8]"+v"(v_acc[8]),
|
||||
[v_acc_9]"+v"(v_acc[9]),
|
||||
[v_acc_10]"+v"(v_acc[10]),
|
||||
[v_acc_11]"+v"(v_acc[11]),
|
||||
[v_acc_12]"+v"(v_acc[12]),
|
||||
[v_acc_13]"+v"(v_acc[13]),
|
||||
[v_acc_14]"+v"(v_acc[14]),
|
||||
[v_acc_15]"+v"(v_acc[15]),
|
||||
[s_mem_]"+r"(smem)
|
||||
: [s_res_a0]"s"(res_a[0]),
|
||||
[s_res_a1]"s"(res_a[1]),
|
||||
[s_res_a2]"s"(res_a[2]),
|
||||
[s_res_a3]"s"(res_a[3]),
|
||||
[s_res_b0]"s"(res_b[0]),
|
||||
[s_res_b1]"s"(res_b[1]),
|
||||
[s_res_b2]"s"(res_b[2]),
|
||||
[s_res_b3]"s"(res_b[3]),
|
||||
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))),
|
||||
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))),
|
||||
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))),
|
||||
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))),
|
||||
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))),
|
||||
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))),
|
||||
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))),
|
||||
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))),
|
||||
|
||||
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
|
||||
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
|
||||
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
|
||||
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
|
||||
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
|
||||
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
|
||||
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
|
||||
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
|
||||
|
||||
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),
|
||||
[s_m0_init]"s"(m0_init_value),
|
||||
[s_size_per_issue]"s"(size_per_issue),
|
||||
[smem_sz]"n"(smem_buf_size), //(smem_buf_size),
|
||||
[sld_os_0]"n"(sld_os[number<0>{}].value),
|
||||
[sld_os_1]"n"(sld_os[number<1>{}].value),
|
||||
[sld_os_2]"n"(sld_os[number<2>{}].value),
|
||||
[sld_os_3]"n"(sld_os[number<3>{}].value),
|
||||
[sld_os_4]"n"(sld_os[number<4>{}].value),
|
||||
[sld_os_5]"n"(sld_os[number<5>{}].value),
|
||||
[sld_os_6]"n"(sld_os[number<6>{}].value),
|
||||
[sld_os_7]"n"(sld_os[number<7>{}].value),
|
||||
[s_tile_os_a]"s"(tile_offset_a_bytes),
|
||||
[s_tile_os_b]"s"(tile_offset_b_bytes)
|
||||
: "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
|
||||
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
|
||||
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
|
||||
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
|
||||
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
|
||||
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
|
||||
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
|
||||
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
|
||||
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
|
||||
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
|
||||
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
|
||||
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
|
||||
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
|
||||
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
|
||||
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
|
||||
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
|
||||
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
|
||||
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
|
||||
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
|
||||
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
|
||||
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
|
||||
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
|
||||
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
|
||||
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
|
||||
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
|
||||
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
|
||||
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
|
||||
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
|
||||
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
|
||||
"a252", "a253", "a254", "a255",
|
||||
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
|
||||
"s86", // s86 as tmp
|
||||
"v64", "v65", "v66", "v67", "v68", "v69",
|
||||
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79",
|
||||
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89",
|
||||
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99",
|
||||
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107",
|
||||
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115",
|
||||
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123",
|
||||
"v124", "v125", "v126", "v127"
|
||||
);
|
||||
// clang-format on
|
||||
: _EXPAND_ASM_ARGS_OUT_TWO_ACC
|
||||
: _EXPAND_ASM_ARGS_IN,
|
||||
[s_res_b4]"s"(res_b[4]),
|
||||
[s_res_b5]"s"(res_b[5]),
|
||||
[s_res_b6]"s"(res_b[6]),
|
||||
[s_res_b7]"s"(res_b[7])
|
||||
: _EXPAND_ASM_ARGS_CLOBBER, "s24", "s25", "s26", "s27"
|
||||
);
|
||||
// clang-format on
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
// return local scratch
|
||||
auto c = MakeCBlockTile();
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
|
||||
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
|
||||
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
|
||||
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
|
||||
// return local scratch
|
||||
auto c = make_tuple(MakeCBlockTile(), MakeCBlockTile());
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.at(number<0>{}).get_thread_buffer()[4 * i + 0] = v_acc[i].x;
|
||||
c.at(number<0>{}).get_thread_buffer()[4 * i + 1] = v_acc[i].y;
|
||||
c.at(number<0>{}).get_thread_buffer()[4 * i + 2] = v_acc[i].z;
|
||||
c.at(number<0>{}).get_thread_buffer()[4 * i + 3] = v_acc[i].w;
|
||||
}
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.at(number<1>{}).get_thread_buffer()[4 * i + 0] = v_acc[16 + i].x;
|
||||
c.at(number<1>{}).get_thread_buffer()[4 * i + 1] = v_acc[16 + i].y;
|
||||
c.at(number<1>{}).get_thread_buffer()[4 * i + 2] = v_acc[16 + i].z;
|
||||
c.at(number<1>{}).get_thread_buffer()[4 * i + 3] = v_acc[16 + i].w;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
else
|
||||
{
|
||||
// this is the acc thread buffer
|
||||
fp32x4_t v_acc[16]{.0f};
|
||||
|
||||
// B nr->kr
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Winline-asm"
|
||||
// clang-format off
|
||||
asm volatile(
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
|
||||
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
|
||||
: _EXPAND_ASM_ARGS_OUT_ONE_ACC
|
||||
: _EXPAND_ASM_ARGS_IN
|
||||
: _EXPAND_ASM_ARGS_CLOBBER
|
||||
);
|
||||
// clang-format on
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
// return local scratch
|
||||
auto c = MakeCBlockTile();
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
|
||||
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
|
||||
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
|
||||
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -432,7 +527,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_FP16 : public Flatmm_32x512x128_1x4x1_16
|
||||
|
||||
// TODO: need paired with tile_window_linear!
|
||||
// TODO: need call init_raw() before call this function!
|
||||
template <typename ARes, typename ACoords, typename BRes, typename BCoords>
|
||||
template <typename ARes, typename ACoords, typename BRes, typename BCoords, bool Is2B = false>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const ARes& res_a,
|
||||
const ACoords& cached_coords_a,
|
||||
@@ -441,7 +536,8 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_FP16 : public Flatmm_32x512x128_1x4x1_16
|
||||
CK_TILE_LDS_ADDR void* smem,
|
||||
index_t k,
|
||||
index_t tile_offset_a, // for each tile, the offset to move for each unroll
|
||||
index_t tile_offset_b) // for each tile, the offset to move for each unroll
|
||||
index_t tile_offset_b, // for each tile, the offset to move for each unroll
|
||||
bool_constant<Is2B> = {})
|
||||
{
|
||||
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
|
||||
static_assert(BCoords::size() == Repeat_N);
|
||||
@@ -486,130 +582,82 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_FP16 : public Flatmm_32x512x128_1x4x1_16
|
||||
|
||||
index_t loop_cnt = k / Block_K;
|
||||
|
||||
// this is the acc thread buffer
|
||||
fp32x4_t v_acc[16]{.0f};
|
||||
if constexpr(Is2B)
|
||||
{
|
||||
// this is the acc thread buffer
|
||||
fp32x4_t v_acc[32]{.0f};
|
||||
|
||||
// B nr->kr
|
||||
// B nr->kr
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Winline-asm"
|
||||
// clang-format off
|
||||
asm volatile(
|
||||
// clang-format off
|
||||
asm volatile(
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
|
||||
#define CK_TILE_FLATMM_UK_2B 1
|
||||
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
|
||||
#undef CK_TILE_FLATMM_UK_MFMA
|
||||
: [s_loop_cnt]"+s"(loop_cnt),
|
||||
[v_acc_0]"+v"(v_acc[0]),
|
||||
[v_acc_1]"+v"(v_acc[1]),
|
||||
[v_acc_2]"+v"(v_acc[2]),
|
||||
[v_acc_3]"+v"(v_acc[3]),
|
||||
[v_acc_4]"+v"(v_acc[4]),
|
||||
[v_acc_5]"+v"(v_acc[5]),
|
||||
[v_acc_6]"+v"(v_acc[6]),
|
||||
[v_acc_7]"+v"(v_acc[7]),
|
||||
[v_acc_8]"+v"(v_acc[8]),
|
||||
[v_acc_9]"+v"(v_acc[9]),
|
||||
[v_acc_10]"+v"(v_acc[10]),
|
||||
[v_acc_11]"+v"(v_acc[11]),
|
||||
[v_acc_12]"+v"(v_acc[12]),
|
||||
[v_acc_13]"+v"(v_acc[13]),
|
||||
[v_acc_14]"+v"(v_acc[14]),
|
||||
[v_acc_15]"+v"(v_acc[15]),
|
||||
[s_mem_]"+r"(smem)
|
||||
: [s_res_a0]"s"(res_a[0]),
|
||||
[s_res_a1]"s"(res_a[1]),
|
||||
[s_res_a2]"s"(res_a[2]),
|
||||
[s_res_a3]"s"(res_a[3]),
|
||||
[s_res_b0]"s"(res_b[0]),
|
||||
[s_res_b1]"s"(res_b[1]),
|
||||
[s_res_b2]"s"(res_b[2]),
|
||||
[s_res_b3]"s"(res_b[3]),
|
||||
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))),
|
||||
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))),
|
||||
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))),
|
||||
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))),
|
||||
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))),
|
||||
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))),
|
||||
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))),
|
||||
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))),
|
||||
|
||||
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
|
||||
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
|
||||
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
|
||||
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
|
||||
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
|
||||
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
|
||||
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
|
||||
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
|
||||
|
||||
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),
|
||||
[s_m0_init]"s"(m0_init_value),
|
||||
[s_size_per_issue]"s"(size_per_issue),
|
||||
[smem_sz]"n"(smem_buf_size), //(smem_buf_size),
|
||||
[sld_os_0]"n"(sld_os[number<0>{}].value),
|
||||
[sld_os_1]"n"(sld_os[number<1>{}].value),
|
||||
[sld_os_2]"n"(sld_os[number<2>{}].value),
|
||||
[sld_os_3]"n"(sld_os[number<3>{}].value),
|
||||
[sld_os_4]"n"(sld_os[number<4>{}].value),
|
||||
[sld_os_5]"n"(sld_os[number<5>{}].value),
|
||||
[sld_os_6]"n"(sld_os[number<6>{}].value),
|
||||
[sld_os_7]"n"(sld_os[number<7>{}].value),
|
||||
[s_tile_os_a]"s"(tile_offset_a_bytes),
|
||||
[s_tile_os_b]"s"(tile_offset_b_bytes)
|
||||
: "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
|
||||
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
|
||||
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
|
||||
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
|
||||
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
|
||||
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
|
||||
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
|
||||
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
|
||||
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
|
||||
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
|
||||
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
|
||||
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
|
||||
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
|
||||
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
|
||||
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
|
||||
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
|
||||
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
|
||||
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
|
||||
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
|
||||
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
|
||||
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
|
||||
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
|
||||
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
|
||||
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
|
||||
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
|
||||
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
|
||||
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
|
||||
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
|
||||
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
|
||||
"a252", "a253", "a254", "a255",
|
||||
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
|
||||
"s86", // s86 as tmp
|
||||
"v64", "v65", "v66", "v67", "v68", "v69",
|
||||
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79",
|
||||
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89",
|
||||
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99",
|
||||
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107",
|
||||
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115",
|
||||
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123",
|
||||
"v124", "v125", "v126", "v127"
|
||||
);
|
||||
// clang-format on
|
||||
: _EXPAND_ASM_ARGS_OUT_TWO_ACC
|
||||
: _EXPAND_ASM_ARGS_IN,
|
||||
[s_res_b4]"s"(res_b[4]),
|
||||
[s_res_b5]"s"(res_b[5]),
|
||||
[s_res_b6]"s"(res_b[6]),
|
||||
[s_res_b7]"s"(res_b[7])
|
||||
: _EXPAND_ASM_ARGS_CLOBBER, "s24", "s25", "s26", "s27"
|
||||
);
|
||||
// clang-format on
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
// return local scratch
|
||||
auto c = MakeCBlockTile();
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
|
||||
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
|
||||
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
|
||||
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
|
||||
// return local scratch
|
||||
auto c = make_tuple(MakeCBlockTile(), MakeCBlockTile());
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.at(number<0>{}).get_thread_buffer()[4 * i + 0] = v_acc[i].x;
|
||||
c.at(number<0>{}).get_thread_buffer()[4 * i + 1] = v_acc[i].y;
|
||||
c.at(number<0>{}).get_thread_buffer()[4 * i + 2] = v_acc[i].z;
|
||||
c.at(number<0>{}).get_thread_buffer()[4 * i + 3] = v_acc[i].w;
|
||||
}
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.at(number<1>{}).get_thread_buffer()[4 * i + 0] = v_acc[16 + i].x;
|
||||
c.at(number<1>{}).get_thread_buffer()[4 * i + 1] = v_acc[16 + i].y;
|
||||
c.at(number<1>{}).get_thread_buffer()[4 * i + 2] = v_acc[16 + i].z;
|
||||
c.at(number<1>{}).get_thread_buffer()[4 * i + 3] = v_acc[16 + i].w;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
else
|
||||
{
|
||||
// this is the acc thread buffer
|
||||
fp32x4_t v_acc[16]{.0f};
|
||||
|
||||
// B nr->kr
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Winline-asm"
|
||||
// clang-format off
|
||||
asm volatile(
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
|
||||
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
|
||||
: _EXPAND_ASM_ARGS_OUT_ONE_ACC
|
||||
: _EXPAND_ASM_ARGS_IN
|
||||
: _EXPAND_ASM_ARGS_CLOBBER
|
||||
);
|
||||
// clang-format on
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
// return local scratch
|
||||
auto c = MakeCBlockTile();
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
|
||||
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
|
||||
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
|
||||
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
};
|
||||
|
||||
#undef _EXPAND_ASM_ARGS_OUT_ONE_ACC
|
||||
#undef _EXPAND_ASM_ARGS_OUT_TWO_ACC
|
||||
#undef _EXPAND_ASM_ARGS_IN
|
||||
#undef _EXPAND_ASM_ARGS_CLOBBER
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -65,7 +65,8 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_Base
|
||||
// in LDS we need store as
|
||||
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
|
||||
// y y wave-id lid/16 lid%16 v
|
||||
return 2 * 2 * 4 * 4 * (16 * 4 + 4) * sizeof(bf16_t);
|
||||
constexpr index_t nbufs = 2;
|
||||
return 2 * 2 * 4 * 4 * (16 * 4 + 4) * sizeof(bf16_t) * nbufs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -173,7 +174,6 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16 : public FlatmmSn_32x128x512_1x4x
|
||||
asm volatile(
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
|
||||
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc"
|
||||
#undef CK_TILE_FLATMM_UK_MFMA
|
||||
:[smem_]"+r"(smem),
|
||||
[s_loop_cnt]"+s"(loop_cnt),
|
||||
[c0]"+v" (v_c0),
|
||||
@@ -418,7 +418,6 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16 : public FlatmmSn_32x128x512_1x4x
|
||||
asm volatile(
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
|
||||
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc"
|
||||
#undef CK_TILE_FLATMM_UK_MFMA
|
||||
:[smem_]"+r"(smem),
|
||||
[s_loop_cnt]"+s"(loop_cnt),
|
||||
[c0]"+v" (v_c0),
|
||||
|
||||
@@ -477,7 +477,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x512_
|
||||
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
|
||||
"a252", "a253", "a254", "a255",
|
||||
"s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
|
||||
"s36", "s37","s59","s80",
|
||||
"s36", "s37", "s56", "s59", "s60", "s80",
|
||||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
|
||||
"v50", "v54", "v55",
|
||||
"v64","v65","v66","v67","v68","v69","v70","v71",
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// clang-format off
|
||||
|
||||
// define the CK_TILE_** macro before include this file to change kernel variation
|
||||
// we will undef everything defined in this file
|
||||
|
||||
#ifndef CK_TILE_FLATMM_UK_MFMA
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
|
||||
#endif
|
||||
@@ -816,3 +823,5 @@
|
||||
#undef _UK_MFMA_
|
||||
#undef _UK_PK_CVT_
|
||||
#undef _UK_ATOMIC_ADD_
|
||||
#undef CK_TILE_FLATMM_UK_MFMA
|
||||
// clang-format on
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -111,7 +111,7 @@ struct FusedMoeGemmHostArgs
|
||||
const void* num_sorted_tiles_ptr; // [1]
|
||||
|
||||
index_t hidden_size; // k
|
||||
index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
|
||||
index_t intermediate_size; // n / TP, for Gate/UP/Down
|
||||
index_t num_tokens; // input number of tokens for current iteration
|
||||
index_t num_experts; // number of groups
|
||||
index_t topk; // need this?
|
||||
@@ -178,7 +178,7 @@ struct FusedMoeGemmKernel
|
||||
return base_str;
|
||||
}();
|
||||
|
||||
return _SS_("fused_moe_") + _SS_(prec_str) + "_" +
|
||||
return _SS_("fused_moe_") + _SS_(prec_str) + "_" + (IsGateOnly ? "g1u0_":"g1u1_") +
|
||||
_TS_(S_::Block_M0) + "x" + _TS_(S_::Block_N0) + "x" + _TS_(S_::Block_K0) + "x" + _TS_(S_::Block_N1) + "_" +
|
||||
_TS_(S_::WarpPerBlock_M0) + "x" + _TS_(S_::WarpPerBlock_N0) + "x" + _TS_(S_::WarpPerBlock_K0) + "_" +
|
||||
_TS_(S_::Warp_M0) + "x" + _TS_(S_::Warp_N0) + "x" + _TS_(S_::Warp_K0) + "_" + _SS_(Pipeline::name);
|
||||
@@ -204,7 +204,7 @@ struct FusedMoeGemmKernel
|
||||
const void* num_sorted_tiles_ptr;
|
||||
|
||||
index_t hidden_size; // k
|
||||
index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
|
||||
index_t intermediate_size; // n / TP, for Gate/Up/Down
|
||||
index_t num_tokens; // input number of tokens for current iteration
|
||||
index_t num_experts; // number of groups
|
||||
index_t topk; // need this?
|
||||
@@ -239,7 +239,7 @@ struct FusedMoeGemmKernel
|
||||
{
|
||||
if constexpr(UseUK)
|
||||
{
|
||||
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
|
||||
__shared__ CK_TILE_LDS_ADDR char smem[GetSmemSize()];
|
||||
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
|
||||
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
|
||||
|
||||
@@ -298,6 +298,9 @@ struct FusedMoeGemmKernel
|
||||
|
||||
index_t token_id =
|
||||
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
token_id &= 0xffffff;
|
||||
#endif
|
||||
auto topk_weight = reinterpret_cast<const TopkWeightDataType*>(
|
||||
kargs.sorted_weight_ptr)[sorted_token_id];
|
||||
|
||||
|
||||
@@ -70,11 +70,16 @@ struct FusedMoeGemmPipeline_FlatmmUk
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
#if 1
|
||||
constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
|
||||
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
|
||||
constexpr index_t smem_bridge =
|
||||
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
|
||||
return max(smem_0, max(smem_1, smem_bridge));
|
||||
return max(smem_0 + smem_1, smem_bridge);
|
||||
#else
|
||||
// keep it here purposely in case we have regression
|
||||
return 65536;
|
||||
#endif
|
||||
}
|
||||
|
||||
// this is the thread-offset along row/col
|
||||
@@ -125,6 +130,9 @@ struct FusedMoeGemmPipeline_FlatmmUk
|
||||
array<index_t, n_size> row_ids;
|
||||
static_for<0, n_size, 1>{}([&](auto i) {
|
||||
row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans;
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
row_ids.at(i) &= 0xffffff;
|
||||
#endif
|
||||
});
|
||||
|
||||
return row_ids;
|
||||
@@ -164,9 +172,12 @@ struct FusedMoeGemmPipeline_FlatmmUk
|
||||
index_t sorted_tile_id,
|
||||
index_t intermediate_tile_id)
|
||||
{
|
||||
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
|
||||
ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size;
|
||||
ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size / hidden_radio_0;
|
||||
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
|
||||
ck_tile::index_t shared_intermediate_size_0 =
|
||||
kargs.intermediate_size * hidden_radio_0; // total gate+up
|
||||
ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size;
|
||||
|
||||
// after weight shuffling, gate-only: [nr0, kr0, w0], gate+up: [nr0_gate + nr0_up, kr0, w0]
|
||||
|
||||
index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W
|
||||
index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W
|
||||
@@ -200,29 +211,35 @@ struct FusedMoeGemmPipeline_FlatmmUk
|
||||
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
|
||||
kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
|
||||
|
||||
auto g_win = [&]() {
|
||||
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
|
||||
static_cast<long_index_t>(expert_id) * expert_stride_0 +
|
||||
interm_idx_nr0 * kr_0 * BlockShape::Block_W0;
|
||||
auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
g_ptr,
|
||||
auto make_gu_win = [&](const auto* ptr_) {
|
||||
auto view_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
ptr_,
|
||||
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
|
||||
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
|
||||
number<kAlignmentG>{},
|
||||
number<1>{});
|
||||
|
||||
auto g_window_ = make_tile_window_linear_raw(
|
||||
g_view_,
|
||||
auto win_ = make_tile_window_linear_raw(
|
||||
view_,
|
||||
make_tuple(number<BlockShape::Block_Nr0>{},
|
||||
number<BlockShape::Block_Kr0>{},
|
||||
number<BlockShape::Block_W0>{}),
|
||||
{0, 0, 0},
|
||||
Policy::template MakeGlobalTileDistribution_G<Problem>(),
|
||||
sequence<0, 1, 1>{});
|
||||
return g_window_;
|
||||
}();
|
||||
return win_;
|
||||
};
|
||||
|
||||
const GDataType* gu_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
|
||||
static_cast<long_index_t>(expert_id) * expert_stride_0 +
|
||||
interm_idx_nr0 * kr_0 * BlockShape::Block_W0;
|
||||
|
||||
auto g_win = make_gu_win(gu_ptr);
|
||||
// Note: gu swizzled, [nr_u+nr_g, kr, w], hence base offset to up is just interm*hidden
|
||||
auto u_win = make_gu_win(gu_ptr + kargs.intermediate_size * kargs.hidden_size);
|
||||
|
||||
auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
|
||||
auto u_res = u_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
|
||||
auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); },
|
||||
number<decltype(g_win)::NumAccess_NonLinear>{});
|
||||
|
||||
@@ -309,28 +326,73 @@ struct FusedMoeGemmPipeline_FlatmmUk
|
||||
auto w_scale = GetWeightScale(
|
||||
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
|
||||
|
||||
auto uk_0 = Policy::template GetUK_0<Problem>();
|
||||
auto acc_0 = uk_0(a_res,
|
||||
a_coords,
|
||||
g_res,
|
||||
g_coords,
|
||||
smem,
|
||||
kargs.hidden_size,
|
||||
BlockShape::Block_K0, // tile offset for B matrix each unroll
|
||||
BlockShape::Block_Kr0 *
|
||||
BlockShape::Block_W0); // tile offset for B matrix each unroll
|
||||
auto uk_0 = Policy::template GetUK_0<Problem>();
|
||||
|
||||
sweep_tile(
|
||||
acc_0,
|
||||
[&](auto idx0, auto idx1) {
|
||||
fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
|
||||
typename Problem::GateActivation{}(v_, v_);
|
||||
acc_0(idx0) = v_.x;
|
||||
acc_0(idx1) = v_.y;
|
||||
},
|
||||
sequence<1, 2>{});
|
||||
auto y_pre = [&]() {
|
||||
if constexpr(IsGateOnly)
|
||||
{
|
||||
auto acc_0 = uk_0(a_res,
|
||||
a_coords,
|
||||
g_res,
|
||||
g_coords,
|
||||
smem,
|
||||
kargs.hidden_size,
|
||||
BlockShape::Block_K0, // tile offset for B matrix each unroll
|
||||
BlockShape::Block_Kr0 *
|
||||
BlockShape::Block_W0); // tile offset for B matrix each unroll
|
||||
|
||||
auto y_pre = cast_tile<YDataType>(acc_0);
|
||||
sweep_tile(
|
||||
acc_0,
|
||||
[&](auto idx0, auto idx1) {
|
||||
fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
|
||||
typename Problem::GateActivation{}(v_, v_);
|
||||
acc_0(idx0) = v_.x;
|
||||
acc_0(idx1) = v_.y;
|
||||
},
|
||||
sequence<1, 2>{});
|
||||
|
||||
return cast_tile<YDataType>(acc_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32x8_t gu_res;
|
||||
gu_res[0] = g_res[0];
|
||||
gu_res[1] = g_res[1];
|
||||
gu_res[2] = g_res[2];
|
||||
gu_res[3] = g_res[3];
|
||||
gu_res[4] = u_res[0];
|
||||
gu_res[5] = u_res[1];
|
||||
gu_res[6] = u_res[2];
|
||||
gu_res[7] = u_res[3];
|
||||
|
||||
auto acc_0 = uk_0(a_res,
|
||||
a_coords,
|
||||
gu_res,
|
||||
g_coords,
|
||||
smem,
|
||||
kargs.hidden_size,
|
||||
BlockShape::Block_K0, // tile offset for B matrix each unroll
|
||||
BlockShape::Block_Kr0 * BlockShape::Block_W0,
|
||||
bool_constant<true>{}); // tile offset for B matrix each unroll
|
||||
|
||||
sweep_tile(
|
||||
acc_0.at(number<0>{}),
|
||||
[&](auto idx0, auto idx1) {
|
||||
fp32x2_t v_{acc_0.at(number<0>{})(idx0), acc_0.at(number<0>{})(idx1)};
|
||||
typename Problem::GateActivation{}(v_, v_);
|
||||
acc_0.at(number<0>{})(idx0) = v_.x;
|
||||
acc_0.at(number<0>{})(idx1) = v_.y;
|
||||
},
|
||||
sequence<1, 2>{});
|
||||
|
||||
auto reduced_acc_0 =
|
||||
tile_elementwise_in([&](const auto& a_, const auto& b_) { return a_ * b_; },
|
||||
acc_0.at(number<0>{}),
|
||||
acc_0.at(number<1>{}));
|
||||
|
||||
return cast_tile<YDataType>(reduced_acc_0);
|
||||
}
|
||||
}();
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user