[rocm-libraries] ROCm/rocm-libraries#7830 (commit 590fe58)

[CK_Tile][MI450] Add bf16 output wmma instruction (16x16x32)
 (#7830)

Wire __builtin_amdgcn_wmma_bf16_16x16x32_bf16 into CK Tile for gfx1250,
enabling bf16-input bf16-output WMMA at the warp GEMM level.

- Add WmmaTraits specialization for <gfx125_t, bf16, bf16, bf16,
16,16,32>
- Add WarpGemmAttributeWmmaImpl typedef and WarpGemmWmma alias
- Add Dispatcher entry for bf16->bf16 16x16x32
- Add warp_gemm test with reference GEMM validation

## Motivation

<!-- Explain the purpose of this PR and the goals it aims to achieve.
-->

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Aviral Goel
2026-06-02 13:54:16 +00:00
committed by assistant-librarian[bot]
parent 919096fde8
commit 99ab4c4ef7
6 changed files with 188 additions and 1 deletions

View File

@@ -146,6 +146,9 @@ using WarpGemmAttributeWmmaImpl_f32_16x16x32_f16_f16 =
using WarpGemmAttributeWmmaImpl_f32_16x16x32_bf16_bf16 =
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx125_t, bf16_t, bf16_t, float, 16, 16, 32>>;
using WarpGemmAttributeWmmaImpl_bf16_16x16x32_bf16_bf16 =
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx125_t, bf16_t, bf16_t, bf16_t, 16, 16, 32>>;
using WarpGemmAttributeWmmaImpl_i32_16x16x64_i8_i8 =
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx125_t, int8_t, int8_t, int32_t, 16, 16, 64>>;

View File

@@ -142,4 +142,28 @@ struct WmmaTraits<gfx125_t, bf16_t, bf16_t, float, 16, 16, 32>
#endif
}
};
// bf16 -> bf16 specialization - GFX125
template <>
struct WmmaTraits<gfx125_t, bf16_t, bf16_t, bf16_t, 16, 16, 32>
: WmmaTraitsBase<gfx12_t, bf16_t, bf16_t, bf16_t, 32>
{
using ArchType = gfx125_t;
template <typename... Params>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx125__
using P = WarpGemmParamsParser<Params...>;
return __builtin_amdgcn_wmma_bf16_16x16x32_bf16(
0, a_vec, 0, b_vec, 0, c_vec, P::reuse_a, P::reuse_b);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0};
#endif
}
};
} // namespace ck_tile

View File

@@ -149,6 +149,8 @@ template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 16, true> { using T
#if defined(__gfx125__)
template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, TransposeC, false, false, AttrNumAccess, AttrNumAccess>
: WmmaTag { using Type = WarpGemmWmma_f32_16x16x32_bf16_bf16<TransposeC, AttrNumAccess>;};
template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<bf16_t, bf16_t, bf16_t, 16, 16, 32, TransposeC, false, false, AttrNumAccess, AttrNumAccess>
: WmmaTag { using Type = WarpGemmWmma_bf16_16x16x32_bf16_bf16<TransposeC, AttrNumAccess>;};
#else
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; };

View File

@@ -68,6 +68,13 @@ using WarpGemmWmma_f32_16x16x32_bf16_bf16 =
AttrNumAccess,
AttrNumAccess>>;
template <bool kTransC = false, WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Default>
using WarpGemmWmma_bf16_16x16x32_bf16_bf16 =
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_bf16_16x16x32_bf16_bf16,
kTransC,
AttrNumAccess,
AttrNumAccess>>;
template <bool kTransC = false, WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Default>
using WarpGemmWmma_f32_16x16x16_f8_bf8 =
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_bf8,

View File

@@ -4,7 +4,8 @@
if(GPU_TARGETS MATCHES "gfx95")
add_gtest_executable(test_ck_tile_wg_16x16x128_fp4 test_f32_16x16x128_fp4.cpp)
endif()
if(GPU_TARGETS MATCHES "gfx125")
add_gtest_executable(test_ck_tile_wmma_bf16_16x16x32_gfx1250 test_wmma_bf16_16x16x32_gfx1250.cpp)
add_gtest_executable(test_ck_tile_wg_32x16x128_fp4 test_f32_32x16x128_fp4.cpp)
endif()

View File

