From 99ab4c4ef7e23588b725a84e30023dfd801cc5ba Mon Sep 17 00:00:00 2001 From: Aviral Goel <191153937+AviralGoelAMD@users.noreply.github.com> Date: Tue, 2 Jun 2026 13:54:16 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#7830 (commit 590fe58) [CK_Tile][MI450] Add bf16 output wmma instruction (16x16x32) (#7830) Wire __builtin_amdgcn_wmma_bf16_16x16x32_bf16 into CK Tile for gfx1250, enabling bf16-input bf16-output WMMA at the warp GEMM level. - Add WmmaTraits specialization for - Add WarpGemmAttributeWmmaImpl typedef and WarpGemmWmma alias - Add Dispatcher entry for bf16->bf16 16x16x32 - Add warp_gemm test with reference GEMM validation ## Motivation ## Technical Details ## Test Plan ## Test Result ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../warp/warp_gemm_attribute_wmma_impl.hpp | 3 + ..._gemm_attribute_wmma_impl_16bit_traits.hpp | 24 +++ .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 2 + .../ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp | 7 + test/ck_tile/warp_gemm/CMakeLists.txt | 3 +- .../test_wmma_bf16_16x16x32_gfx1250.cpp | 150 ++++++++++++++++++ 6 files changed, 188 insertions(+), 1 deletion(-) create mode 100644 test/ck_tile/warp_gemm/test_wmma_bf16_16x16x32_gfx1250.cpp diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp index 8fd185cb42..6f38199828 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp @@ -146,6 +146,9 @@ using WarpGemmAttributeWmmaImpl_f32_16x16x32_f16_f16 = using WarpGemmAttributeWmmaImpl_f32_16x16x32_bf16_bf16 = WarpGemmAttributeWmmaImpl>; +using WarpGemmAttributeWmmaImpl_bf16_16x16x32_bf16_bf16 = + WarpGemmAttributeWmmaImpl>; + using WarpGemmAttributeWmmaImpl_i32_16x16x64_i8_i8 = WarpGemmAttributeWmmaImpl>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp index b5d7365dad..e770efeaeb 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp @@ -142,4 +142,28 @@ struct WmmaTraits #endif } }; + +// bf16 -> bf16 specialization - GFX125 +template <> +struct WmmaTraits + : WmmaTraitsBase +{ + using ArchType = gfx125_t; + + template + CK_TILE_DEVICE static CVecType + wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) + { +#ifdef __gfx125__ + using P = WarpGemmParamsParser; + return __builtin_amdgcn_wmma_bf16_16x16x32_bf16( + 0, a_vec, 0, b_vec, 0, c_vec, P::reuse_a, P::reuse_b); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = c_vec; + return CVecType{0}; +#endif + } +}; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 2e6fa605ba..b77f09fa9f 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -149,6 +149,8 @@ template<> struct Dispatcher { using T #if defined(__gfx125__) template struct Dispatcher : WmmaTag { using Type = WarpGemmWmma_f32_16x16x32_bf16_bf16;}; +template struct Dispatcher + : WmmaTag { using Type = WarpGemmWmma_bf16_16x16x32_bf16_bf16;}; #else template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; }; diff --git a/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp index 1c522d07c1..224937dc6d 100644 --- a/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp @@ -68,6 +68,13 @@ using WarpGemmWmma_f32_16x16x32_bf16_bf16 = AttrNumAccess, AttrNumAccess>>; +template +using WarpGemmWmma_bf16_16x16x32_bf16_bf16 = + WarpGemmImpl>; + template using WarpGemmWmma_f32_16x16x16_f8_bf8 = WarpGemmImpl +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +using namespace ck_tile; + +template +struct WGDispCase +{ + using AType = A; + using BType = B; + using AccType = Acc; + static constexpr index_t MPerWave = M; + static constexpr index_t NPerWave = N; + static constexpr index_t KPerWave = K; + static constexpr bool kTransposeC = TransposeC; +}; + +using WGDispatcherTypesList = + ::testing::Types, + WGDispCase>; + +template +struct WarpGemmKernel +{ + static constexpr int kBlockSize = 32; + __device__ void operator()(void* A, void* B, void* C) const + { + using WarpGemm = WarpGemmDispatcher; + + const auto a_view = + make_naive_tensor_view(static_cast(A), + make_tuple(M, K), + make_tuple(K, number<1>{}), + number{}, + number<1>{}); + + const auto b_view = + make_naive_tensor_view(static_cast(B), + make_tuple(N, K), + make_tuple(K, number<1>{}), + number{}, + number<1>{}); + + const auto c_view = + make_naive_tensor_view(static_cast(C), + make_tuple(M, N), + make_tuple(N, number<1>{}), + number{}, + number<1>{}); + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + constexpr auto a_len = AWarpTensor::get_tile_distribution().get_lengths(); + constexpr auto b_len = BWarpTensor::get_tile_distribution().get_lengths(); + constexpr auto c_len = CWarpTensor::get_tile_distribution().get_lengths(); + + auto a_win = make_tile_window( + a_view, a_len, make_multi_index(0, 0), AWarpTensor::get_tile_distribution()); + auto b_win = make_tile_window( + b_view, b_len, make_multi_index(0, 0), BWarpTensor::get_tile_distribution()); + auto c_win = make_tile_window( + c_view, c_len, make_multi_index(0, 0), CWarpTensor::get_tile_distribution()); + + AWarpTensor a_tile; + BWarpTensor b_tile; + load_tile(a_tile, a_win); + load_tile(b_tile, b_win); + + auto c_tile = WarpGemm{}(a_tile, b_tile); + + store_tile(c_win, c_tile); + } +}; + +template +static void RunWarpGemmCase(const HostTensor& A, + const HostTensor& B, + HostTensor& C) +{ + DeviceMem Ad(A), Bd(B), Cd(C); + + using Kernel = WarpGemmKernel; + dim3 grid(1), block{Kernel::kBlockSize}; + + (void)launch_kernel(stream_config{nullptr, true, 0, 0, 1}, + make_kernel(Kernel{}, + grid, + block, + 0, + Ad.GetDeviceBuffer(), + Bd.GetDeviceBuffer(), + Cd.GetDeviceBuffer())); + + Cd.FromDevice(C.mData.data()); +} + +template +class WGRuntimeTest : public ::testing::Test +{ +}; + +TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList); + +TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_ReferenceGemm) +{ + using Case = TypeParam; + + using AType = typename Case::AType; + using BType = typename Case::BType; + using CType = typename Case::AccType; + + constexpr index_t M = Case::MPerWave; + constexpr index_t N = Case::NPerWave; + constexpr index_t K = Case::KPerWave; + + HostTensor A({M, K}); + HostTensor B({N, K}); + HostTensor C({M, N}); + + FillUniformDistribution{-1.f, 1.f, 11939}(A); + FillUniformDistribution{-1.f, 1.f, 11940}(B); + C.SetZero(); + + RunWarpGemmCase(A, B, C); + + HostTensor C_ref({M, N}); + C_ref.SetZero(); + reference_gemm(A, B.transpose(), C_ref); + + EXPECT_TRUE(check_err(C, C_ref, "Warp gemm bf16 result error.")); +}