Remove "basic" and universal GEMM tests, and incorporate their test cases into the GEMM pipeline tests (#3094)

* Add missing copyright statements

* Use ck_tile::host_tensor_descriptor instead of a custom lambda

* Refactor use of check_data_type in test classes

* Use TEST_SUITE_NAME with TYPED_TEST_SUITE

* Remove an unused namespace

* Make dim3 const

* Add BF8 x BF8 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add F8 x BF8 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add BF16 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add BF16 x BF16 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add BF8 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add F8 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Add F16 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp

* Skip failing tests of F16 x I4 for CompV3 with K == 2 * K_Tile

* Add missing precision type combinations to CompV4 from CompV3

* Move the INT8 tests around for consistency with KernelTypesCompV3Wmma

* Add missing precision type combinations to CompV3Wmma from CompV3

* Remove the basic and universal tests and their dependencies

* On __gfx950__, avoid using transposed loading of A with datatype pk_int4_t of B

* Use ADataType and BDataType instead of ComputeDataType for WarpGemm

* Explicitly set some return types to void

* Use more general typenames in InterleavedPKTypeLoader

* Add load_interleaved_pk_type.hpp to common.hpp

* Use std::is_same_v in load_int4_tile

* Add handling of LoadTranspose to load_int4_tile

* Factor out common code in several places using load_int4_tile

* Add support for pk_int4_t using load_int4_tile

* Fix formatting
This commit is contained in:
SamiAario-AMD
2025-11-13 21:01:27 +02:00
committed by GitHub
parent 7d57bc169f
commit f2cfc6b94e
38 changed files with 352 additions and 1888 deletions

View File

@@ -8,16 +8,7 @@
namespace ck_tile {
template <class T>
struct is_pk_int4 : std::false_type
{
};
template <>
struct is_pk_int4<pk_int4_t> : std::true_type
{
};
template <typename ComputeDataType, index_t UnaryOpSize>
template <typename DstDataType, index_t UnaryOpSize>
struct InterleavedPKTypeLoader
{
template <typename WarpWindow, typename WarpTile>
@@ -30,24 +21,30 @@ struct InterleavedPKTypeLoader
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(warp_window);
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
using DstVectorType = DstDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
elementwise_op(warp_tile.get_thread_buffer().template get_as<DstVectorType>()(i),
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
});
}
};
template <typename BDataType,
typename ComputeDataType,
template <typename SrcDataType,
typename DstDataType,
index_t UnaryOpSize,
bool LoadTranspose = false,
typename WarpTile,
typename WarpWindow>
CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src)
{
if constexpr(is_pk_int4<std::remove_cv_t<BDataType>>::value)
if constexpr(std::is_same_v<SrcDataType, pk_int4_t>)
{
InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t");
InterleavedPKTypeLoader<DstDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
}
else if constexpr(LoadTranspose)
{
dst = load_tile_transpose(src);
}
else
{