mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Remove "basic" and universal GEMM tests, and incorporate their test cases into the GEMM pipeline tests (#3094)
* Add missing copyright statements * Use ck_tile::host_tensor_descriptor instead of a custom lambda * Refactor use of check_data_type in test classes * Use TEST_SUITE_NAME with TYPED_TEST_SUITE * Remove an unused namespace * Make dim3 const * Add BF8 x BF8 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Add F8 x BF8 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Add BF16 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Add BF16 x BF16 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Add BF8 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Add F8 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Add F16 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Skip failing tests of F16 x I4 for CompV3 with K == 2 * K_Tile * Add missing precision type combinations to CompV4 from CompV3 * Move the INT8 tests around for consistency with KernelTypesCompV3Wmma * Add missing precision type combinations to CompV3Wmma from CompV3 * Remove the basic and universal tests and their dependencies * On __gfx950__, avoid using transposed loading of A with datatype pk_int4_t of B * Use ADataType and BDataType instead of ComputeDataType for WarpGemm * Explicitly set some return types to void * Use more general typenames in InterleavedPKTypeLoader * Add load_interleaved_pk_type.hpp to common.hpp * Use std::is_same_v in load_int4_tile * Add handling of LoadTranspose to load_int4_tile * Factor out common code in several places using load_int4_tile * Add support for pk_int4_t using load_int4_tile * Fix formatting
This commit is contained in:
@@ -5,7 +5,6 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
@@ -156,7 +155,6 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
|
||||
using Base = BlockGemmAQuantBase<Problem_>;
|
||||
|
||||
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
|
||||
@@ -447,26 +445,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>);
|
||||
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_block_window);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>);
|
||||
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_, b_block_window);
|
||||
}
|
||||
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_>(a_warp_tile_, a_block_window);
|
||||
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_>(b_warp_tile_, b_block_window);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
@@ -155,7 +154,6 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
|
||||
using Base = BlockGemmBQuantBase<Problem_>;
|
||||
|
||||
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
|
||||
@@ -273,26 +271,8 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>);
|
||||
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_block_window);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>);
|
||||
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_, b_block_window);
|
||||
}
|
||||
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_>(a_warp_tile_, a_block_window);
|
||||
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_>(b_warp_tile_, b_block_window);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
|
||||
|
||||
Reference in New Issue
Block a user