test(grouped_gemm): add gtests for the example/grouped_gemm_preshuffle to ensure its integrity (#2811)

* test(grouped_gemm): add gtests for the example to maintain its integrity

* test(grouped_gemm_preshuffle): add prefill variant to testbed to cover wider range

* fix: removed residue code to make b_shuffle() work again

* test(grouped_gemm_preshuffle): limit the test suite to gfx942 arch as it fails on gfx90a

* build: add gfx950 as build target for gtests

* test(grouped_gemm_preshuffle): temporarily disable fp8 prec tests due to numerical errors

* fix(grouped_gemm_preshuffle): resolved fp8 tests failure on gfx950 by adding correct compiler flag
This commit is contained in:
Aviral Goel
2025-09-16 18:43:30 -04:00
committed by GitHub
parent dee185d80c
commit 48e08c6429
7 changed files with 563 additions and 0 deletions

View File

@@ -1,3 +1,10 @@
add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp)
add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp)
add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp)
set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS})

View File

@@ -3,6 +3,7 @@ add_subdirectory(gemm)
add_subdirectory(gemm_weight_preshuffle)
add_subdirectory(batched_gemm)
add_subdirectory(grouped_gemm)
add_subdirectory(grouped_gemm_preshuffle)
add_subdirectory(gemm_multi_d)
add_subdirectory(gemm_streamk)
add_subdirectory(data_type)

View File

@@ -0,0 +1,9 @@
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
if(GPU_TARGETS MATCHES "gfx94|gfx95")
add_gtest_executable(test_ck_tile_grouped_gemm_preshuffle test_grouped_gemm_preshuffle.cpp)
target_compile_options(test_ck_tile_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()

View File

@@ -0,0 +1,58 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_grouped_gemm_preshuffle_util.hpp"
using F16 = ck_tile::half_t;
using F8 = ck_tile::fp8_t;
using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// Custom tuple-like structure for kernel configuration
template <typename ALayout_,
typename BLayout_,
typename CLayout_,
typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CDataType_,
int M_Tile_val_,
int N_Tile_val_,
int K_Tile_val_,
int BlockPerCu_val_>
struct KernelConfig
{
using ALayoutType = ALayout_;
using BLayoutType = BLayout_;
using CLayoutType = CLayout_;
using ADataType = ADataType_;
using BDataType = BDataType_;
using AccDataType = AccDataType_;
using CDataType = CDataType_;
static constexpr int M_Tile_ = M_Tile_val_;
static constexpr int N_Tile_ = N_Tile_val_;
static constexpr int K_Tile_ = K_Tile_val_;
static constexpr int BlockPerCu_ = BlockPerCu_val_;
};
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_Tile, N_Tile, K_Tile, BlockPerCu
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F8, F8, F32, F16, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 128, 128, 2>,
KernelConfig< Row, Col, Row, F8, F8, F32, F16, 128, 128, 128, 2>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGroupedGemmPreshuffle, KernelTypes);
#include "test_grouped_gemm_preshuffle_ut_cases.inc"
#include "test_grouped_gemm_preshuffle_prefill_cases.inc"

View File

@@ -0,0 +1,61 @@
#pragma once
// Test with prefill config struct
TYPED_TEST(TestCkTileGroupedGemmPreshuffle, PrefillVariant)
{
const int group_count = 4;
const int kbatch = 1;
std::vector<int> Ms;
std::vector<int> Ns;
std::vector<int> Ks;
std::vector<int> stride_As;
std::vector<int> stride_Bs;
std::vector<int> stride_Cs;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 128 * i);
Ns.push_back(256 + 128 * i);
Ks.push_back(128 * (i + 1));
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
stride_Cs.push_back(Ns[i]);
}
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
}
TYPED_TEST(TestCkTileGroupedGemmPreshuffle, VariedDimensions)
{
const int group_count = 6;
const int kbatch = 1;
std::vector<int> Ms;
std::vector<int> Ns;
std::vector<int> Ks;
std::vector<int> stride_As;
std::vector<int> stride_Bs;
std::vector<int> stride_Cs;
std::vector<std::tuple<int, int, int>> test_cases = {{64, 128, 256},
{128, 256, 512},
{256, 512, 1024},
{512, 256, 128},
{128, 128, 128},
{64, 512, 256}};
for(int i = 0; i < group_count; i++)
{
auto [M, N, K] = test_cases[i];
Ms.push_back(M);
Ns.push_back(N);
Ks.push_back(K);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
stride_Cs.push_back(Ns[i]);
}
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
}

View File

