diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index e063873d47..c338ba39c4 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -3,7 +3,6 @@ #pragma once -#include "ck_tile/core.hpp" #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" @@ -583,53 +582,51 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy typename Problem::HstuAttentionTileSetting::Gemm0BlockWarps, typename Problem::HstuAttentionTileSetting::Gemm0WarpTile>>; - constexpr auto warp_gemm = []() { - constexpr index_t WarpGemmM = - Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{}); - constexpr index_t WarpGemmK = - Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{}); - static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + auto warp_gemm = [&]() { + if constexpr((std::is_same_v || + std::is_same_v)&&std:: + is_same_v) + { + constexpr index_t WarpGemmM = + Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{}); + constexpr index_t WarpGemmK = + Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{}); - if constexpr(std::is_same_v && - std::is_same_v) - { - if constexpr(WarpGemmM == 32) - return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; - else if constexpr(WarpGemmM == 16) - { - if constexpr(WarpGemmK == 32) - return WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>{}; - else - return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; - } - else // WarpGemmM == 4 - return WarpGemmMfmaF16F16F32M4N64K16{}; - } - else if constexpr(std::is_same_v && - std::is_same_v) - { - if constexpr(WarpGemmM == 32) - return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; - else if constexpr(WarpGemmM == 16) - { - if constexpr(WarpGemmK == 32) - return WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>{}; - else - return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; - } - else // WarpGemmM == 4 - return WarpGemmMfmaBf16Bf16F32M4N64K16{}; - } - else if constexpr(std::is_same_v && - std::is_same_v) - { - static_assert(WarpGemmM == 32); +#ifdef __gfx950__ + static_assert(WarpGemmM == 16 && WarpGemmK == 32, "Not supported WarpGemm sizes!"); - // TODO: hard coded here. Otherwise, it may incorrect result - constexpr index_t swizzle_factor = 4; - return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution< - swizzle_factor>{}; - } // TODO - bf8_t + return WarpGemmDispatcher< + typename Problem::QKVDataType, + typename Problem::QKVDataType, + typename Problem::GemmAccDataType, + Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}), + Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}), + Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + WGAttrNumAccessEnum::Single>{}; +#else + static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)), + "Not supported WarpGemm sizes!"); + + return WarpGemmDispatcher< + typename Problem::QKVDataType, + typename Problem::QKVDataType, + typename Problem::GemmAccDataType, + Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}), + Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}), + Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + WGAttrNumAccessEnum::Single>{}; +#endif + } + else + { + static_assert(false, "Not supported data types!"); + } }(); using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy< @@ -667,25 +664,46 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy typename Problem::HstuAttentionTileSetting::Gemm1WarpTile>>; auto warp_gemm = [&]() { - if constexpr(std::is_same_v && - std::is_same_v) + if constexpr((std::is_same_v || + std::is_same_v)&&std:: + is_same_v) { - return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<>{}; - // return - // WarpGemmImpl>>{}; + constexpr index_t WarpGemmM = + Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}); + constexpr index_t WarpGemmK = + Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}); + + static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)), + "Not supported WarpGemm sizes!"); + + if constexpr(WarpGemmK == 32) + return WarpGemmDispatcher< + typename Problem::QKVDataType, + typename Problem::QKVDataType, + typename Problem::GemmAccDataType, + Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}), + Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}), + Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + WGAttrNumAccessEnum::Double>{}; + else + return WarpGemmDispatcher< + typename Problem::QKVDataType, + typename Problem::QKVDataType, + typename Problem::GemmAccDataType, + Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}), + Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}), + Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + WGAttrNumAccessEnum::Single>{}; } else { - return WarpGemmDispatcher< - typename Problem::QKVDataType, - typename Problem::QKVDataType, - typename Problem::GemmAccDataType, - Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}), - Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}), - Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}), - true>{}; + static_assert(false, "Not supported data types!"); } }();