From 16dd90a523e14aefbe72c8617af1394fbdb956e2 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Wed, 26 Nov 2025 10:57:15 +0800 Subject: [PATCH] [CK_TILE] Refine warp_gemm_attribute_mfma (#3272) [ROCm/composable_kernel commit: 8fa90025d0da22683dabe721d77a75a536388683] --- .../gemm/warp/warp_gemm_attribute_mfma.hpp | 251 +++--------------- 1 file changed, 36 insertions(+), 215 deletions(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index d1b14721f2..896bb31b42 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -47,8 +47,9 @@ struct WarpGemmAttributeMfma template static constexpr auto get_warp_dstr_encoding() { + static_assert(kKPerThread % AttrNumAccessV == 0, + "kKPerThread must be divisible by NumAccess"); if constexpr(AttrNumAccessV == 1) - { return tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -56,11 +57,7 @@ struct WarpGemmAttributeMfma tuple>, sequence<2>, sequence<1>>{}; - } else - { - static_assert(kKPerThread % AttrNumAccessV == 0, - "kKPerThread must be divisible by NumAccess"); return tile_distribution_encoding< sequence<>, tuple, @@ -69,7 +66,6 @@ struct WarpGemmAttributeMfma tuple>, sequence<2, 2>, sequence<0, 2>>{}; - } } using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); @@ -155,28 +151,25 @@ struct WarpGemmAttributeMfmaIterateK static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1, "Multi-block on both M & N directions is not supported"); - CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() + template + CK_TILE_DEVICE static constexpr auto get_warp_dstr_encoding() { - if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + if constexpr(kMNBlock == 1 && kNMBlock == 1) { + static_assert(kKPerThread % AttrNumAccessV == 0, + "kKPerThread must be divisible by NumAccess"); if constexpr(AttrNumAccessV == 1) - { return tile_distribution_encoding< sequence<>, - tuple, - sequence>, + tuple, sequence>, tuple>, tuple>, sequence<2>, sequence<1>>{}; - } else - { - static_assert(kKPerThread % AttrNumAccessV == 0, - "kKPerThread must be divisible by NumAccess"); return tile_distribution_encoding< sequence<>, - tuple, + tuple, sequence>, @@ -184,30 +177,28 @@ struct WarpGemmAttributeMfmaIterateK tuple>, sequence<2, 2>, sequence<0, 2>>{}; - } } - else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + else if constexpr(kMNBlock == 1 && 1 < kNMBlock) { static_assert(AttrNumAccessV == 1, "Multiple access is not supported when using multi-block"); - // each M blocks share the same data + // each M/N blocks share the same data return tile_distribution_encoding< - sequence, - tuple, - sequence>, + sequence, + tuple, sequence>, tuple>, tuple>, sequence<2>, sequence<1>>{}; } - else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) + else if constexpr(1 < kMNBlock && kNMBlock == 1) { static_assert(AttrNumAccessV == 1, "Multiple access is not supported when using multi-block"); // single block to multi-block thread mapping return tile_distribution_encoding< sequence<>, - tuple, + tuple, sequence>, tuple>, tuple>, @@ -216,68 +207,6 @@ struct WarpGemmAttributeMfmaIterateK } } - CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() - { - if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) - { - if constexpr(AttrNumAccessV == 1) - { - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } - else - { - - static_assert(kKPerThread % AttrNumAccessV == 0, - "kKPerThread must be divisible by NumAccess"); - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2, 2>, - sequence<0, 2>>{}; - } - } - else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) - { - static_assert(AttrNumAccessV == 1, - "Multiple access is not supported when using multi-block"); - // single block to multi-block thread mapping - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } - else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) - { - static_assert(AttrNumAccessV == 1, - "Multiple access is not supported when using multi-block"); - // each N blocks share the same data - return tile_distribution_encoding< - sequence, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } - } - CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding() { if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) @@ -316,10 +245,10 @@ struct WarpGemmAttributeMfmaIterateK } } - using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); - - using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); - + using AWarpDstrEncoding = + decltype(get_warp_dstr_encoding()); + using BWarpDstrEncoding = + decltype(get_warp_dstr_encoding()); using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding()); // c_vec += a_vec * b_vec @@ -329,17 +258,7 @@ struct WarpGemmAttributeMfmaIterateK const BVecType& b_vec, bool_constant = {}) const { - using buf_a = thread_buffer; - using buf_b = thread_buffer; - - static_for<0, kKIter, 1>{}([&](auto iKIter) { - Impl{}(c_vec, - reinterpret_cast(a_vec) - .template get_as()[iKIter], - reinterpret_cast(b_vec) - .template get_as()[iKIter], - bool_constant{}); - }); + static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); }); } template @@ -354,14 +273,12 @@ struct WarpGemmAttributeMfmaIterateK static_assert(iKIter < kKIter); - // static_for<0, kKIter, 1>{}([&](auto iKIter) { Impl{}(c_vec, reinterpret_cast(a_vec) .template get_as()[iKIter], reinterpret_cast(b_vec) .template get_as()[iKIter], bool_constant{}); - //}); } // c_vec = a_vec * b_vec @@ -377,13 +294,7 @@ struct WarpGemmAttributeMfmaIterateK reinterpret_cast(b_vec).template get_as()[I0]); // c += a * b - static_for<1, kKIter, 1>{}([&](auto iKIter) { - Impl{}(c_vec, - reinterpret_cast(a_vec) - .template get_as()[iKIter], - reinterpret_cast(b_vec) - .template get_as()[iKIter]); - }); + static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); }); return c_vec; } @@ -416,35 +327,10 @@ struct WarpGemmAttributeMfmaTransposedCDistribution static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); - template - static constexpr auto get_warp_dstr_encoding() - { - if constexpr(AttrNumAccessV == 1) - { - return tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } - else - { - static_assert(kKPerThread % AttrNumAccessV == 0, - "kKPerThread must be divisible by NumAccess"); - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2, 2>, - sequence<0, 2>>{}; - } - } - using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); - using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + using AWarpDstrEncoding = + typename WarpGemmAttributeMfma::BWarpDstrEncoding; + using BWarpDstrEncoding = + typename WarpGemmAttributeMfma::AWarpDstrEncoding; using CWarpDstrEncoding = tile_distribution_encoding< sequence<>, @@ -597,18 +483,6 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1, "Multi-block on both M & N directions is not supported"); - CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() - { - return WarpGemmAttributeMfmaIterateK:: - get_bwarp_dstr_encoding(); - } - - CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() - { - return WarpGemmAttributeMfmaIterateK:: - get_awarp_dstr_encoding(); - } - CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding() { if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) @@ -647,30 +521,20 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution } } - using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); - - using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); - + using AWarpDstrEncoding = + typename WarpGemmAttributeMfmaIterateK::BWarpDstrEncoding; + using BWarpDstrEncoding = + typename WarpGemmAttributeMfmaIterateK::AWarpDstrEncoding; using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding()); - template // c_vec += a_vec * b_vec + template CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec, bool_constant = {}) const { - using buf_a = thread_buffer; - using buf_b = thread_buffer; - // swap A and B, value and type - static_for<0, kKIter, 1>{}([&](auto iKIter) { - Impl{}(c_vec, - reinterpret_cast(b_vec) - .template get_as()[iKIter], - reinterpret_cast(a_vec) - .template get_as()[iKIter], - bool_constant{}); - }); + static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); }); } template @@ -686,14 +550,12 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution static_assert(iKIter < kKIter); // swap A and B, value and type - // static_for<0, kKIter, 1>{}([&](auto iKIter) { Impl{}(c_vec, reinterpret_cast(b_vec) .template get_as()[iKIter], reinterpret_cast(a_vec) .template get_as()[iKIter], bool_constant{}); - //}); } // c_vec = a_vec * b_vec @@ -708,13 +570,7 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution reinterpret_cast(b_vec).template get_as()[I0], reinterpret_cast(a_vec).template get_as()[I0]); - static_for<1, kKIter, 1>{}([&](auto iKIter) { - Impl{}(c_vec, - reinterpret_cast(b_vec) - .template get_as()[iKIter], - reinterpret_cast(a_vec) - .template get_as()[iKIter]); - }); + static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); }); return c_vec; } @@ -805,17 +661,8 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB const BVecType& b_vec, bool_constant = {}) const { - using buf_a = thread_buffer; - using buf_b = thread_buffer; // swap A and B, value and type - static_for<0, kKIter, 1>{}([&](auto iKIter) { - Impl{}(c_vec, - reinterpret_cast(b_vec) - .template get_as()[iKIter], - reinterpret_cast(a_vec) - .template get_as()[iKIter], - bool_constant{}); - }); + static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); }); } template @@ -830,14 +677,12 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB static_assert(iKIter < kKIter); // swap A and B, value and type - // static_for<0, kKIter, 1>{}([&](auto iKIter) { Impl{}(c_vec, reinterpret_cast(b_vec) .template get_as()[iKIter], reinterpret_cast(a_vec) .template get_as()[iKIter], bool_constant{}); - //}); } // c_vec = a_vec * b_vec @@ -852,13 +697,7 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB reinterpret_cast(b_vec).template get_as()[I0], reinterpret_cast(a_vec).template get_as()[I0]); - static_for<1, kKIter, 1>{}([&](auto iKIter) { - Impl{}(c_vec, - reinterpret_cast(b_vec) - .template get_as()[iKIter], - reinterpret_cast(a_vec) - .template get_as()[iKIter]); - }); + static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); }); return c_vec; } @@ -926,17 +765,7 @@ struct WarpGemmAttributeMfmaIterateK_SwizzleA const BVecType& b_vec, bool_constant = {}) const { - using buf_a = thread_buffer; - using buf_b = thread_buffer; - - static_for<0, kKIter, 1>{}([&](auto iKIter) { - Impl{}(c_vec, - reinterpret_cast(a_vec) - .template get_as()[iKIter], - reinterpret_cast(b_vec) - .template get_as()[iKIter], - bool_constant{}); - }); + static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); }); } template @@ -951,14 +780,12 @@ struct WarpGemmAttributeMfmaIterateK_SwizzleA static_assert(iKIter < kKIter); - // static_for<0, kKIter, 1>{}([&](auto iKIter) { Impl{}(c_vec, reinterpret_cast(a_vec) .template get_as()[iKIter], reinterpret_cast(b_vec) .template get_as()[iKIter], bool_constant{}); - //}); } // c_vec = a_vec * b_vec @@ -972,13 +799,7 @@ struct WarpGemmAttributeMfmaIterateK_SwizzleA reinterpret_cast(a_vec).template get_as()[I0], reinterpret_cast(b_vec).template get_as()[I0]); - static_for<1, kKIter, 1>{}([&](auto iKIter) { - Impl{}(c_vec, - reinterpret_cast(a_vec) - .template get_as()[iKIter], - reinterpret_cast(b_vec) - .template get_as()[iKIter]); - }); + static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); }); return c_vec; }