Remove "basic" and universal GEMM tests, and incorporate their test cases into the GEMM pipeline tests (#3094)

* Add missing copyright statements

* Use ck_tile::host_tensor_descriptor instead of a custom lambda

* Refactor use of check_data_type in test classes

* Use TEST_SUITE_NAME with TYPED_TEST_SUITE

* Remove an unused namespace

* Make dim3 const

* Add BF8 x BF8 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add F8 x BF8 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add BF16 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add BF16 x BF16 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add BF8 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add F8 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add F16 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Skip failing tests of F16 x I4 for CompV3 with K == 2 * K_Tile

* Add missing precision type combinations to CompV4 from CompV3

* Move the INT8 tests around for consistency with KernelTypesCompV3Wmma

* Add missing precision type combinations to CompV3Wmma from CompV3

* Remove the basic and universal tests and their dependencies

* On __gfx950__, avoid using transposed loading of A with datatype pk_int4_t of B

* Use ADataType and BDataType instead of ComputeDataType for WarpGemm

* Explicitly set some return types to void

* Use more general typenames in InterleavedPKTypeLoader

* Add load_interleaved_pk_type.hpp to common.hpp

* Use std::is_same_v in load_int4_tile

* Add handling of LoadTranspose to load_int4_tile

* Factor out common code in several places using load_int4_tile

* Add support for pk_int4_t using load_int4_tile

* Fix formatting
This commit is contained in:
SamiAario-AMD
2025-11-13 21:01:27 +02:00
committed by GitHub
parent 7d57bc169f
commit f2cfc6b94e
38 changed files with 352 additions and 1888 deletions

View File

@@ -5,7 +5,6 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
@@ -156,7 +155,6 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
using Base = BlockGemmAQuantBase<Problem_>;
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
@@ -447,26 +445,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>);
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else
{
load_tile(a_warp_tile_, a_block_window);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>);
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else
{
load_tile(b_warp_tile_, b_block_window);
}
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_>(a_warp_tile_, a_block_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_>(b_warp_tile_, b_block_window);
}
// C += A * B

View File

@@ -5,7 +5,6 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
@@ -155,7 +154,6 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
using Base = BlockGemmBQuantBase<Problem_>;
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
@@ -273,26 +271,8 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>);
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else
{
load_tile(a_warp_tile_, a_block_window);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>);
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else
{
load_tile(b_warp_tile_, b_block_window);
}
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_>(a_warp_tile_, a_block_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_>(b_warp_tile_, b_block_window);
}
// C += A * B

View File

@@ -7,7 +7,6 @@
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"