From f2fcc4a4613f2f7525dbf90f81631da81abab7df Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Fri, 28 Nov 2025 09:20:19 +0000 Subject: [PATCH] Add NumAccess as a template parameter to WarpGemmAttributeMfma::get_warp_dstr_encoding --- .../block/block_universal_gemm_as_bs_cr.hpp | 16 ++- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 9 +- .../gemm/warp/warp_gemm_attribute_mfma.hpp | 106 +++++++++++++----- .../gemm/warp/warp_gemm_attribute_wmma.hpp | 22 +++- .../ck_tile/ops/gemm/warp/warp_gemm_impl.hpp | 4 +- 5 files changed, 121 insertions(+), 36 deletions(-) diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 6fb5cf433b..235885a8dd 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -139,6 +139,7 @@ struct BlockUniversalGemmAsBsCr using I0 = number<0>; using I1 = number<1>; + template CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() { constexpr index_t KPerThread = Traits::KPerThread; @@ -158,12 +159,18 @@ struct BlockUniversalGemmAsBsCr tuple>, sequence<1, 2>, sequence<0, 0>>{}; + using Attr = typename WarpGemm::WarpGemmAttribute; + constexpr auto NumAccessA = + convert ? Attr::AttrNumAccessV * sizeof(ADataType) / sizeof(ComputeDataType) + : Attr::AttrNumAccessV; constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + a_block_outer_dstr_encoding, + WarpGemm::WarpGemmAttribute::template get_awarp_dstr_encoding()); return a_block_dstr_encode; } + template CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() { constexpr index_t KPerThread = Traits::KPerThread; @@ -183,8 +190,13 @@ struct BlockUniversalGemmAsBsCr tuple>, sequence<1, 2>, sequence<0, 0>>{}; + using Attr = typename WarpGemm::WarpGemmAttribute; + constexpr auto NumAccessB = + convert ? Attr::AttrNumAccessV * sizeof(BDataType) / sizeof(ComputeDataType) + : Attr::AttrNumAccessV; constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + b_block_outer_dstr_encoding, + WarpGemm::WarpGemmAttribute::template get_bwarp_dstr_encoding()); return b_block_dstr_encode; } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 8fae704203..71e69621f5 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -440,10 +440,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); // Tile distribution for load from lds - constexpr auto a_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); - constexpr auto b_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + constexpr bool is_load_tr = is_a_load_tr_v || is_b_load_tr_v; + constexpr auto a_lds_load_tile_distr = make_static_tile_distribution( + BlockGemm::template MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = make_static_tile_distribution( + BlockGemm::template MakeBBlockDistributionEncode()); // A DRAM tile window for load // A LDS tile window for store diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 3c7944a427..553f94fff3 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -44,12 +44,12 @@ struct WarpGemmAttributeMfma static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); - template + template static constexpr auto get_warp_dstr_encoding() { - static_assert(kKPerThread % AttrNumAccessV == 0, + static_assert(NumAccess != 0 && kKPerThread % NumAccess == 0, "kKPerThread must be divisible by NumAccess"); - if constexpr(AttrNumAccessV == 1) + if constexpr(NumAccess == 1) return tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -61,14 +61,30 @@ struct WarpGemmAttributeMfma return tile_distribution_encoding< sequence<>, tuple, - sequence>, + sequence>, tuple>, tuple>, sequence<2, 2>, sequence<0, 2>>{}; } - using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); - using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + + template + static constexpr auto get_awarp_dstr_encoding() + { + return get_warp_dstr_encoding(); + } + + template + static constexpr auto get_bwarp_dstr_encoding() + { + return get_warp_dstr_encoding(); + } + + template + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + + template + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); using CWarpDstrEncoding = tile_distribution_encoding< sequence<>, @@ -151,14 +167,14 @@ struct WarpGemmAttributeMfmaIterateK static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1, "Multi-block on both M & N directions is not supported"); - template + template CK_TILE_DEVICE static constexpr auto get_warp_dstr_encoding() { if constexpr(kMNBlock == 1 && kNMBlock == 1) { static_assert(kKPerThread % AttrNumAccessV == 0, "kKPerThread must be divisible by NumAccess"); - if constexpr(AttrNumAccessV == 1) + if constexpr(NumAccess == 1) return tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -170,9 +186,9 @@ struct WarpGemmAttributeMfmaIterateK return tile_distribution_encoding< sequence<>, tuple, - sequence>, + Impl::kABKPerLane * kKIter / NumAccess>>, tuple>, tuple>, sequence<2, 2>, @@ -180,7 +196,7 @@ struct WarpGemmAttributeMfmaIterateK } else if constexpr(kMNBlock == 1 && 1 < kNMBlock) { - static_assert(AttrNumAccessV == 1, + static_assert(NumAccess == 1, "Multiple access is not supported when using multi-block"); // each M/N blocks share the same data return tile_distribution_encoding< @@ -193,7 +209,7 @@ struct WarpGemmAttributeMfmaIterateK } else if constexpr(1 < kMNBlock && kNMBlock == 1) { - static_assert(AttrNumAccessV == 1, + static_assert(NumAccess == 1, "Multiple access is not supported when using multi-block"); // single block to multi-block thread mapping return tile_distribution_encoding< @@ -207,6 +223,18 @@ struct WarpGemmAttributeMfmaIterateK } } + template + CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() + { + return get_warp_dstr_encoding(); + } + + template + CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() + { + return get_warp_dstr_encoding(); + } + CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding() { if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) @@ -245,10 +273,12 @@ struct WarpGemmAttributeMfmaIterateK } } - using AWarpDstrEncoding = - decltype(get_warp_dstr_encoding()); - using BWarpDstrEncoding = - decltype(get_warp_dstr_encoding()); + template + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + + template + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); + using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding()); // c_vec += a_vec * b_vec @@ -327,10 +357,23 @@ struct WarpGemmAttributeMfmaTransposedCDistribution static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); - using AWarpDstrEncoding = - typename WarpGemmAttributeMfma::BWarpDstrEncoding; - using BWarpDstrEncoding = - typename WarpGemmAttributeMfma::AWarpDstrEncoding; + template + CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() + { + return WarpGemmAttributeMfma::template get_bwarp_dstr_encoding(); + } + + template + CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() + { + return WarpGemmAttributeMfma::template get_awarp_dstr_encoding(); + } + + template + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + + template + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); using CWarpDstrEncoding = tile_distribution_encoding< sequence<>, @@ -384,6 +427,7 @@ struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); + template using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -392,6 +436,7 @@ struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB sequence<2>, sequence<1>>; #if 0 + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple>; #else // TODO: more test not only 32x32 + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution { - using Impl = remove_cvref_t; - static constexpr auto AttrNumAccess = AttrNumAccess_; + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccess = AttrNumAccess_; + static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); // swap A and B using ADataType = typename Impl::BDataType; @@ -521,10 +568,12 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution } } - using AWarpDstrEncoding = - typename WarpGemmAttributeMfmaIterateK::BWarpDstrEncoding; - using BWarpDstrEncoding = - typename WarpGemmAttributeMfmaIterateK::AWarpDstrEncoding; + template + using AWarpDstrEncoding = typename WarpGemmAttributeMfmaIterateK:: + template BWarpDstrEncoding; + template + using BWarpDstrEncoding = typename WarpGemmAttributeMfmaIterateK:: + template AWarpDstrEncoding; using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding()); // c_vec += a_vec * b_vec @@ -603,6 +652,7 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); + template using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -611,6 +661,7 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB sequence<2>, sequence<1>>; #if 0 + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple>; #else // TODO: more test not only 32x32 + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence<1>>; + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp index ef31d06c9c..4a275848b2 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp @@ -67,6 +67,10 @@ template struct WarpGemmAttributeWmma { using Impl = remove_cvref_t; + // AttrNumAccessV is required for compatibility with the block GEMM, and is currently ignored + // within WarpGemmAttributeWmma + static constexpr auto AttrNumAccess = WGAttrNumAccessEnum::Single; + static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); // When kTransC is true and A/B types differ, we need an impl with swapped types using TransposedImpl = @@ -99,8 +103,22 @@ struct WarpGemmAttributeWmma // 16 bit input, kAMLane = 16, kABK0PerLane = 4, kABKLane = 2, kABK1PerLane = 2 // 8 bit input, kAMLane = 16, kABK0PerLane = 2, kABKLane = 2, kABK1PerLane = 4 - using AWarpDstrEncoding = typename AWarpDstrEncodingTrait::type; - using BWarpDstrEncoding = typename BWarpDstrEncodingTrait::type; + template + static constexpr auto get_awarp_dstr_encoding() + { + return typename AWarpDstrEncodingTrait::type{}; + } + + template + static constexpr auto get_bwarp_dstr_encoding() + { + return typename BWarpDstrEncodingTrait::type{}; + } + + template + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + template + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); // kCM0PerLane = 1, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16 using CWarpDstrEncoding = diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp index ca7c32b6af..5ff0660f49 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp @@ -25,8 +25,8 @@ struct WarpGemmImpl using BDataType = typename WarpGemmAttribute::BDataType; using CDataType = typename WarpGemmAttribute::CDataType; - using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding; - using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding; + using AWarpDstrEncoding = typename WarpGemmAttribute::template AWarpDstrEncoding<>; + using BWarpDstrEncoding = typename WarpGemmAttribute::template BWarpDstrEncoding<>; using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding; using AWarpDstr = remove_cvref_t;