mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[CK_Tile] Refactor Permute and MOE Smoothquant ctests to gtests (#2622)
* Refactor CK tile permute ctests to gtests * Refactor CK tile MOE smoothquant ctests to gtests * fix typo in comment Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update invalid case in else clause for get_precision_string * Refactor permute gtests to use templated versions of matrix_core_swizzle and permute functions --------- Co-authored-by: root <root@splinter-126-wr-c2.aus.dcgpu> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -24,9 +24,7 @@ using trait_ = moe_smoothquant_traits_<InType,
|
||||
kTwoPass_>;
|
||||
|
||||
template <typename in_type, typename out_type>
|
||||
float moe_smoothquant_dispatch(moe_smoothquant_traits /*t*/,
|
||||
moe_smoothquant_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
float moe_smoothquant_dispatch(moe_smoothquant_args a, const ck_tile::stream_config& s)
|
||||
{
|
||||
float r = -1;
|
||||
// clang-format off
|
||||
@@ -130,26 +128,30 @@ float moe_smoothquant_dispatch(moe_smoothquant_traits /*t*/,
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
float moe_smoothquant(moe_smoothquant_traits t,
|
||||
moe_smoothquant_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
template <>
|
||||
float moe_smoothquant<ck_tile::fp16_t, ck_tile::int8_t>(moe_smoothquant_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(t.in_type.compare("fp16") == 0 && t.out_type == "int8")
|
||||
{
|
||||
return moe_smoothquant_dispatch<ck_tile::fp16_t, ck_tile::int8_t>(t, a, s);
|
||||
}
|
||||
else if(t.in_type.compare("fp16") == 0 && t.out_type == "fp8")
|
||||
{
|
||||
return moe_smoothquant_dispatch<ck_tile::fp16_t, ck_tile::fp8_t>(t, a, s);
|
||||
}
|
||||
else if(t.in_type.compare("bf16") == 0 && t.out_type == "int8")
|
||||
{
|
||||
return moe_smoothquant_dispatch<ck_tile::bf16_t, ck_tile::int8_t>(t, a, s);
|
||||
}
|
||||
else if(t.in_type.compare("bf16") == 0 && t.out_type == "fp8")
|
||||
{
|
||||
return moe_smoothquant_dispatch<ck_tile::bf16_t, ck_tile::fp8_t>(t, a, s);
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("Without supported instances!");
|
||||
}
|
||||
return moe_smoothquant_dispatch<ck_tile::fp16_t, ck_tile::int8_t>(a, s);
|
||||
};
|
||||
|
||||
template <>
|
||||
float moe_smoothquant<ck_tile::fp16_t, ck_tile::fp8_t>(moe_smoothquant_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
return moe_smoothquant_dispatch<ck_tile::fp16_t, ck_tile::fp8_t>(a, s);
|
||||
};
|
||||
|
||||
template <>
|
||||
float moe_smoothquant<ck_tile::bf16_t, ck_tile::int8_t>(moe_smoothquant_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
return moe_smoothquant_dispatch<ck_tile::bf16_t, ck_tile::int8_t>(a, s);
|
||||
};
|
||||
|
||||
template <>
|
||||
float moe_smoothquant<ck_tile::bf16_t, ck_tile::fp8_t>(moe_smoothquant_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
return moe_smoothquant_dispatch<ck_tile::bf16_t, ck_tile::fp8_t>(a, s);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user