mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
refine activation code & complete moe example
This commit is contained in:
@@ -131,7 +131,7 @@ static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
static constexpr ck::index_t ActOP = 2; // 0: gelu_and_mul, 2: silu_and_mul
|
||||
static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// clang-format off
|
||||
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
|
||||
@@ -62,7 +62,7 @@ struct MulABScale
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c * 16);
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
#else
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
#endif
|
||||
@@ -74,7 +74,7 @@ struct MulABScale
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c * 16);
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
#else
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
#endif
|
||||
@@ -125,7 +125,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
|
||||
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t ActOP = 2; // 0: gelu_and_mul, 2: silu_and_mul
|
||||
static constexpr ck::index_t Act_OP = 1; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
@@ -203,7 +203,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
expert_ids.mData[i] = eids[i];
|
||||
}
|
||||
int token_per_tile = tokens * topk / valid_tile_num;
|
||||
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
@@ -479,7 +479,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
return ck::utils::check_err(
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user