diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index f05881cb11..d26686ec37 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -408,7 +408,6 @@ struct HostTensor return sizeof(T) * get_element_space_size(); } - // void SetZero() { ck_tile::ranges::fill(mData, 0); } void SetZero() { if constexpr(std::is_same_v) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 0b83ebfe65..3c7944a427 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -115,7 +115,7 @@ struct WarpGemmAttributeMfma const BVecType& b_vec, const int32_t& b_scale) const { - auto c_vec = Impl{}.template operator()(a_vec, a_scale, b_vec, b_scale); + return Impl{}.template operator()(a_vec, a_scale, b_vec, b_scale); } }; diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 7b482d91e2..6378bb8e43 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -30,6 +30,7 @@ add_subdirectory(add_rmsnorm2d_rdquant) # add_subdirectory(rmsnorm2d) add_subdirectory(gemm_block_scale) add_subdirectory(utility) +add_subdirectory(warp_gemm) add_subdirectory(reduce) add_subdirectory(core) add_subdirectory(epilogue) diff --git a/test/ck_tile/warp_gemm/CMakeLists.txt b/test/ck_tile/warp_gemm/CMakeLists.txt new file mode 100644 index 0000000000..664ebc003b --- /dev/null +++ b/test/ck_tile/warp_gemm/CMakeLists.txt @@ -0,0 +1,3 @@ +if(GPU_TARGETS MATCHES "gfx95") + add_gtest_executable(test_ck_tile_wg_16x16x128_fp4 test_f32_16x16x128_fp4.cpp) +endif() diff --git a/test/ck_tile/warp_gemm/test_f32_16x16x128_fp4.cpp b/test/ck_tile/warp_gemm/test_f32_16x16x128_fp4.cpp new file mode 100644 index 0000000000..7878fda618 --- /dev/null +++ b/test/ck_tile/warp_gemm/test_f32_16x16x128_fp4.cpp @@ -0,0 +1,192 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +using namespace ck_tile; + +template +struct WGDispCase +{ + using AType = A; + using BType = B; + using AccType = Acc; + static constexpr index_t MPerWave = M; + static constexpr index_t NPerWave = N; + static constexpr index_t KPerWave = K; + static constexpr bool kTransposeC = TransposeC; + static constexpr bool kSwizzleA = SwizzleA; + static constexpr bool kUSS = UseStructuredSparsity; + static constexpr WGAttrNumAccessEnum kNA = NA; +}; + +using WGDispatcherTypesList = + ::testing::Types>; + +template +struct WarpGemmKernel +{ + static constexpr int kBlockSize = 64; + __device__ void operator()(void* A, void* B, void* C, void* ScaleA, void* ScaleB) const + { + using WarpGemm = ck_tile::WarpGemmDispatcher; + // A: [M,K] row-major (packed) + const auto a_view = ck_tile::make_naive_tensor_view( + static_cast(A), + ck_tile::make_tuple(M, K), + ck_tile::make_tuple(K, ck_tile::number<1>{}), + ck_tile::number{}, + ck_tile::number<1>{}); + // B: expose as logical [N,K] with strides (1, N) over the original row-major [K,N] buffer + const auto b_view = ck_tile::make_naive_tensor_view( + static_cast(B), + ck_tile::make_tuple(N, K), + ck_tile::make_tuple(K, ck_tile::number<1>{}), + ck_tile::number{}, + ck_tile::number<1>{}); + // C: [M,N] row-major (packed) + const auto c_view = ck_tile::make_naive_tensor_view( + static_cast(C), + ck_tile::make_tuple(M, N), + ck_tile::make_tuple(N, ck_tile::number<1>{}), + ck_tile::number{}, + ck_tile::number<1>{}); + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + constexpr auto a_len = AWarpTensor::get_tile_distribution().get_lengths(); + constexpr auto b_len = BWarpTensor::get_tile_distribution().get_lengths(); + constexpr auto c_len = CWarpTensor::get_tile_distribution().get_lengths(); + + auto a_win = ck_tile::make_tile_window( + a_view, a_len, ck_tile::make_multi_index(0, 0), AWarpTensor::get_tile_distribution()); + auto b_win = ck_tile::make_tile_window( + b_view, b_len, ck_tile::make_multi_index(0, 0), BWarpTensor::get_tile_distribution()); + auto c_win = ck_tile::make_tile_window( + c_view, c_len, ck_tile::make_multi_index(0, 0), CWarpTensor::get_tile_distribution()); + + AWarpTensor a_tile; + BWarpTensor b_tile; + ck_tile::load_tile(a_tile, a_win); + ck_tile::load_tile(b_tile, b_win); + + auto scale_a = static_cast(static_cast(ScaleA)[0].get()); + auto scale_b = static_cast(static_cast(ScaleB)[0].get()); + + auto c_tile = WarpGemm{}.template operator()<0, 0>(a_tile, b_tile, scale_a, scale_b); + + ck_tile::store_tile(c_win, c_tile); + } +}; + +template +static void RunWarpGemmCase(const ck_tile::HostTensor& A, + const ck_tile::HostTensor& B, + const ck_tile::HostTensor& ScaleA, + const ck_tile::HostTensor& ScaleB, + ck_tile::HostTensor& C) +{ + ck_tile::DeviceMem Ad(A), Bd(B), Cd(C), SAd(ScaleA), SBd(ScaleB); + dim3 grid(1), block{64}; + + using Kernel = WarpGemmKernel; + + (void)ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true, 0, 0, 1}, + ck_tile::make_kernel(Kernel{}, + grid, + block, + 0, + Ad.GetDeviceBuffer(), + Bd.GetDeviceBuffer(), + Cd.GetDeviceBuffer(), + SAd.GetDeviceBuffer(), + SBd.GetDeviceBuffer())); + + Cd.FromDevice(C.mData.data()); +} + +template +class WGRuntimeTest : public ::testing::Test +{ +}; + +TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList); + +TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) +{ + using Case = TypeParam; + + using AType = typename Case::AType; + using BType = typename Case::BType; + using CType = typename Case::AccType; + using ck_tile::e8m0_t; + + constexpr index_t M = Case::MPerWave; + constexpr index_t N = Case::NPerWave; + constexpr index_t K = Case::KPerWave; + + auto ScaleA = e8m0_t{2.f}; + auto ScaleB = e8m0_t{4.f}; + + ck_tile::HostTensor A({M, K}); + ck_tile::HostTensor B({N, K}); + ck_tile::HostTensor C({M, N}); + ck_tile::HostTensor sA({M, 1}); + ck_tile::HostTensor sB({N, 1}); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(A); + ck_tile::FillUniformDistribution{-5.f, 5.f}(B); + C.SetZero(); + ck_tile::FillConstant{ScaleA}(sA); + ck_tile::FillConstant{ScaleB}(sB); + + RunWarpGemmCase(A, B, sA, sB, C); + + ck_tile::HostTensor C_ref({M, N}); + C_ref.SetZero(); + ck_tile::reference_mx_gemm( + A, B.transpose(), C_ref, sA, sB.transpose()); + + EXPECT_TRUE(ck_tile::check_err(C, C_ref, "Warp gemm result error.")); +}