mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
[CK_TILE] CK_TILE GEMM WMMA Support for GFX11/GFX12 (#2466)
* WMMA GEMM F16 Implementation
Signed-off-by: root <tianyuwu@amd.com>
* Self-review
Signed-off-by: root <tianyuwu@amd.com>
* ASIC check minor tweak
Signed-off-by: root <tianyuwu@amd.com>
* add missing include file
* Set GPU_TARGETS to gfx11/12 generic
Signed-off-by: root <tianyuwu@amd.com>
* INT8 GFX12
Signed-off-by: root <tianyuwu@amd.com>
* add int8x16 branch
* Fix CI script
Signed-off-by: root <tianyuwu@amd.com>
* Fix typo
Signed-off-by: root <tianyuwu@amd.com>
* Add CK_Tile WMMA example
Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>
* Fix CI
Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>
* fix clang format
* Set M/N_Warp Back to Constant
Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>
* Use GemmConfigComputeV3 by default
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Remove CK_Tile wmma gemm examples from the CI list
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Add atomic add fallback method for gfx11
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Fix typo
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Omit copyright year
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Support non-square cases
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Fix CI
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Add get_device_ip()
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Revert "Add atomic add fallback method for gfx11"
This reverts commit 4f664969c01b37976c8518c19833d9f1574cd746.
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
* Revert "Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12"
This reverts commit 949129a3858a825b2a2c4d3ec01663df18a165a5.
* Revise method name and typos
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
* clang-format
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Try fix CI
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Revert "Try fix CI"
This reverts commit 084c683227e64ab6a8137db00c8165fb05bdc902.
* clang-format
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Fix typo caused by merge
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
* Fix typo caused by merging
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
---------
Signed-off-by: root <tianyuwu@amd.com>
Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
Co-authored-by: joye <joye@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
[ROCm/composable_kernel commit: 68134b60e4]
This commit is contained in:
@@ -30,6 +30,14 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_basic_bf8 test_gemm_pipeline_basic_bf8.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
|
||||
# On Radeon devices, build the WMMA version instead
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3_wmma test_gemm_pipeline_compv3_wmma.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4_wmma test_gemm_pipeline_compv4_wmma.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_mem_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv3_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv4_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile_gemm tests for current target")
|
||||
endif()
|
||||
@@ -46,4 +54,7 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MAT
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_basic_bf16 test_gemm_pipeline_basic_bf16.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent_wmma test_gemm_pipeline_persistent_wmma.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_persistent_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineCompV3 : public TestCkTileGemmPipeline<T>
|
||||
class TestCkTileGemmPipelineCompV3
|
||||
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineCompV3<T>>
|
||||
{
|
||||
};
|
||||
|
||||
|
||||
17
test/ck_tile/gemm/test_gemm_pipeline_compv3_wmma.cpp
Normal file
17
test/ck_tile/gemm/test_gemm_pipeline_compv3_wmma.cpp
Normal file
@@ -0,0 +1,17 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_wmma_base.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineCompV3Wmma
|
||||
: public TestCkTileGemmPipelineWmmaBase<T, TestCkTileGemmPipelineCompV3Wmma<T>>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV3Wmma
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV3Wmma, KernelTypesCompV3Wmma);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -3,7 +3,8 @@
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineCompV4 : public TestCkTileGemmPipeline<T>
|
||||
class TestCkTileGemmPipelineCompV4
|
||||
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineCompV4<T>>
|
||||
{
|
||||
};
|
||||
|
||||
|
||||
17
test/ck_tile/gemm/test_gemm_pipeline_compv4_wmma.cpp
Normal file
17
test/ck_tile/gemm/test_gemm_pipeline_compv4_wmma.cpp
Normal file
@@ -0,0 +1,17 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_wmma_base.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineCompV4Wmma
|
||||
: public TestCkTileGemmPipelineWmmaBase<T, TestCkTileGemmPipelineCompV4Wmma<T>>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV4Wmma
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV4Wmma, KernelTypesCompV4Wmma);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -9,13 +9,16 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_pipeline_util.hpp"
|
||||
|
||||
using I8 = ck_tile::int8_t;
|
||||
using I32 = ck_tile::int32_t;
|
||||
using INT8 = ck_tile::int8_t;
|
||||
using INT32 = ck_tile::int32_t;
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
|
||||
@@ -30,52 +33,119 @@ using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Co
|
||||
using Persistent = std::true_type;
|
||||
using NonPersistent = std::false_type;
|
||||
|
||||
using I16 = ck_tile::number<16>;
|
||||
using I32 = ck_tile::number<32>;
|
||||
using I64 = ck_tile::number<64>;
|
||||
using I256 = ck_tile::number<256>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypesMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, Interwave, Mem>
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, M_TileSize, K_TileSize, Scheduler, PipelineType
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>
|
||||
>;
|
||||
|
||||
using KernelTypesMemWmma = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>
|
||||
>;
|
||||
|
||||
using KernelTypesCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, I8, I8, I32, I32, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, I8, I8, I32, I32, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, I8, I8, I32, I32, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, I8, I8, I32, I32, Intrawave, CompV3>
|
||||
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>
|
||||
>;
|
||||
|
||||
using KernelTypesCompV3Wmma = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>
|
||||
>;
|
||||
|
||||
using KernelTypesCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>
|
||||
>;
|
||||
|
||||
using KernelTypesCompV4Wmma = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>
|
||||
>;
|
||||
|
||||
|
||||
using KernelTypesPersistent = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, Persistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, NonPersistent>
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, M_TileSize, K_TileSize, Scheduler, PipelineType, Persistent
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3, Persistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3, NonPersistent>
|
||||
>;
|
||||
|
||||
using KernelTypesPersistentWmma = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3, Persistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3, NonPersistent>
|
||||
>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineMem : public TestCkTileGemmPipeline<T>
|
||||
class TestCkTileGemmPipelineMem : public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineMem<T>>
|
||||
{
|
||||
};
|
||||
|
||||
|
||||
17
test/ck_tile/gemm/test_gemm_pipeline_mem_wmma.cpp
Normal file
17
test/ck_tile/gemm/test_gemm_pipeline_mem_wmma.cpp
Normal file
@@ -0,0 +1,17 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_wmma_base.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineMemWmma
|
||||
: public TestCkTileGemmPipelineWmmaBase<T, TestCkTileGemmPipelineMemWmma<T>>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelineMemWmma
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipelineMemWmma, KernelTypesMemWmma);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -3,7 +3,8 @@
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelinePersistent : public TestCkTileGemmPipeline<T>
|
||||
class TestCkTileGemmPipelinePersistent
|
||||
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelinePersistent<T>>
|
||||
{
|
||||
};
|
||||
|
||||
|
||||
17
test/ck_tile/gemm/test_gemm_pipeline_persistent_wmma.cpp
Normal file
17
test/ck_tile/gemm/test_gemm_pipeline_persistent_wmma.cpp
Normal file
@@ -0,0 +1,17 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_wmma_base.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelinePersistentWmma
|
||||
: public TestCkTileGemmPipelineWmmaBase<T, TestCkTileGemmPipelinePersistentWmma<T>>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelinePersistentWmma
|
||||
|
||||
TYPED_TEST_SUITE(TEST_SUITE_NAME, KernelTypesPersistentWmma);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -69,7 +69,7 @@ struct GemmPipelineTypeSelector<GemmPipelineType::CompV4, Problem>
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV4"; }
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
template <typename Tuple, typename Derived>
|
||||
class TestCkTileGemmPipeline : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
@@ -80,32 +80,30 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value;
|
||||
static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value;
|
||||
static constexpr auto Scheduler = std::tuple_element_t<13, Tuple>::value;
|
||||
static constexpr auto PipelineType = std::tuple_element_t<14, Tuple>::value;
|
||||
|
||||
static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>{};
|
||||
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>{};
|
||||
static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>{};
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = std::tuple_element_t<10, Tuple>{};
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = std::tuple_element_t<11, Tuple>{};
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = std::tuple_element_t<12, Tuple>{};
|
||||
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
|
||||
static constexpr bool Persistent =
|
||||
ck_tile::tuple_element_or_default_t<Tuple, 9, std::false_type>::value;
|
||||
// TODO: expose tile size through test t-param ?
|
||||
ck_tile::tuple_element_or_default_t<Tuple, 15, std::false_type>::value;
|
||||
|
||||
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
|
||||
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// TODO: This should be parameterized in tests
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = (PipelineType == GemmPipelineType::CompV4) ? 32 : 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool kPadM = PadM;
|
||||
constexpr bool kPadN = PadN;
|
||||
constexpr bool kPadK = PadK;
|
||||
@@ -247,11 +245,48 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
ck_tile::index_t M_Warp_Tile,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile>
|
||||
bool check_data_type()
|
||||
{
|
||||
return static_cast<Derived*>(this)
|
||||
->template check_data_type_impl<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile>();
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
ck_tile::index_t M_Warp_Tile,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile>
|
||||
bool check_data_type_impl()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
public:
|
||||
std::vector<int> k_batches_;
|
||||
|
||||
void SetUp() override
|
||||
{
|
||||
if(!check_data_type<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile>())
|
||||
{
|
||||
GTEST_SKIP() << "Unsupported data type combination for gemm pipeline test.";
|
||||
}
|
||||
if constexpr(PipelineType == GemmPipelineType::CompV4)
|
||||
{
|
||||
// Only do k_batch = 1 when pipeline is CompV4
|
||||
|
||||
24
test/ck_tile/gemm/test_gemm_pipeline_wmma_base.hpp
Normal file
24
test/ck_tile/gemm/test_gemm_pipeline_wmma_base.hpp
Normal file
@@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
|
||||
#include "test_gemm_pipeline_util.hpp"
|
||||
|
||||
template <typename Tuple, typename Derived>
|
||||
class TestCkTileGemmPipelineWmmaBase : public TestCkTileGemmPipeline<Tuple, Derived>
|
||||
{
|
||||
public:
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
ck_tile::index_t M_Warp_Tile,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile>
|
||||
bool check_data_type_impl()
|
||||
{
|
||||
return ck_tile::check_wmma_supported<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile>();
|
||||
}
|
||||
};
|
||||
0
test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc
Executable file → Normal file
0
test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc
Executable file → Normal file
Reference in New Issue
Block a user