mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
refine activation code & complete moe example
This commit is contained in:
@@ -131,7 +131,7 @@ static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
static constexpr ck::index_t ActOP = 2; // 0: gelu_and_mul, 2: silu_and_mul
|
||||
static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// clang-format off
|
||||
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
|
||||
@@ -62,7 +62,7 @@ struct MulABScale
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c * 16);
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
#else
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
#endif
|
||||
@@ -74,7 +74,7 @@ struct MulABScale
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c * 16);
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
#else
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
#endif
|
||||
@@ -125,7 +125,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
|
||||
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t ActOP = 2; // 0: gelu_and_mul, 2: silu_and_mul
|
||||
static constexpr ck::index_t Act_OP = 1; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
@@ -203,7 +203,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
expert_ids.mData[i] = eids[i];
|
||||
}
|
||||
int token_per_tile = tokens * topk / valid_tile_num;
|
||||
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
@@ -479,7 +479,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
return ck::utils::check_err(
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
@@ -29,9 +29,8 @@ namespace ck {
|
||||
|
||||
enum Activation
|
||||
{
|
||||
gelu = 0,
|
||||
silu = 1,
|
||||
swiglu = 2
|
||||
gelu_and_mul = 0,
|
||||
silu_and_mul = 1
|
||||
};
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
@@ -1405,6 +1404,12 @@ struct GridwiseMoeGemm
|
||||
// mul scales
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
const float* p_scale_b = p_ds_grid[I1];
|
||||
|
||||
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
|
||||
static_assert(M4 == 4);
|
||||
const index_t m1 = get_warp_local_1d_id() / NWave;
|
||||
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
|
||||
|
||||
if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
|
||||
{
|
||||
if constexpr(PerTokenQuant)
|
||||
@@ -1418,10 +1423,6 @@ struct GridwiseMoeGemm
|
||||
p_scale_b += expert_id;
|
||||
}
|
||||
|
||||
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
|
||||
static_assert(M4 == 4);
|
||||
const index_t m1 = get_warp_local_1d_id() / NWave;
|
||||
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
|
||||
vector_type<int32_t, 4> scale_token_ids;
|
||||
vector_type<float, 4> topk_weights; // for gemm2 only
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
@@ -1462,27 +1463,7 @@ struct GridwiseMoeGemm
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu)
|
||||
{
|
||||
tensor_operation::element_wise::Silu{}(c_thread_buf(cidx),
|
||||
c_thread_buf(cidx));
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
auto gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::swiglu)
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
@@ -1497,6 +1478,21 @@ struct GridwiseMoeGemm
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
auto gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1511,10 +1507,6 @@ struct GridwiseMoeGemm
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
|
||||
static_assert(M4 == 4);
|
||||
const index_t m1 = get_warp_local_1d_id() / NWave;
|
||||
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
|
||||
vector_type<float, 4> topk_weights; // for gemm2 only
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
@@ -1533,25 +1525,20 @@ struct GridwiseMoeGemm
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu)
|
||||
{
|
||||
tensor_operation::element_wise::Silu{}(c_thread_buf(cidx),
|
||||
c_thread_buf(cidx));
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu)
|
||||
{
|
||||
auto gate = c_thread_buf[cidx];
|
||||
auto up = c_thread_buf_up[cidx];
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::swiglu)
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
auto gate = c_thread_buf[cidx];
|
||||
auto up = c_thread_buf_up[cidx];
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
auto gate = c_thread_buf[cidx];
|
||||
auto up = c_thread_buf_up[cidx];
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -80,10 +80,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
if constexpr(ActivationType > 2)
|
||||
{
|
||||
static_assert(false, "Not supported activation type");
|
||||
}
|
||||
static_assert(ActivationType < 2, "Not supported activation type");
|
||||
const int full_n = arg.c_t_k_n_.mDesc.GetLengths()[2];
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = arg.a_t_k_.mDesc.GetLengths()[1];
|
||||
@@ -148,44 +145,43 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
else
|
||||
{
|
||||
arg.b_element_op_(v_b, arg.b_e_n_k_(e, k, n));
|
||||
if constexpr(ActivationType == 2)
|
||||
{
|
||||
arg.b_element_op_(v_b_up, arg.b_e_n_k_(e, k, n + full_n));
|
||||
}
|
||||
arg.b_element_op_(v_b_up, arg.b_e_n_k_(e, k, n + full_n));
|
||||
}
|
||||
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
|
||||
if constexpr(ActivationType == 2)
|
||||
{
|
||||
v_acc_up += ck::type_convert<AccDataType>(v_a) *
|
||||
ck::type_convert<AccDataType>(v_b_up);
|
||||
}
|
||||
v_acc_up += ck::type_convert<AccDataType>(v_a) *
|
||||
ck::type_convert<AccDataType>(v_b_up);
|
||||
}
|
||||
CDataType v_c{0};
|
||||
CDataType v_c_up{0};
|
||||
|
||||
arg.c_element_op_(v_c, v_acc);
|
||||
if constexpr(ActivationType == 2)
|
||||
if constexpr(ActivationType == 1)
|
||||
{
|
||||
arg.c_element_op_(v_c_up, v_acc_up);
|
||||
v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t);
|
||||
v_c = v_c * (1.0 / (1.0 + math::exp(-v_c)));
|
||||
v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t);
|
||||
if constexpr(is_same_v<BDataType, pk_i4_t>)
|
||||
{
|
||||
v_c_up *= 16;
|
||||
v_c *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(v_c, v_c);
|
||||
v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t);
|
||||
arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up;
|
||||
}
|
||||
else
|
||||
else if constexpr(ActivationType == 0)
|
||||
{
|
||||
if constexpr(ActivationType == 1)
|
||||
arg.c_element_op_(v_c_up, v_acc_up);
|
||||
v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t);
|
||||
if constexpr(is_same_v<BDataType, pk_i4_t>)
|
||||
{
|
||||
tensor_operation::element_wise::Silu{}(v_c, v_c);
|
||||
v_c_up *= 16;
|
||||
v_c *= 16;
|
||||
}
|
||||
else if constexpr(ActivationType == 0)
|
||||
{
|
||||
tensor_operation::element_wise::Gelu{}(v_c, v_c);
|
||||
}
|
||||
arg.c_t_k_n_(t, topk_id, n) = v_c;
|
||||
tensor_operation::element_wise::Gelu{}(v_c, v_c);
|
||||
v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t);
|
||||
arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user