From 1e65b3ab35a640fda8ee38b9ebd795383f93e08c Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Wed, 25 Dec 2024 23:57:28 +0800 Subject: [PATCH] Correct the dtype checking logics (#1775) [ROCm/composable_kernel commit: 4c2eff023a26821512a100171531dc8757ad0e8f] --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index df5b9cecc6..2f7edd5477 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -261,7 +261,7 @@ FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F static_assert({F_bn1} % 32 == 0); if (t.has_lse) {{ - if constexpr (std::is_same_v<{F_dtype}, ck_tile::fp8_t>) {{ + if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{ return -1; }} else {{ using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>; @@ -614,7 +614,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdSplitKVCombineTileSize(32, -1), + '64' : FmhaFwdSplitKVCombineTileSize(32, -1), '128' : FmhaFwdSplitKVCombineTileSize(32, -1), '256' : FmhaFwdSplitKVCombineTileSize(32, -1), }