[CK_TILE] CK_TILE GEMM WMMA Support for GFX11/GFX12 (#2466)

* WMMA GEMM F16 Implementation

Signed-off-by: root <tianyuwu@amd.com>

* Self-review

Signed-off-by: root <tianyuwu@amd.com>

* ASIC check minor tweak

Signed-off-by: root <tianyuwu@amd.com>

* add missing include file

* Set GPU_TARGETS to gfx11/12 generic

Signed-off-by: root <tianyuwu@amd.com>

* INT8 GFX12

Signed-off-by: root <tianyuwu@amd.com>

* add int8x16 branch

* Fix CI script

Signed-off-by: root <tianyuwu@amd.com>

* Fix typo

Signed-off-by: root <tianyuwu@amd.com>

* Add CK_Tile WMMA example

Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>

* Fix CI

Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>

* fix clang format

* Set M/N_Warp Back to Constant

Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>

* Use GemmConfigComputeV3 by default

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Remove CK_Tile wmma gemm examples from the CI list

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Add atomic add fallback method for gfx11

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Fix typo

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Omit copyright year

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Support non-square cases

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Fix CI

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Add get_device_ip()

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Revert "Add atomic add fallback method for gfx11"

This reverts commit 4f664969c01b37976c8518c19833d9f1574cd746.

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* Revert "Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12"

This reverts commit 949129a3858a825b2a2c4d3ec01663df18a165a5.

* Revise method name and typos

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* clang-format

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Try fix CI

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Revert "Try fix CI"

This reverts commit 084c683227e64ab6a8137db00c8165fb05bdc902.

* clang-format

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Fix typo caused by merge

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* Fix typo caused by merging

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

---------

Signed-off-by: root <tianyuwu@amd.com>
Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
Co-authored-by: joye <joye@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>

[ROCm/composable_kernel commit: 68134b60e4]
This commit is contained in:
Tianyuan Wu
2025-08-16 07:22:27 +08:00
committed by GitHub
parent 42d775e488
commit ec7ee5b7b7
54 changed files with 1388 additions and 403 deletions

View File

@@ -30,6 +30,14 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_test_executable(test_ck_tile_gemm_pipeline_basic_bf8 test_gemm_pipeline_basic_bf8.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_basic_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
# On Radeon devices, build the WMMA version instead
add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3_wmma test_gemm_pipeline_compv3_wmma.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4_wmma test_gemm_pipeline_compv4_wmma.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_mem_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_compv3_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_compv4_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
else()
message(DEBUG "Skipping ck_tile_gemm tests for current target")
endif()
@@ -46,4 +54,7 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MAT
target_compile_options(test_ck_tile_gemm_pipeline_basic_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_test_executable(test_ck_tile_gemm_pipeline_basic_bf16 test_gemm_pipeline_basic_bf16.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_basic_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent_wmma test_gemm_pipeline_persistent_wmma.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_persistent_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()

View File

@@ -3,7 +3,8 @@
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineCompV3 : public TestCkTileGemmPipeline<T>
class TestCkTileGemmPipelineCompV3
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineCompV3<T>>
{
};

View File

@@ -0,0 +1,17 @@
#include "test_gemm_pipeline_kernel_types.hpp"
#include "test_gemm_pipeline_wmma_base.hpp"
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineCompV3Wmma
: public TestCkTileGemmPipelineWmmaBase<T, TestCkTileGemmPipelineCompV3Wmma<T>>
{
};
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV3Wmma
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV3Wmma, KernelTypesCompV3Wmma);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -3,7 +3,8 @@
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineCompV4 : public TestCkTileGemmPipeline<T>
class TestCkTileGemmPipelineCompV4
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineCompV4<T>>
{
};

View File

@@ -0,0 +1,17 @@
#include "test_gemm_pipeline_kernel_types.hpp"
#include "test_gemm_pipeline_wmma_base.hpp"
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineCompV4Wmma
: public TestCkTileGemmPipelineWmmaBase<T, TestCkTileGemmPipelineCompV4Wmma<T>>
{
};
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV4Wmma
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV4Wmma, KernelTypesCompV4Wmma);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -9,13 +9,16 @@
#include "ck_tile/host.hpp"
#include "test_gemm_pipeline_util.hpp"
using I8 = ck_tile::int8_t;
using I32 = ck_tile::int32_t;
using INT8 = ck_tile::int8_t;
using INT32 = ck_tile::int32_t;
using F16 = ck_tile::half_t;
using F32 = float;
using F8 = ck_tile::fp8_t;
using BF16 = ck_tile::bf16_t;
using BF8 = ck_tile::bf8_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
@@ -30,52 +33,119 @@ using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Co
using Persistent = std::true_type;
using NonPersistent = std::false_type;
using I16 = ck_tile::number<16>;
using I32 = ck_tile::number<32>;
using I64 = ck_tile::number<64>;
using I256 = ck_tile::number<256>;
// clang-format off
using KernelTypesMem = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, Interwave, Mem>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, Interwave, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, Interwave, Mem>
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, M_TileSize, K_TileSize, Scheduler, PipelineType
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>
>;
using KernelTypesMemWmma = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>
>;
using KernelTypesCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, I8, I8, I32, I32, Intrawave, CompV3>,
std::tuple< Row, Col, Row, I8, I8, I32, I32, Intrawave, CompV3>,
std::tuple< Col, Row, Row, I8, I8, I32, I32, Intrawave, CompV3>,
std::tuple< Col, Col, Row, I8, I8, I32, I32, Intrawave, CompV3>
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>
>;
using KernelTypesCompV3Wmma = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>
>;
using KernelTypesCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>
>;
using KernelTypesCompV4Wmma = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>
>;
using KernelTypesPersistent = ::testing::Types<
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, NonPersistent>
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, M_TileSize, K_TileSize, Scheduler, PipelineType, Persistent
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3, NonPersistent>
>;
using KernelTypesPersistentWmma = ::testing::Types<
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3, NonPersistent>
>;
// clang-format on

View File

@@ -3,7 +3,7 @@
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineMem : public TestCkTileGemmPipeline<T>
class TestCkTileGemmPipelineMem : public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineMem<T>>
{
};

View File

@@ -0,0 +1,17 @@
#include "test_gemm_pipeline_kernel_types.hpp"
#include "test_gemm_pipeline_wmma_base.hpp"
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineMemWmma
: public TestCkTileGemmPipelineWmmaBase<T, TestCkTileGemmPipelineMemWmma<T>>
{
};
#define TEST_SUITE_NAME TestCkTileGemmPipelineMemWmma
TYPED_TEST_SUITE(TestCkTileGemmPipelineMemWmma, KernelTypesMemWmma);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -3,7 +3,8 @@
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelinePersistent : public TestCkTileGemmPipeline<T>
class TestCkTileGemmPipelinePersistent
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelinePersistent<T>>
{
};

View File

@@ -0,0 +1,17 @@
#include "test_gemm_pipeline_kernel_types.hpp"
#include "test_gemm_pipeline_wmma_base.hpp"
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelinePersistentWmma
: public TestCkTileGemmPipelineWmmaBase<T, TestCkTileGemmPipelinePersistentWmma<T>>
{
};
#define TEST_SUITE_NAME TestCkTileGemmPipelinePersistentWmma
TYPED_TEST_SUITE(TEST_SUITE_NAME, KernelTypesPersistentWmma);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -69,7 +69,7 @@ struct GemmPipelineTypeSelector<GemmPipelineType::CompV4, Problem>
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV4"; }
};
template <typename Tuple>
template <typename Tuple, typename Derived>
class TestCkTileGemmPipeline : public ::testing::Test
{
protected:
@@ -80,32 +80,30 @@ class TestCkTileGemmPipeline : public ::testing::Test
using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value;
static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value;
static constexpr auto Scheduler = std::tuple_element_t<13, Tuple>::value;
static constexpr auto PipelineType = std::tuple_element_t<14, Tuple>::value;
static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>{};
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>{};
static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>{};
static constexpr ck_tile::index_t M_Warp_Tile = std::tuple_element_t<10, Tuple>{};
static constexpr ck_tile::index_t N_Warp_Tile = std::tuple_element_t<11, Tuple>{};
static constexpr ck_tile::index_t K_Warp_Tile = std::tuple_element_t<12, Tuple>{};
using DsLayout = ck_tile::tuple<>;
using DsDataType = ck_tile::tuple<>;
static constexpr bool Persistent =
ck_tile::tuple_element_or_default_t<Tuple, 9, std::false_type>::value;
// TODO: expose tile size through test t-param ?
ck_tile::tuple_element_or_default_t<Tuple, 15, std::false_type>::value;
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
// TODO: This should be parameterized in tests
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = (PipelineType == GemmPipelineType::CompV4) ? 32 : 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool kPadM = PadM;
constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK;
@@ -247,11 +245,48 @@ class TestCkTileGemmPipeline : public ::testing::Test
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
ck_tile::index_t M_Warp_Tile,
ck_tile::index_t N_Warp_Tile,
ck_tile::index_t K_Warp_Tile>
bool check_data_type()
{
return static_cast<Derived*>(this)
->template check_data_type_impl<ADataType,
BDataType,
AccDataType,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile>();
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
ck_tile::index_t M_Warp_Tile,
ck_tile::index_t N_Warp_Tile,
ck_tile::index_t K_Warp_Tile>
bool check_data_type_impl()
{
return true;
}
public:
std::vector<int> k_batches_;
void SetUp() override
{
if(!check_data_type<ADataType,
BDataType,
AccDataType,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile>())
{
GTEST_SKIP() << "Unsupported data type combination for gemm pipeline test.";
}
if constexpr(PipelineType == GemmPipelineType::CompV4)
{
// Only do k_batch = 1 when pipeline is CompV4

View File

@@ -0,0 +1,24 @@
#pragma once
#include "test_gemm_pipeline_util.hpp"
template <typename Tuple, typename Derived>
class TestCkTileGemmPipelineWmmaBase : public TestCkTileGemmPipeline<Tuple, Derived>
{
public:
template <typename ADataType,
typename BDataType,
typename AccDataType,
ck_tile::index_t M_Warp_Tile,
ck_tile::index_t N_Warp_Tile,
ck_tile::index_t K_Warp_Tile>
bool check_data_type_impl()
{
return ck_tile::check_wmma_supported<ADataType,
BDataType,
AccDataType,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile>();
}
};

View File