@@ -0,0 +1,150 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
using namespace ck_tile;
template <typename A, typename B, typename Acc, index_t M, index_t N, index_t K, bool TransposeC>
struct WGDispCase
{
using AType = A;
using BType = B;
using AccType = Acc;
static constexpr index_t MPerWave = M;
static constexpr index_t NPerWave = N;
static constexpr index_t KPerWave = K;
static constexpr bool kTransposeC = TransposeC;
};
using WGDispatcherTypesList =
::testing::Types<WGDispCase<bf16_t, bf16_t, bf16_t, 16, 16, 32, false>,
WGDispCase<bf16_t, bf16_t, float, 16, 16, 32, false>>;
template <typename AType,
typename BType,
typename CType,
index_t M,
index_t N,
index_t K,
bool TransposeC>
struct WarpGemmKernel
{
static constexpr int kBlockSize = 32;
__device__ void operator()(void* A, void* B, void* C) const
{
using WarpGemm = WarpGemmDispatcher<AType, BType, CType, M, N, K, TransposeC>;
const auto a_view =
make_naive_tensor_view<address_space_enum::global>(static_cast<AType*>(A),
make_tuple(M, K),
make_tuple(K, number<1>{}),
number<K>{},
number<1>{});
const auto b_view =
make_naive_tensor_view<address_space_enum::global>(static_cast<BType*>(B),
make_tuple(N, K),
make_tuple(K, number<1>{}),
number<K>{},
number<1>{});
const auto c_view =
make_naive_tensor_view<address_space_enum::global>(static_cast<CType*>(C),
make_tuple(M, N),
make_tuple(N, number<1>{}),
number<N>{},
number<1>{});
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto a_len = AWarpTensor::get_tile_distribution().get_lengths();
constexpr auto b_len = BWarpTensor::get_tile_distribution().get_lengths();
constexpr auto c_len = CWarpTensor::get_tile_distribution().get_lengths();
auto a_win = make_tile_window(
a_view, a_len, make_multi_index(0, 0), AWarpTensor::get_tile_distribution());
auto b_win = make_tile_window(
b_view, b_len, make_multi_index(0, 0), BWarpTensor::get_tile_distribution());
auto c_win = make_tile_window(
c_view, c_len, make_multi_index(0, 0), CWarpTensor::get_tile_distribution());
AWarpTensor a_tile;
BWarpTensor b_tile;
load_tile(a_tile, a_win);
load_tile(b_tile, b_win);
auto c_tile = WarpGemm{}(a_tile, b_tile);
store_tile(c_win, c_tile);
}
};
template <typename Case>
static void RunWarpGemmCase(const HostTensor<typename Case::AType>& A,
const HostTensor<typename Case::BType>& B,
HostTensor<typename Case::AccType>& C)
{
DeviceMem Ad(A), Bd(B), Cd(C);
using Kernel = WarpGemmKernel<typename Case::AType,
typename Case::BType,
typename Case::AccType,
Case::MPerWave,
Case::NPerWave,
Case::KPerWave,
Case::kTransposeC>;
dim3 grid(1), block{Kernel::kBlockSize};
(void)launch_kernel(stream_config{nullptr, true, 0, 0, 1},
make_kernel(Kernel{},
grid,
block,
0,
Ad.GetDeviceBuffer(),
Bd.GetDeviceBuffer(),
Cd.GetDeviceBuffer()));
Cd.FromDevice(C.mData.data());
}
template <typename Case>
class WGRuntimeTest : public ::testing::Test
{
};
TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_ReferenceGemm)
{
using Case = TypeParam;
using AType = typename Case::AType;
using BType = typename Case::BType;
using CType = typename Case::AccType;
constexpr index_t M = Case::MPerWave;
constexpr index_t N = Case::NPerWave;
constexpr index_t K = Case::KPerWave;
HostTensor<AType> A({M, K});
HostTensor<BType> B({N, K});
HostTensor<CType> C({M, N});
FillUniformDistribution<AType>{-1.f, 1.f, 11939}(A);
FillUniformDistribution<BType>{-1.f, 1.f, 11940}(B);
C.SetZero();
RunWarpGemmCase<Case>(A, B, C);
HostTensor<CType> C_ref({M, N});
C_ref.SetZero();
reference_gemm<AType, BType, float, CType>(A, B.transpose(), C_ref);
EXPECT_TRUE(check_err(C, C_ref, "Warp gemm bf16 result error."));
}