mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4964 (commit 3271d9a)
[CK Tile] Eight Waves pipeline GEMM ## Motivation Eight waves pipeline was added for ABQuant. The goal of this PR is to enable it also for GEMM ## Technical Details Summary: - Block: - Create block struct for GEMM using eight warps specific distribution encodings - Use this block struct in ABQuant for encodings - Pipeline: - Create impl pipeline for eight waves which can be used by GEMM and ABQuant as base (and for AQuant and BQuant in the future) - Create eight waves pipeline for GEMM (this can not be easily integrated in the existing async pipeline) - Pipeline policy: - Extract GEMM specific parts in the ABQuant policy to define GEMM policy (then ABQuant use it as base and add Quant specific methods) - Minor: naming was inconsistent between warp/wave, everything is now referred to as eight waves So overall we have: - block struct directly used by GEMM -> ABQuant derived struct to implement operator - Impl base pipeline with general implementation -> GEMM and ABQuant pipelines use it to avoid code duplication but still define their own pipelines - pipeline policy struct directly used by GEMM -> ABQuant derived policy struct for Quant specific parts ## Test Plan Added new tests for GEMM pipeline: `test_ck_tile_gemm_pipeline_comp_async_eight_waves` (only gfx950 supports it). Note: K padding test is disabled for this pipeline because it's not implemented yet ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b8108662da
commit
eb033ef208
@@ -49,6 +49,13 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
|
||||
list(APPEND CK_TILE_GEMM_TEST_TARGETS
|
||||
test_ck_tile_gemm_pipeline_comp_async
|
||||
)
|
||||
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_comp_async_eight_waves test_gemm_pipeline_comp_async_eight_waves.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_comp_async_eight_waves PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS})
|
||||
|
||||
list(APPEND CK_TILE_GEMM_TEST_TARGETS
|
||||
test_ck_tile_gemm_pipeline_comp_async_eight_waves
|
||||
)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx11|gfx12")
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineCompAsyncEightWaves
|
||||
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineCompAsyncEightWaves<T>>
|
||||
{
|
||||
public:
|
||||
static constexpr bool check_data_type() { return true; }
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompAsyncEightWaves
|
||||
|
||||
TYPED_TEST_SUITE(TEST_SUITE_NAME, KernelTypesCompAsyncEightWaves);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -22,6 +22,8 @@ using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType:
|
||||
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
|
||||
using CompV6 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV6>;
|
||||
using CompAsync = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompAsync>;
|
||||
using CompAsyncEightWaves =
|
||||
ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompAsyncEightWaves>;
|
||||
|
||||
using Persistent = std::true_type;
|
||||
using NonPersistent = std::false_type;
|
||||
@@ -30,6 +32,7 @@ using I16 = ck_tile::number<16>;
|
||||
using I32 = ck_tile::number<32>;
|
||||
using I64 = ck_tile::number<64>;
|
||||
using I128 = ck_tile::number<128>;
|
||||
using I192 = ck_tile::number<192>;
|
||||
using I256 = ck_tile::number<256>;
|
||||
|
||||
// clang-format off
|
||||
@@ -242,6 +245,23 @@ using CompAsyncConfig16x16x128 = std::tuple<ALayout,
|
||||
Intrawave,
|
||||
CompAsync>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
|
||||
using CompAsyncEightWavesConfig = std::tuple<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
InputType, // AType
|
||||
InputType, // BType
|
||||
F32, // AccType
|
||||
F16, // OutputType
|
||||
I192, // MBlockTileSize
|
||||
I256, // NBlockTileSize
|
||||
I128, // KBlockTileSize
|
||||
I16, // MWarpTileSize
|
||||
I16, // NWarpTileSize
|
||||
I128, // KWarpTileSize
|
||||
Intrawave,
|
||||
CompAsyncEightWaves>;
|
||||
|
||||
using KernelTypesCompAsync = ::testing::Types<CompAsyncConfig<Row, Row, Row, F16>,
|
||||
CompAsyncConfig<Row, Col, Row, F16>,
|
||||
CompAsyncConfig<Col, Row, Row, F16>,
|
||||
@@ -254,8 +274,10 @@ using KernelTypesCompAsync = ::testing::Types<CompAsyncConfig<Row, Row, Row, F16
|
||||
using KernelTypesCompAsync16x16x128 = ::testing::Types<CompAsyncConfig16x16x128<Row, Col, Row, F4>,
|
||||
CompAsyncConfig16x16x128<Row, Col, Row, F8>>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypesCompAsyncEightWaves =
|
||||
::testing::Types<CompAsyncEightWavesConfig<Row, Col, Row, F8>>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypesCompV6 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>,
|
||||
|
||||
@@ -115,6 +115,9 @@ TYPED_TEST(TEST_SUITE_NAME, PaddK)
|
||||
constexpr int N = 1024;
|
||||
constexpr int K = 432;
|
||||
|
||||
if constexpr(TestFixture::PipelineType == GemmPipelineType::CompAsyncEightWaves)
|
||||
return;
|
||||
|
||||
for(int M : Ms)
|
||||
{
|
||||
this->Run(M, N, K);
|
||||
|
||||
@@ -46,7 +46,8 @@ enum struct GemmPipelineType
|
||||
CompV3,
|
||||
CompV4,
|
||||
CompV6,
|
||||
CompAsync
|
||||
CompAsync,
|
||||
CompAsyncEightWaves
|
||||
};
|
||||
|
||||
template <GemmPipelineType PT, typename Problem>
|
||||
@@ -97,6 +98,15 @@ struct GemmPipelineTypeSelector<GemmPipelineType::CompAsync, Problem>
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompAsync"; }
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct GemmPipelineTypeSelector<GemmPipelineType::CompAsyncEightWaves, Problem>
|
||||
{
|
||||
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using pipeline = ck_tile::GemmPipelineAgBgCrCompAsyncEightWaves<Problem>;
|
||||
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompAsyncEightWaves"; }
|
||||
};
|
||||
|
||||
template <typename Tuple, typename Derived>
|
||||
class TestCkTileGemmPipeline : public ::testing::Test
|
||||
{
|
||||
@@ -129,7 +139,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
|
||||
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t M_Warp =
|
||||
PipelineType == GemmPipelineType::CompAsyncEightWaves ? 4 : 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
@@ -246,6 +257,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
GTEST_SKIP() << "Unsupported data type combination for gemm pipeline test.";
|
||||
}
|
||||
if constexpr(PipelineType == GemmPipelineType::CompV4 ||
|
||||
PipelineType == GemmPipelineType::CompAsyncEightWaves ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Only do k_batch = 1 when pipeline is CompV4, or BDataType is I4
|
||||
|
||||
@@ -81,10 +81,10 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_preshuffle_preshuffleQuant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_eightwarps
|
||||
test_gemm_quant_abquant_eightwarps.cpp
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_eightwaves
|
||||
test_gemm_quant_abquant_eightwaves.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_eightwarps PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_tile_gemm_quant_abquant_eightwaves PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# ABQuant split-K tests
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_splitk_decode
|
||||
@@ -280,7 +280,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
test_tile_gemm_quant_abquant_a4w4_base
|
||||
test_tile_gemm_quant_abquant_a4w4_padding
|
||||
test_tile_gemm_quant_abquant_a4w4_preshuffle
|
||||
test_tile_gemm_quant_abquant_eightwarps
|
||||
test_tile_gemm_quant_abquant_eightwaves
|
||||
# ABQuant split-K tests
|
||||
test_tile_gemm_quant_abquant_splitk_decode
|
||||
test_tile_gemm_quant_abquant_splitk_prefill
|
||||
|
||||
@@ -27,15 +27,15 @@ using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
|
||||
// clang-format off
|
||||
using ABQuantEightWarpsTypes = ::testing::Types<
|
||||
using ABQuantEightWavesTypes = ::testing::Types<
|
||||
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigEightWarps, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigEightWarps_PreshuffleB, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigEightWaves, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigEightWaves_PreshuffleB, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for ABQuant
|
||||
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantEightWarpsTypes);
|
||||
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantEightWavesTypes);
|
||||
|
||||
// AQuant tests
|
||||
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
|
||||
@@ -193,7 +193,7 @@ struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleBPrefi
|
||||
static constexpr bool TransposeC = true;
|
||||
};
|
||||
|
||||
struct GemmConfigEightWarps : public GemmConfigBase
|
||||
struct GemmConfigEightWaves : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong!
|
||||
@@ -210,7 +210,7 @@ struct GemmConfigEightWarps : public GemmConfigBase
|
||||
static constexpr bool TransposeC = true;
|
||||
};
|
||||
|
||||
struct GemmConfigEightWarps_PreshuffleB : public GemmConfigEightWarps
|
||||
struct GemmConfigEightWaves_PreshuffleB : public GemmConfigEightWaves
|
||||
{
|
||||
static constexpr bool PreshuffleB = true;
|
||||
};
|
||||
@@ -1221,7 +1221,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
(std::is_same_v<BDataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<BDataType, ck_tile::bf8_t>);
|
||||
constexpr bool transpose_c = CodegenGemmTraits::TransposeC;
|
||||
constexpr bool eight_warps =
|
||||
constexpr bool eight_waves =
|
||||
#ifdef CK_GFX950_SUPPORT
|
||||
IS_FP8BLOCKSCALE &&
|
||||
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) &&
|
||||
@@ -1237,7 +1237,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
ComputeDataType>;
|
||||
|
||||
constexpr auto base_gemm_pipeline = []() {
|
||||
if constexpr(eight_warps)
|
||||
if constexpr(eight_waves)
|
||||
return ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>{};
|
||||
else if constexpr(PreshuffleB)
|
||||
return ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>{};
|
||||
@@ -1275,8 +1275,8 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
eight_warps,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrEightWarps<PipelineProblem>,
|
||||
eight_waves,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrEightWaves<PipelineProblem>,
|
||||
std::conditional_t<PreshuffleB,
|
||||
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
|
||||
@@ -1316,7 +1316,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
{
|
||||
throw std::runtime_error("Arguments not supported for ABQuant kernel");
|
||||
}
|
||||
using k_attr_t = ck_tile::kernel_attr<eight_warps>;
|
||||
using k_attr_t = ck_tile::kernel_attr<eight_waves>;
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu, k_attr_t>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
Reference in New Issue
Block a user