refine activation code & complete moe example

This commit is contained in:
lalala-sh
2025-04-11 02:18:16 +00:00
parent d2a82f56e2
commit 0353337447
4 changed files with 59 additions and 76 deletions

View File

@@ -131,7 +131,7 @@ static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t ActOP = 2; // 0: gelu_and_mul, 2: silu_and_mul
static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,

View File

@@ -62,7 +62,7 @@ struct MulABScale
(void)d0;
(void)d1;
#if CK_USE_PK4_LAYOUT_SHUFFLE
e = ck::type_convert<EDataType>(c * 16);
e = ck::type_convert<EDataType>(c);
#else
e = ck::type_convert<EDataType>(c);
#endif
@@ -74,7 +74,7 @@ struct MulABScale
(void)d0;
(void)d1;
#if CK_USE_PK4_LAYOUT_SHUFFLE
e = ck::type_convert<EDataType>(c * 16);
e = ck::type_convert<EDataType>(c);
#else
e = ck::type_convert<EDataType>(c);
#endif
@@ -125,7 +125,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 2; // 0: gelu_and_mul, 2: silu_and_mul
static constexpr ck::index_t Act_OP = 1; // 0: gelu_and_mul, 1: silu_and_mul
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
Row, Col, DsLayout, ELayout,
@@ -203,7 +203,7 @@ int main(int argc, char* argv[])
{
expert_ids.mData[i] = eids[i];
}
int token_per_tile = tokens * topk / valid_tile_num;
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;
for(int i = 0; i < sorted_size; i++)
{
@@ -479,7 +479,7 @@ int main(int argc, char* argv[])
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
return ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
? 0
: 1;
}

View File

@@ -29,9 +29,8 @@ namespace ck {
enum Activation
{
gelu = 0,
silu = 1,
swiglu = 2
gelu_and_mul = 0,
silu_and_mul = 1
};
template <typename GridwiseGemm,
@@ -1405,6 +1404,12 @@ struct GridwiseMoeGemm
// mul scales
const float* p_sorted_weights_0 = p_ds_grid[I0];
const float* p_scale_b = p_ds_grid[I1];
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
static_assert(M4 == 4);
const index_t m1 = get_warp_local_1d_id() / NWave;
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
{
if constexpr(PerTokenQuant)
@@ -1418,10 +1423,6 @@ struct GridwiseMoeGemm
p_scale_b += expert_id;
}
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
static_assert(M4 == 4);
const index_t m1 = get_warp_local_1d_id() / NWave;
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
vector_type<int32_t, 4> scale_token_ids;
vector_type<float, 4> topk_weights; // for gemm2 only
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
@@ -1462,27 +1463,7 @@ struct GridwiseMoeGemm
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion
{
if constexpr(ActivationOperation == Activation::silu)
{
tensor_operation::element_wise::Silu{}(c_thread_buf(cidx),
c_thread_buf(cidx));
}
else if(ActivationOperation == Activation::gelu)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
auto gate = scale_a * scale_b * c_thread_buf[cidx];
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
else if(ActivationOperation == Activation::swiglu)
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
@@ -1497,6 +1478,21 @@ struct GridwiseMoeGemm
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
auto gate = scale_a * scale_b * c_thread_buf[cidx];
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
}
else
{
@@ -1511,10 +1507,6 @@ struct GridwiseMoeGemm
}
else
{
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
static_assert(M4 == 4);
const index_t m1 = get_warp_local_1d_id() / NWave;
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
vector_type<float, 4> topk_weights; // for gemm2 only
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
@@ -1533,25 +1525,20 @@ struct GridwiseMoeGemm
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion
{
if constexpr(ActivationOperation == Activation::silu)
{
tensor_operation::element_wise::Silu{}(c_thread_buf(cidx),
c_thread_buf(cidx));
}
else if(ActivationOperation == Activation::gelu)
{
auto gate = c_thread_buf[cidx];
auto up = c_thread_buf_up[cidx];
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
else if(ActivationOperation == Activation::swiglu)
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
auto gate = c_thread_buf[cidx];
auto up = c_thread_buf_up[cidx];
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
auto gate = c_thread_buf[cidx];
auto up = c_thread_buf_up[cidx];
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
}
else
{

View File

@@ -80,10 +80,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
float Run(const Argument& arg)
{
if constexpr(ActivationType > 2)
{
static_assert(false, "Not supported activation type");
}
static_assert(ActivationType < 2, "Not supported activation type");
const int full_n = arg.c_t_k_n_.mDesc.GetLengths()[2];
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_t_k_.mDesc.GetLengths()[1];
@@ -148,44 +145,43 @@ struct ReferenceMoeGemm : public device::BaseOperator
else
{
arg.b_element_op_(v_b, arg.b_e_n_k_(e, k, n));
if constexpr(ActivationType == 2)
{
arg.b_element_op_(v_b_up, arg.b_e_n_k_(e, k, n + full_n));
}
arg.b_element_op_(v_b_up, arg.b_e_n_k_(e, k, n + full_n));
}
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
if constexpr(ActivationType == 2)
{
v_acc_up += ck::type_convert<AccDataType>(v_a) *
ck::type_convert<AccDataType>(v_b_up);
}
v_acc_up += ck::type_convert<AccDataType>(v_a) *
ck::type_convert<AccDataType>(v_b_up);
}
CDataType v_c{0};
CDataType v_c_up{0};
arg.c_element_op_(v_c, v_acc);
if constexpr(ActivationType == 2)
if constexpr(ActivationType == 1)
{
arg.c_element_op_(v_c_up, v_acc_up);
v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t);
v_c = v_c * (1.0 / (1.0 + math::exp(-v_c)));
v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t);
if constexpr(is_same_v<BDataType, pk_i4_t>)
{
v_c_up *= 16;
v_c *= 16;
}
tensor_operation::element_wise::Silu{}(v_c, v_c);
v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t);
arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up;
}
else
else if constexpr(ActivationType == 0)
{
if constexpr(ActivationType == 1)
arg.c_element_op_(v_c_up, v_acc_up);
v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t);
if constexpr(is_same_v<BDataType, pk_i4_t>)
{
tensor_operation::element_wise::Silu{}(v_c, v_c);
v_c_up *= 16;
v_c *= 16;
}
else if constexpr(ActivationType == 0)
{
tensor_operation::element_wise::Gelu{}(v_c, v_c);
}
arg.c_t_k_n_(t, topk_id, n) = v_c;
tensor_operation::element_wise::Gelu{}(v_c, v_c);
v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t);
arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up;
}
}
};