mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
[CK_TILE] Grouped gemm quant tensor layouts (#3414)
* feat: add RRR, CRR, CCR layouts for a/b quant grouped gemm tests and examples. Refactor example setup to improve compile time * chore: split out bquant preshuffle test, and reduce tile size to 128 to temporarily solve slow compile times * chore: set m/n warp tile to 16 as configurations with 32 seem to have some support problems * fix: missing check for transposed load in bquant pipeline * chore: lower unit test tensors dimensions a bit for faster tests * chore: set grouped gemm example M/N warp tile to 16 --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -6,18 +6,21 @@ if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
# if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
# # Split into three separate test executables for faster parallel compilation
|
||||
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp)
|
||||
# target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
# Split into three separate test executables for faster parallel compilation
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp)
|
||||
# target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp)
|
||||
# target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp)
|
||||
# target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
# endif()
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant_preshuffleb test_grouped_gemm_quant_bquant_preshuffleb.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_bquant_preshuffleb PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
|
||||
@@ -21,13 +21,29 @@ using AQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQ
|
||||
// clang-format off
|
||||
using KernelTypes_AQuant = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
|
||||
// RCR FP8 (with/without TransposeC)
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>,
|
||||
|
||||
// RCR BF8 (with/without TransposeC)
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>,
|
||||
|
||||
// RCR non-persistent (with/without TransposeC)
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, True>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, False>
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, False>,
|
||||
|
||||
// RRR layout (with/without TransposeC)
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>,
|
||||
|
||||
// CRR layout (with/without TransposeC)
|
||||
// NOT SUPPORTED: std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>,
|
||||
|
||||
// CCR layout (with/without TransposeC)
|
||||
// NOT SUPPORTED: std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -21,13 +21,18 @@ using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQ
|
||||
// clang-format off
|
||||
using KernelTypes_BQuant = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
|
||||
|
||||
// Base instances: RCR FP8/BF16 persistent
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False>,
|
||||
|
||||
// Non-persistent variant
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, False, False>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, False, False>
|
||||
|
||||
// Alternative layouts: RRR, CRR, CCR
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_util_quant.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using True = ck_tile::bool_constant<true>;
|
||||
using False = ck_tile::bool_constant<false>;
|
||||
using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_BQuant_PreshuffleB = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
|
||||
|
||||
// Base instances: RCR FP8/BF16 persistent
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False>,
|
||||
|
||||
// Non-persistent variant
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, False, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_BQuant_PreshuffleB, KernelTypes_BQuant_PreshuffleB);
|
||||
|
||||
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_BQuant_PreshuffleB
|
||||
#include "test_grouped_gemm_quant_ut_cases.inc"
|
||||
#undef TEST_CLASS_NAME
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
TYPED_TEST(TEST_CLASS_NAME, Basic)
|
||||
{
|
||||
const int group_count = 8;
|
||||
const int group_count = 6;
|
||||
std::vector<int> Ms;
|
||||
std::vector<int> Ns;
|
||||
std::vector<int> Ks;
|
||||
|
||||
@@ -31,8 +31,8 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using AQLayout = Row;
|
||||
using BQLayout = Col;
|
||||
using AQLayout = ALayout;
|
||||
using BQLayout = BLayout;
|
||||
static constexpr bool PreshuffleB = std::tuple_element_t<10, Tuple>::value;
|
||||
static constexpr bool Persistent = std::tuple_element_t<11, Tuple>::value;
|
||||
static constexpr bool TransposeC = std::tuple_element_t<12, Tuple>::value;
|
||||
@@ -44,8 +44,8 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
static const bool kPadK = false;
|
||||
|
||||
static const int kBlockPerCu = 1;
|
||||
static const ck_tile::index_t M_Tile = 256;
|
||||
static const ck_tile::index_t N_Tile = 256;
|
||||
static const ck_tile::index_t M_Tile = 128;
|
||||
static const ck_tile::index_t N_Tile = 128;
|
||||
static const ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static const ck_tile::index_t M_Warp = 2;
|
||||
@@ -782,3 +782,6 @@ using TestCkTileGroupedGemmQuant_AQuant = TestCkTileGroupedGemmQuant<Tuple>;
|
||||
|
||||
template <typename Tuple>
|
||||
using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant<Tuple>;
|
||||
|
||||
template <typename Tuple>
|
||||
using TestCkTileGroupedGemmQuant_BQuant_PreshuffleB = TestCkTileGroupedGemmQuant<Tuple>;
|
||||
|
||||
Reference in New Issue
Block a user