mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Remove "basic" and universal GEMM tests, and incorporate their test cases into the GEMM pipeline tests (#3094)
* Add missing copyright statements * Use ck_tile::host_tensor_descriptor instead of a custom lambda * Refactor use of check_data_type in test classes * Use TEST_SUITE_NAME with TYPED_TEST_SUITE * Remove an unused namespace * Make dim3 const * Add BF8 x BF8 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Add F8 x BF8 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Add BF16 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Add BF16 x BF16 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Add BF8 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Add F8 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Add F16 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp * Skip failing tests of F16 x I4 for CompV3 with K == 2 * K_Tile * Add missing precision type combinations to CompV4 from CompV3 * Move the INT8 tests around for consistency with KernelTypesCompV3Wmma * Add missing precision type combinations to CompV3Wmma from CompV3 * Remove the basic and universal tests and their dependencies * On __gfx950__, avoid using transposed loading of A with datatype pk_int4_t of B * Use ADataType and BDataType instead of ComputeDataType for WarpGemm * Explicitly set some return types to void * Use more general typenames in InterleavedPKTypeLoader * Add load_interleaved_pk_type.hpp to common.hpp * Use std::is_same_v in load_int4_tile * Add handling of LoadTranspose to load_int4_tile * Factor out common code in several places using load_int4_tile * Add support for pk_int4_t using load_int4_tile * Fix formatting
This commit is contained in:
@@ -11,6 +11,14 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
@@ -93,7 +101,7 @@ struct GemmPipelineTypeSelector<GemmPipelineType::CompAsync, Problem>
|
||||
template <typename Tuple, typename Derived>
|
||||
class TestCkTileGemmPipeline : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
public:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
@@ -118,6 +126,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
static constexpr bool Persistent =
|
||||
ck_tile::tuple_element_or_default_t<Tuple, 15, std::false_type>::value;
|
||||
|
||||
protected:
|
||||
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
|
||||
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
@@ -228,7 +237,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
@@ -266,51 +275,19 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
ck_tile::index_t M_Warp_Tile,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile>
|
||||
bool check_data_type()
|
||||
{
|
||||
return static_cast<Derived*>(this)
|
||||
->template check_data_type_impl<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile>();
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
ck_tile::index_t M_Warp_Tile,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile>
|
||||
bool check_data_type_impl()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
public:
|
||||
std::vector<int> k_batches_;
|
||||
|
||||
void SetUp() override
|
||||
{
|
||||
if(!check_data_type<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile>())
|
||||
if constexpr(!Derived::check_data_type())
|
||||
{
|
||||
GTEST_SKIP() << "Unsupported data type combination for gemm pipeline test.";
|
||||
}
|
||||
if constexpr(PipelineType == GemmPipelineType::CompV4)
|
||||
if constexpr(PipelineType == GemmPipelineType::CompV4 ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Only do k_batch = 1 when pipeline is CompV4
|
||||
// Only do k_batch = 1 when pipeline is CompV4, or BDataType is I4
|
||||
k_batches_ = {1};
|
||||
}
|
||||
else
|
||||
@@ -328,9 +305,13 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
const int StrideB = 0,
|
||||
const int StrideC = 0)
|
||||
{
|
||||
for(auto kb : k_batches_)
|
||||
// Some unsupported tests don't compile, so we check here before attempting to.
|
||||
if constexpr(Derived::check_data_type())
|
||||
{
|
||||
RunSingle<PadM, PadN, PadK, Preshuffle>(M, N, K, StrideA, StrideB, StrideC, kb);
|
||||
for(auto kb : k_batches_)
|
||||
{
|
||||
RunSingle<PadM, PadN, PadK, Preshuffle>(M, N, K, StrideA, StrideB, StrideC, kb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -343,49 +324,19 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
const int StrideC,
|
||||
int kbatch = 1)
|
||||
{
|
||||
using namespace ck_tile::literals;
|
||||
ck_tile::index_t stride_A =
|
||||
ck_tile::get_default_stride(M, K, StrideA, is_row_major(ALayout{}));
|
||||
ck_tile::index_t stride_B =
|
||||
ck_tile::get_default_stride(K, N, StrideB, is_row_major(BLayout{}));
|
||||
ck_tile::index_t stride_C =
|
||||
ck_tile::get_default_stride(M, N, StrideC, is_row_major(CLayout{}));
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
// give a chance if stride is zero, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout),
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
ck_tile::index_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
ck_tile::index_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
ck_tile::index_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5, 11939}(a_m_k);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5, 11940}(b_k_n);
|
||||
@@ -394,8 +345,19 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
@@ -416,7 +378,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
bool pass = true;
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
|
||||
Reference in New Issue
Block a user