@@ -0,0 +1,53 @@
#pragma once
// kPadK is not needed for these k values
TYPED_TEST(TestCkTileGroupedGemmPreshuffle, kPadKFalse)
{
const int group_count = 4;
const int kbatch = 1;
std::vector<int> Ms;
std::vector<int> Ns;
std::vector<int> Ks;
std::vector<int> stride_As;
std::vector<int> stride_Bs;
std::vector<int> stride_Cs;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(256 + 512 * i);
Ks.push_back(512 + 256 * i);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
stride_Cs.push_back(Ns[i]);
}
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
}
// kPadK is needed to be true for these k values
TYPED_TEST(TestCkTileGroupedGemmPreshuffle, kPadKTrue)
{
const int group_count = 4;
const int kbatch = 1;
std::vector<int> Ms;
std::vector<int> Ns;
std::vector<int> Ks;
std::vector<int> stride_As;
std::vector<int> stride_Bs;
std::vector<int> stride_Cs;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(256 + 512 * i);
Ks.push_back(512 + 128 * i);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
stride_Cs.push_back(Ns[i]);
}
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
}

View File

@@ -0,0 +1,374 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
{
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
template <typename Tuple>
class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
{
protected:
using ALayout = typename Tuple::ALayoutType;
using BLayout = typename Tuple::BLayoutType;
using CLayout = typename Tuple::CLayoutType;
using ADataType = typename Tuple::ADataType;
using BDataType = typename Tuple::BDataType;
using AccDataType = typename Tuple::AccDataType;
using CDataType = typename Tuple::CDataType;
using PrecType = BDataType;
using DsLayout = ck_tile::tuple<>; // not used
using DsDataType = ck_tile::tuple<>; // not used
static const bool kPadM = false;
static const bool kPadN = false;
static const bool kPadK = true; // preshuffle pipeline requires k padding
static const int kBlockPerCu = Tuple::BlockPerCu_;
// Tile dimensions from tuple
static const ck_tile::index_t M_Tile = Tuple::M_Tile_;
static const ck_tile::index_t N_Tile = Tuple::N_Tile_;
static const ck_tile::index_t K_Tile = Tuple::K_Tile_;
static const ck_tile::index_t M_Warp = 1;
static const ck_tile::index_t N_Warp = 4;
static const ck_tile::index_t K_Warp = 1;
static const ck_tile::index_t M_Warp_Tile = 16;
static const ck_tile::index_t N_Warp_Tile = 16;
static const ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<BDataType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem
static constexpr bool TransposeC = false; // transpose c is not supported
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
}
template <typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view(
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename ALayout, typename BLayout, typename CLayout>
void invoke_grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
{
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
// for testing purposes, we can hardcode the values here as we what is compatible with
// pipeline
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC,
/*UseStructuredSparsity*/ false,
/*Persistent*/ false,
/*NumWaveGroups*/ 1,
/*Preshuffle*/ true>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline =
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
TileParitionerGroupNum,
TileParitionerM01>::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
ck_tile::GemmPipelineScheduler::Default,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline =
ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
const dim3 grids = Kernel::GridSize(gemm_descs);
const dim3 blocks = Kernel::BlockSize();
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(gemm_descs[0].k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
// EXPECT TO FAIL because splitk is not supported
EXPECT_FALSE(true);
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
public:
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
std::vector<int>& stride_As,
std::vector<int>& stride_Bs,
std::vector<int>& stride_Cs,
const int kbatch = 1,
const int group_count = 16)
{
using namespace ck_tile::literals;
auto f_host_tensor_descriptor = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0)
{
if constexpr(std::is_same_v<decltype(layout),
ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
a_m_k_tensors.reserve(group_count);
b_k_n_tensors.reserve(group_count);
c_m_n_tensors.reserve(group_count);
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
a_m_k_dev_buf.reserve(group_count);
b_k_n_dev_buf.reserve(group_count);
c_m_n_dev_buf.reserve(group_count);
std::vector<grouped_gemm_kargs> gemm_descs;
gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; ++i)
{
const ck_tile::index_t M = Ms[i];
const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i];
stride_As[i] = f_get_default_stride(M, K, stride_As[i], ALayout{});
stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], BLayout{});
stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{});
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
f_host_tensor_descriptor(M, K, stride_As[i], ALayout{})));
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
f_host_tensor_descriptor(K, N, stride_Bs[i], BLayout{})));
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{})));
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
// Host-side preshuffle of B
auto b_shuffle_host = shuffle_b(b_k_n_tensors[i]);
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
a_m_k_tensors[i].get_element_space_size_in_bytes()));
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
b_shuffle_host.get_element_space_size_in_bytes()));
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
c_m_n_tensors[i].get_element_space_size_in_bytes()));
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data());
c_m_n_dev_buf[i]->SetZero();
c_m_n_tensors[i].SetZero();
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
gemm_descs.push_back(
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
}
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(get_workspace_size(gemm_descs));
invoke_grouped_gemm<ALayout, BLayout, CLayout>(gemm_descs,
ck_tile::stream_config{nullptr, false, 1},
gemm_workspace.GetDeviceBuffer());
// Copy results back to host for validation
for(int i = 0; i < group_count; i++)
{
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
}
bool pass{true};
for(int i = 0; i < group_count; ++i)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{}));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol =
calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
Ks[i], kbatch, max_accumulated_value);
pass &= ck_tile::check_err(c_m_n_tensors[i],
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
}
EXPECT_TRUE(pass);
}
};