mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 05:37:34 +00:00
refine activation code & complete moe example
This commit is contained in:
@@ -29,9 +29,8 @@ namespace ck {
|
||||
|
||||
enum Activation
|
||||
{
|
||||
gelu = 0,
|
||||
silu = 1,
|
||||
swiglu = 2
|
||||
gelu_and_mul = 0,
|
||||
silu_and_mul = 1
|
||||
};
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
@@ -1405,6 +1404,12 @@ struct GridwiseMoeGemm
|
||||
// mul scales
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
const float* p_scale_b = p_ds_grid[I1];
|
||||
|
||||
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
|
||||
static_assert(M4 == 4);
|
||||
const index_t m1 = get_warp_local_1d_id() / NWave;
|
||||
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
|
||||
|
||||
if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
|
||||
{
|
||||
if constexpr(PerTokenQuant)
|
||||
@@ -1418,10 +1423,6 @@ struct GridwiseMoeGemm
|
||||
p_scale_b += expert_id;
|
||||
}
|
||||
|
||||
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
|
||||
static_assert(M4 == 4);
|
||||
const index_t m1 = get_warp_local_1d_id() / NWave;
|
||||
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
|
||||
vector_type<int32_t, 4> scale_token_ids;
|
||||
vector_type<float, 4> topk_weights; // for gemm2 only
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
@@ -1462,27 +1463,7 @@ struct GridwiseMoeGemm
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu)
|
||||
{
|
||||
tensor_operation::element_wise::Silu{}(c_thread_buf(cidx),
|
||||
c_thread_buf(cidx));
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
auto gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::swiglu)
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
@@ -1497,6 +1478,21 @@ struct GridwiseMoeGemm
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
auto gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1511,10 +1507,6 @@ struct GridwiseMoeGemm
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
|
||||
static_assert(M4 == 4);
|
||||
const index_t m1 = get_warp_local_1d_id() / NWave;
|
||||
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
|
||||
vector_type<float, 4> topk_weights; // for gemm2 only
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
@@ -1533,25 +1525,20 @@ struct GridwiseMoeGemm
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu)
|
||||
{
|
||||
tensor_operation::element_wise::Silu{}(c_thread_buf(cidx),
|
||||
c_thread_buf(cidx));
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu)
|
||||
{
|
||||
auto gate = c_thread_buf[cidx];
|
||||
auto up = c_thread_buf_up[cidx];
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::swiglu)
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
auto gate = c_thread_buf[cidx];
|
||||
auto up = c_thread_buf_up[cidx];
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
auto gate = c_thread_buf[cidx];
|
||||
auto up = c_thread_buf_up[cidx];
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user