refine activation code & complete moe example

This commit is contained in:
lalala-sh
2025-04-11 02:18:16 +00:00
parent d2a82f56e2
commit 0353337447
4 changed files with 59 additions and 76 deletions

View File

@@ -29,9 +29,8 @@ namespace ck {
enum Activation
{
gelu = 0,
silu = 1,
swiglu = 2
gelu_and_mul = 0,
silu_and_mul = 1
};
template <typename GridwiseGemm,
@@ -1405,6 +1404,12 @@ struct GridwiseMoeGemm
// mul scales
const float* p_sorted_weights_0 = p_ds_grid[I0];
const float* p_scale_b = p_ds_grid[I1];
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
static_assert(M4 == 4);
const index_t m1 = get_warp_local_1d_id() / NWave;
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
{
if constexpr(PerTokenQuant)
@@ -1418,10 +1423,6 @@ struct GridwiseMoeGemm
p_scale_b += expert_id;
}
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
static_assert(M4 == 4);
const index_t m1 = get_warp_local_1d_id() / NWave;
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
vector_type<int32_t, 4> scale_token_ids;
vector_type<float, 4> topk_weights; // for gemm2 only
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
@@ -1462,27 +1463,7 @@ struct GridwiseMoeGemm
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion
{
if constexpr(ActivationOperation == Activation::silu)
{
tensor_operation::element_wise::Silu{}(c_thread_buf(cidx),
c_thread_buf(cidx));
}
else if(ActivationOperation == Activation::gelu)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
auto gate = scale_a * scale_b * c_thread_buf[cidx];
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
else if(ActivationOperation == Activation::swiglu)
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
@@ -1497,6 +1478,21 @@ struct GridwiseMoeGemm
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
auto gate = scale_a * scale_b * c_thread_buf[cidx];
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
}
else
{
@@ -1511,10 +1507,6 @@ struct GridwiseMoeGemm
}
else
{
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
static_assert(M4 == 4);
const index_t m1 = get_warp_local_1d_id() / NWave;
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
vector_type<float, 4> topk_weights; // for gemm2 only
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
@@ -1533,25 +1525,20 @@ struct GridwiseMoeGemm
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion
{
if constexpr(ActivationOperation == Activation::silu)
{
tensor_operation::element_wise::Silu{}(c_thread_buf(cidx),
c_thread_buf(cidx));
}
else if(ActivationOperation == Activation::gelu)
{
auto gate = c_thread_buf[cidx];
auto up = c_thread_buf_up[cidx];
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
else if(ActivationOperation == Activation::swiglu)
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
auto gate = c_thread_buf[cidx];
auto up = c_thread_buf_up[cidx];
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
auto gate = c_thread_buf[cidx];
auto up = c_thread_buf_up[cidx];
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
}
else
{