Remove "basic" and universal GEMM tests, and incorporate their test cases into the GEMM pipeline tests (#3094)

* Add missing copyright statements

* Use ck_tile::host_tensor_descriptor instead of a custom lambda

* Refactor use of check_data_type in test classes

* Use TEST_SUITE_NAME with TYPED_TEST_SUITE

* Remove an unused namespace

* Make dim3 const

* Add BF8 x BF8 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add F8 x BF8 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add BF16 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add BF16 x BF16 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add BF8 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add F8 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add F16 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Skip failing tests of F16 x I4 for CompV3 with K == 2 * K_Tile

* Add missing precision type combinations to CompV4 from CompV3

* Move the INT8 tests around for consistency with KernelTypesCompV3Wmma

* Add missing precision type combinations to CompV3Wmma from CompV3

* Remove the basic and universal tests and their dependencies

* On __gfx950__, avoid using transposed loading of A with datatype pk_int4_t of B

* Use ADataType and BDataType instead of ComputeDataType for WarpGemm

* Explicitly set some return types to void

* Use more general typenames in InterleavedPKTypeLoader

* Add load_interleaved_pk_type.hpp to common.hpp

* Use std::is_same_v in load_int4_tile

* Add handling of LoadTranspose to load_int4_tile

* Factor out common code in several places using load_int4_tile

* Add support for pk_int4_t using load_int4_tile

* Fix formatting
This commit is contained in:
SamiAario-AMD
2025-11-13 21:01:27 +02:00
committed by GitHub
parent 7d57bc169f
commit f2cfc6b94e
38 changed files with 352 additions and 1888 deletions

View File

@@ -94,7 +94,11 @@ struct BlockUniversalGemmAsBsCr
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
@@ -196,8 +200,8 @@ struct BlockUniversalGemmAsBsCr
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
@@ -222,22 +226,10 @@ struct BlockUniversalGemmAsBsCr
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!");
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else
{
load_tile(a_warp_tile_, a_block_window);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else
{
load_tile(b_warp_tile_, b_block_window);
}
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_block_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_block_window);
// hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
@@ -285,8 +277,8 @@ struct BlockUniversalGemmAsBsCr
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
@@ -300,30 +292,10 @@ struct BlockUniversalGemmAsBsCr
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else if constexpr(ALoadTranspose)
{
a_warp_tile_ = load_tile_transpose(a_block_window);
}
else
{
load_tile(a_warp_tile_, a_block_window);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else if constexpr(BLoadTranspose)
{
b_warp_tile_ = load_tile_transpose(b_block_window);
}
else
{
load_tile(b_warp_tile_, b_block_window);
}
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_block_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_block_window);
}
// C += A * B
@@ -396,8 +368,8 @@ struct BlockUniversalGemmAsBsCr
static constexpr auto BLdsTileDistr =
make_static_tile_distribution(MakeBBlockDistributionEncode());
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
@@ -451,30 +423,10 @@ struct BlockUniversalGemmAsBsCr
auto b_lds_gemm_window = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else if constexpr(ALoadTranspose)
{
a_warp_tile_ = load_tile_transpose(a_lds_gemm_window);
}
else
{
load_tile(a_warp_tile_, a_lds_gemm_window);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else if constexpr(BLoadTranspose)
{
b_warp_tile_ = load_tile_transpose(b_lds_gemm_window);
}
else
{
load_tile(b_warp_tile_, b_lds_gemm_window);
}
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_lds_gemm_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_lds_gemm_window);
}
// C += A * B