mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-11 00:39:02 +00:00
[rocm-libraries] ROCm/rocm-libraries#7830 (commit 590fe58)
[CK_Tile][MI450] Add bf16 output wmma instruction (16x16x32) (#7830) Wire __builtin_amdgcn_wmma_bf16_16x16x32_bf16 into CK Tile for gfx1250, enabling bf16-input bf16-output WMMA at the warp GEMM level. - Add WmmaTraits specialization for <gfx125_t, bf16, bf16, bf16, 16,16,32> - Add WarpGemmAttributeWmmaImpl typedef and WarpGemmWmma alias - Add Dispatcher entry for bf16->bf16 16x16x32 - Add warp_gemm test with reference GEMM validation ## Motivation <!-- Explain the purpose of this PR and the goals it aims to achieve. --> ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
919096fde8
commit
99ab4c4ef7
@@ -146,6 +146,9 @@ using WarpGemmAttributeWmmaImpl_f32_16x16x32_f16_f16 =
|
||||
using WarpGemmAttributeWmmaImpl_f32_16x16x32_bf16_bf16 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx125_t, bf16_t, bf16_t, float, 16, 16, 32>>;
|
||||
|
||||
using WarpGemmAttributeWmmaImpl_bf16_16x16x32_bf16_bf16 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx125_t, bf16_t, bf16_t, bf16_t, 16, 16, 32>>;
|
||||
|
||||
using WarpGemmAttributeWmmaImpl_i32_16x16x64_i8_i8 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx125_t, int8_t, int8_t, int32_t, 16, 16, 64>>;
|
||||
|
||||
|
||||
@@ -142,4 +142,28 @@ struct WmmaTraits<gfx125_t, bf16_t, bf16_t, float, 16, 16, 32>
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// bf16 -> bf16 specialization - GFX125
|
||||
template <>
|
||||
struct WmmaTraits<gfx125_t, bf16_t, bf16_t, bf16_t, 16, 16, 32>
|
||||
: WmmaTraitsBase<gfx12_t, bf16_t, bf16_t, bf16_t, 32>
|
||||
{
|
||||
using ArchType = gfx125_t;
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx125__
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return __builtin_amdgcn_wmma_bf16_16x16x32_bf16(
|
||||
0, a_vec, 0, b_vec, 0, c_vec, P::reuse_a, P::reuse_b);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -149,6 +149,8 @@ template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 16, true> { using T
|
||||
#if defined(__gfx125__)
|
||||
template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, TransposeC, false, false, AttrNumAccess, AttrNumAccess>
|
||||
: WmmaTag { using Type = WarpGemmWmma_f32_16x16x32_bf16_bf16<TransposeC, AttrNumAccess>;};
|
||||
template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<bf16_t, bf16_t, bf16_t, 16, 16, 32, TransposeC, false, false, AttrNumAccess, AttrNumAccess>
|
||||
: WmmaTag { using Type = WarpGemmWmma_bf16_16x16x32_bf16_bf16<TransposeC, AttrNumAccess>;};
|
||||
#else
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; };
|
||||
|
||||
@@ -68,6 +68,13 @@ using WarpGemmWmma_f32_16x16x32_bf16_bf16 =
|
||||
AttrNumAccess,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <bool kTransC = false, WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Default>
|
||||
using WarpGemmWmma_bf16_16x16x32_bf16_bf16 =
|
||||
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_bf16_16x16x32_bf16_bf16,
|
||||
kTransC,
|
||||
AttrNumAccess,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <bool kTransC = false, WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Default>
|
||||
using WarpGemmWmma_f32_16x16x16_f8_bf8 =
|
||||
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_bf8,
|
||||
|
||||
@@ -4,7 +4,8 @@
|
||||
if(GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_wg_16x16x128_fp4 test_f32_16x16x128_fp4.cpp)
|
||||
endif()
|
||||
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx125")
|
||||
add_gtest_executable(test_ck_tile_wmma_bf16_16x16x32_gfx1250 test_wmma_bf16_16x16x32_gfx1250.cpp)
|
||||
add_gtest_executable(test_ck_tile_wg_32x16x128_fp4 test_f32_32x16x128_fp4.cpp)
|
||||
endif()
|
||||
|
||||
150
test/ck_tile/warp_gemm/test_wmma_bf16_16x16x32_gfx1250.cpp
Normal file
150
test/ck_tile/warp_gemm/test_wmma_bf16_16x16x32_gfx1250.cpp
Normal file
@@ -0,0 +1,150 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
template <typename A, typename B, typename Acc, index_t M, index_t N, index_t K, bool TransposeC>
|
||||
struct WGDispCase
|
||||
{
|
||||
using AType = A;
|
||||
using BType = B;
|
||||
using AccType = Acc;
|
||||
static constexpr index_t MPerWave = M;
|
||||
static constexpr index_t NPerWave = N;
|
||||
static constexpr index_t KPerWave = K;
|
||||
static constexpr bool kTransposeC = TransposeC;
|
||||
};
|
||||
|
||||
using WGDispatcherTypesList =
|
||||
::testing::Types<WGDispCase<bf16_t, bf16_t, bf16_t, 16, 16, 32, false>,
|
||||
WGDispCase<bf16_t, bf16_t, float, 16, 16, 32, false>>;
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
bool TransposeC>
|
||||
struct WarpGemmKernel
|
||||
{
|
||||
static constexpr int kBlockSize = 32;
|
||||
__device__ void operator()(void* A, void* B, void* C) const
|
||||
{
|
||||
using WarpGemm = WarpGemmDispatcher<AType, BType, CType, M, N, K, TransposeC>;
|
||||
|
||||
const auto a_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(static_cast<AType*>(A),
|
||||
make_tuple(M, K),
|
||||
make_tuple(K, number<1>{}),
|
||||
number<K>{},
|
||||
number<1>{});
|
||||
|
||||
const auto b_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(static_cast<BType*>(B),
|
||||
make_tuple(N, K),
|
||||
make_tuple(K, number<1>{}),
|
||||
number<K>{},
|
||||
number<1>{});
|
||||
|
||||
const auto c_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(static_cast<CType*>(C),
|
||||
make_tuple(M, N),
|
||||
make_tuple(N, number<1>{}),
|
||||
number<N>{},
|
||||
number<1>{});
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto a_len = AWarpTensor::get_tile_distribution().get_lengths();
|
||||
constexpr auto b_len = BWarpTensor::get_tile_distribution().get_lengths();
|
||||
constexpr auto c_len = CWarpTensor::get_tile_distribution().get_lengths();
|
||||
|
||||
auto a_win = make_tile_window(
|
||||
a_view, a_len, make_multi_index(0, 0), AWarpTensor::get_tile_distribution());
|
||||
auto b_win = make_tile_window(
|
||||
b_view, b_len, make_multi_index(0, 0), BWarpTensor::get_tile_distribution());
|
||||
auto c_win = make_tile_window(
|
||||
c_view, c_len, make_multi_index(0, 0), CWarpTensor::get_tile_distribution());
|
||||
|
||||
AWarpTensor a_tile;
|
||||
BWarpTensor b_tile;
|
||||
load_tile(a_tile, a_win);
|
||||
load_tile(b_tile, b_win);
|
||||
|
||||
auto c_tile = WarpGemm{}(a_tile, b_tile);
|
||||
|
||||
store_tile(c_win, c_tile);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Case>
|
||||
static void RunWarpGemmCase(const HostTensor<typename Case::AType>& A,
|
||||
const HostTensor<typename Case::BType>& B,
|
||||
HostTensor<typename Case::AccType>& C)
|
||||
{
|
||||
DeviceMem Ad(A), Bd(B), Cd(C);
|
||||
|
||||
using Kernel = WarpGemmKernel<typename Case::AType,
|
||||
typename Case::BType,
|
||||
typename Case::AccType,
|
||||
Case::MPerWave,
|
||||
Case::NPerWave,
|
||||
Case::KPerWave,
|
||||
Case::kTransposeC>;
|
||||
dim3 grid(1), block{Kernel::kBlockSize};
|
||||
|
||||
(void)launch_kernel(stream_config{nullptr, true, 0, 0, 1},
|
||||
make_kernel(Kernel{},
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
Ad.GetDeviceBuffer(),
|
||||
Bd.GetDeviceBuffer(),
|
||||
Cd.GetDeviceBuffer()));
|
||||
|
||||
Cd.FromDevice(C.mData.data());
|
||||
}
|
||||
|
||||
template <typename Case>
|
||||
class WGRuntimeTest : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
|
||||
|
||||
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_ReferenceGemm)
|
||||
{
|
||||
using Case = TypeParam;
|
||||
|
||||
using AType = typename Case::AType;
|
||||
using BType = typename Case::BType;
|
||||
using CType = typename Case::AccType;
|
||||
|
||||
constexpr index_t M = Case::MPerWave;
|
||||
constexpr index_t N = Case::NPerWave;
|
||||
constexpr index_t K = Case::KPerWave;
|
||||
|
||||
HostTensor<AType> A({M, K});
|
||||
HostTensor<BType> B({N, K});
|
||||
HostTensor<CType> C({M, N});
|
||||
|
||||
FillUniformDistribution<AType>{-1.f, 1.f, 11939}(A);
|
||||
FillUniformDistribution<BType>{-1.f, 1.f, 11940}(B);
|
||||
C.SetZero();
|
||||
|
||||
RunWarpGemmCase<Case>(A, B, C);
|
||||
|
||||
HostTensor<CType> C_ref({M, N});
|
||||
C_ref.SetZero();
|
||||
reference_gemm<AType, BType, float, CType>(A, B.transpose(), C_ref);
|
||||
|
||||
EXPECT_TRUE(check_err(C, C_ref, "Warp gemm bf16 result error."));
|
||||
}
|
||||
Reference in New Issue
Block a user