Remove "basic" and universal GEMM tests, and incorporate their test cases into the GEMM pipeline tests (#3094)

* Add missing copyright statements

* Use ck_tile::host_tensor_descriptor instead of a custom lambda

* Refactor use of check_data_type in test classes

* Use TEST_SUITE_NAME with TYPED_TEST_SUITE

* Remove an unused namespace

* Make dim3 const

* Add BF8 x BF8 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add F8 x BF8 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add BF16 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add BF16 x BF16 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add BF8 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add F8 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add F16 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Skip failing tests of F16 x I4 for CompV3 with K == 2 * K_Tile

* Add missing precision type combinations to CompV4 from CompV3

* Move the INT8 tests around for consistency with KernelTypesCompV3Wmma

* Add missing precision type combinations to CompV3Wmma from CompV3

* Remove the basic and universal tests and their dependencies

* On __gfx950__, avoid using transposed loading of A with datatype pk_int4_t of B

* Use ADataType and BDataType instead of ComputeDataType for WarpGemm

* Explicitly set some return types to void

* Use more general typenames in InterleavedPKTypeLoader

* Add load_interleaved_pk_type.hpp to common.hpp

* Use std::is_same_v in load_int4_tile

* Add handling of LoadTranspose to load_int4_tile

* Factor out common code in several places using load_int4_tile

* Add support for pk_int4_t using load_int4_tile

* Fix formatting
This commit is contained in:
SamiAario-AMD
2025-11-13 21:01:27 +02:00
committed by GitHub
parent 7d57bc169f
commit f2cfc6b94e
38 changed files with 352 additions and 1888 deletions

View File

@@ -11,6 +11,14 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
@@ -93,7 +101,7 @@ struct GemmPipelineTypeSelector<GemmPipelineType::CompAsync, Problem>
template <typename Tuple, typename Derived>
class TestCkTileGemmPipeline : public ::testing::Test
{
protected:
public:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
@@ -118,6 +126,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
static constexpr bool Persistent =
ck_tile::tuple_element_or_default_t<Tuple, 15, std::false_type>::value;
protected:
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
@@ -228,7 +237,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -266,51 +275,19 @@ class TestCkTileGemmPipeline : public ::testing::Test
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
ck_tile::index_t M_Warp_Tile,
ck_tile::index_t N_Warp_Tile,
ck_tile::index_t K_Warp_Tile>
bool check_data_type()
{
return static_cast<Derived*>(this)
->template check_data_type_impl<ADataType,
BDataType,
AccDataType,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile>();
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
ck_tile::index_t M_Warp_Tile,
ck_tile::index_t N_Warp_Tile,
ck_tile::index_t K_Warp_Tile>
bool check_data_type_impl()
{
return true;
}
public:
std::vector<int> k_batches_;
void SetUp() override
{
if(!check_data_type<ADataType,
BDataType,
AccDataType,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile>())
if constexpr(!Derived::check_data_type())
{
GTEST_SKIP() << "Unsupported data type combination for gemm pipeline test.";
}
if constexpr(PipelineType == GemmPipelineType::CompV4)
if constexpr(PipelineType == GemmPipelineType::CompV4 ||
std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Only do k_batch = 1 when pipeline is CompV4
// Only do k_batch = 1 when pipeline is CompV4, or BDataType is I4
k_batches_ = {1};
}
else
@@ -328,9 +305,13 @@ class TestCkTileGemmPipeline : public ::testing::Test
const int StrideB = 0,
const int StrideC = 0)
{
for(auto kb : k_batches_)
// Some unsupported tests don't compile, so we check here before attempting to.
if constexpr(Derived::check_data_type())
{
RunSingle<PadM, PadN, PadK, Preshuffle>(M, N, K, StrideA, StrideB, StrideC, kb);
for(auto kb : k_batches_)
{
RunSingle<PadM, PadN, PadK, Preshuffle>(M, N, K, StrideA, StrideB, StrideC, kb);
}
}
}
@@ -343,49 +324,19 @@ class TestCkTileGemmPipeline : public ::testing::Test
const int StrideC,
int kbatch = 1)
{
using namespace ck_tile::literals;
ck_tile::index_t stride_A =
ck_tile::get_default_stride(M, K, StrideA, is_row_major(ALayout{}));
ck_tile::index_t stride_B =
ck_tile::get_default_stride(K, N, StrideB, is_row_major(BLayout{}));
ck_tile::index_t stride_C =
ck_tile::get_default_stride(M, N, StrideC, is_row_major(CLayout{}));
auto f_host_tensor_descriptor = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout),
ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
ck_tile::index_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{});
ck_tile::index_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{});
ck_tile::index_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{});
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(BLayout{})));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5, 11939}(a_m_k);
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5, 11940}(b_k_n);
@@ -394,8 +345,19 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else
{
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
@@ -416,7 +378,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
bool pass = true;
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(