refine activation code & complete moe example

This commit is contained in:
lalala-sh
2025-04-11 02:18:16 +00:00
parent d2a82f56e2
commit 0353337447
4 changed files with 59 additions and 76 deletions

View File

@@ -131,7 +131,7 @@ static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t ActOP = 2; // 0: gelu_and_mul, 2: silu_and_mul
static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,

View File

@@ -62,7 +62,7 @@ struct MulABScale
(void)d0;
(void)d1;
#if CK_USE_PK4_LAYOUT_SHUFFLE
e = ck::type_convert<EDataType>(c * 16);
e = ck::type_convert<EDataType>(c);
#else
e = ck::type_convert<EDataType>(c);
#endif
@@ -74,7 +74,7 @@ struct MulABScale
(void)d0;
(void)d1;
#if CK_USE_PK4_LAYOUT_SHUFFLE
e = ck::type_convert<EDataType>(c * 16);
e = ck::type_convert<EDataType>(c);
#else
e = ck::type_convert<EDataType>(c);
#endif
@@ -125,7 +125,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 2; // 0: gelu_and_mul, 2: silu_and_mul
static constexpr ck::index_t Act_OP = 1; // 0: gelu_and_mul, 1: silu_and_mul
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
Row, Col, DsLayout, ELayout,
@@ -203,7 +203,7 @@ int main(int argc, char* argv[])
{
expert_ids.mData[i] = eids[i];
}
int token_per_tile = tokens * topk / valid_tile_num;
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;
for(int i = 0; i < sorted_size; i++)
{
@@ -479,7 +479,7 @@ int main(int argc, char* argv[])
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
return ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
? 0
: 1;
}