mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Remove "basic" and universal GEMM tests, and incorporate their test cases into the GEMM pipeline tests (#3094)
* Add missing copyright statements * Use ck_tile::host_tensor_descriptor instead of a custom lambda * Refactor use of check_data_type in test classes * Use TEST_SUITE_NAME with TYPED_TEST_SUITE * Remove an unused namespace * Make dim3 const * Add BF8 x BF8 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Add F8 x BF8 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Add BF16 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Add BF16 x BF16 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Add BF8 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Add F8 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Add F16 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Skip failing tests of F16 x I4 for CompV3 with K == 2 * K_Tile * Add missing precision type combinations to CompV4 from CompV3 * Move the INT8 tests around for consistency with KernelTypesCompV3Wmma * Add missing precision type combinations to CompV3Wmma from CompV3 * Remove the basic and universal tests and their dependencies * On __gfx950__, avoid using transposed loading of A with datatype pk_int4_t of B * Use ADataType and BDataType instead of ComputeDataType for WarpGemm * Explicitly set some return types to void * Use more general typenames in InterleavedPKTypeLoader * Add load_interleaved_pk_type.hpp to common.hpp * Use std::is_same_v in load_int4_tile * Add handling of LoadTranspose to load_int4_tile * Factor out common code in several places using load_int4_tile * Add support for pk_int4_t using load_int4_tile * Fix formatting
This commit is contained in:
@@ -8,16 +8,7 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <class T>
|
||||
struct is_pk_int4 : std::false_type
|
||||
{
|
||||
};
|
||||
template <>
|
||||
struct is_pk_int4<pk_int4_t> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename ComputeDataType, index_t UnaryOpSize>
|
||||
template <typename DstDataType, index_t UnaryOpSize>
|
||||
struct InterleavedPKTypeLoader
|
||||
{
|
||||
template <typename WarpWindow, typename WarpTile>
|
||||
@@ -30,24 +21,30 @@ struct InterleavedPKTypeLoader
|
||||
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
|
||||
const auto in_dstr_tensors = load_tile(warp_window);
|
||||
|
||||
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
|
||||
using DstVectorType = DstDataType __attribute__((ext_vector_type(UnaryOpSize)));
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
|
||||
elementwise_op(warp_tile.get_thread_buffer().template get_as<DstVectorType>()(i),
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename BDataType,
|
||||
typename ComputeDataType,
|
||||
template <typename SrcDataType,
|
||||
typename DstDataType,
|
||||
index_t UnaryOpSize,
|
||||
bool LoadTranspose = false,
|
||||
typename WarpTile,
|
||||
typename WarpWindow>
|
||||
CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src)
|
||||
{
|
||||
if constexpr(is_pk_int4<std::remove_cv_t<BDataType>>::value)
|
||||
if constexpr(std::is_same_v<SrcDataType, pk_int4_t>)
|
||||
{
|
||||
InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
|
||||
static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t");
|
||||
InterleavedPKTypeLoader<DstDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
|
||||
}
|
||||
else if constexpr(LoadTranspose)
|
||||
{
|
||||
dst = load_tile_transpose(src);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -94,7 +94,11 @@ struct BlockUniversalGemmAsBsCr
|
||||
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
|
||||
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
|
||||
using ATypeToUse =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
|
||||
@@ -196,8 +200,8 @@ struct BlockUniversalGemmAsBsCr
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
@@ -222,22 +226,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
"The ADataType and BDataType as defined in "
|
||||
"traits should be the same as correspoinding block window data type!");
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_block_window);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_, b_block_window);
|
||||
}
|
||||
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
|
||||
a_block_window);
|
||||
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
|
||||
b_block_window);
|
||||
// hot loop:
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
@@ -285,8 +277,8 @@ struct BlockUniversalGemmAsBsCr
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
@@ -300,30 +292,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else if constexpr(ALoadTranspose)
|
||||
{
|
||||
a_warp_tile_ = load_tile_transpose(a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_block_window);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else if constexpr(BLoadTranspose)
|
||||
{
|
||||
b_warp_tile_ = load_tile_transpose(b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_, b_block_window);
|
||||
}
|
||||
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
|
||||
a_block_window);
|
||||
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
|
||||
b_block_window);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
@@ -396,8 +368,8 @@ struct BlockUniversalGemmAsBsCr
|
||||
static constexpr auto BLdsTileDistr =
|
||||
make_static_tile_distribution(MakeBBlockDistributionEncode());
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
@@ -451,30 +423,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else if constexpr(ALoadTranspose)
|
||||
{
|
||||
a_warp_tile_ = load_tile_transpose(a_lds_gemm_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_lds_gemm_window);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else if constexpr(BLoadTranspose)
|
||||
{
|
||||
b_warp_tile_ = load_tile_transpose(b_lds_gemm_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_, b_lds_gemm_window);
|
||||
}
|
||||
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
|
||||
a_lds_gemm_window);
|
||||
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
|
||||
b_lds_gemm_window);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
|
||||
@@ -26,8 +26,21 @@ struct GemmPipelineAgBgCrImplBase
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
#if defined(__gfx950__)
|
||||
static constexpr bool is_a_load_tr = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
static constexpr bool is_b_load_tr = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
// The combination of pk_int4_t and transposed loading causes numerical errors.
|
||||
// Therefore do not use transposed loading in this case.
|
||||
static constexpr bool is_a_load_tr = []() {
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else
|
||||
return std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
}();
|
||||
|
||||
static constexpr bool is_b_load_tr = []() {
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else
|
||||
return std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
}();
|
||||
#else
|
||||
static constexpr bool is_a_load_tr = false;
|
||||
static constexpr bool is_b_load_tr = false;
|
||||
|
||||
@@ -33,12 +33,27 @@ template <typename Derived>
|
||||
struct UniversalGemmBasePolicy
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// The combination of pk_int4_t and transposed loading causes numerical errors.
|
||||
// Therefore do not use transposed loading in this case.
|
||||
template <typename Problem>
|
||||
static constexpr bool is_a_load_tr =
|
||||
std::is_same_v<remove_cvref_t<typename Problem::ALayout>, tensor_layout::gemm::ColumnMajor>;
|
||||
static constexpr bool is_a_load_tr = []() {
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else
|
||||
return std::is_same_v<remove_cvref_t<typename Problem::ALayout>,
|
||||
tensor_layout::gemm::ColumnMajor>;
|
||||
}();
|
||||
|
||||
template <typename Problem>
|
||||
static constexpr bool is_b_load_tr =
|
||||
std::is_same_v<remove_cvref_t<typename Problem::BLayout>, tensor_layout::gemm::RowMajor>;
|
||||
static constexpr bool is_b_load_tr = []() {
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else
|
||||
return std::is_same_v<remove_cvref_t<typename Problem::BLayout>,
|
||||
tensor_layout::gemm::RowMajor>;
|
||||
}();
|
||||
#else
|
||||
template <typename Problem>
|
||||
static constexpr bool is_a_load_tr = false;
|
||||
@@ -707,8 +722,15 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
|
||||
: WGAttrNumAccessEnum::Invalid;
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using ATypeToUse =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
|
||||
BTypeToUse,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
@@ -718,8 +740,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
Problem::UseStructuredSparsity,
|
||||
wg_attr_num_access>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<ATypeToUse,
|
||||
BTypeToUse,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
@@ -156,7 +155,6 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
|
||||
using Base = BlockGemmAQuantBase<Problem_>;
|
||||
|
||||
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
|
||||
@@ -447,26 +445,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>);
|
||||
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_block_window);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>);
|
||||
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_, b_block_window);
|
||||
}
|
||||
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_>(a_warp_tile_, a_block_window);
|
||||
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_>(b_warp_tile_, b_block_window);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
@@ -155,7 +154,6 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
|
||||
using Base = BlockGemmBQuantBase<Problem_>;
|
||||
|
||||
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
|
||||
@@ -273,26 +271,8 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>);
|
||||
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_block_window);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>);
|
||||
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_, b_block_window);
|
||||
}
|
||||
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_>(a_warp_tile_, a_block_window);
|
||||
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_>(b_warp_tile_, b_block_window);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
|
||||
|
||||
Reference in New Issue
Block a user