[CK_TILE] Grouped gemm quant tensor layouts (#3414)

* feat: add RRR, CRR, CCR layouts for a/b quant grouped gemm tests and examples. Refactor example setup to improve compile time

* chore: split out bquant preshuffle test, and reduce tile size to 128 to temporarily solve slow compile times

* chore: set m/n warp tile to 16 as configurations with 32 seem to have some support problems

* fix: missing check for transposed load in bquant pipeline

* chore: lower unit test tensors dimensions a bit for faster tests

* chore: set grouped gemm example M/N warp tile to 16

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Erwin Terpstra
2025-12-25 08:01:23 +01:00
committed by GitHub
parent 14668a56e3
commit e08efa551f
20 changed files with 662 additions and 490 deletions

View File

@@ -6,18 +6,21 @@ if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
# if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
# # Split into three separate test executables for faster parallel compilation
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp)
# target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
# Split into three separate test executables for faster parallel compilation
add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp)
# target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp)
# target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp)
# target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
# endif()
add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant_preshuffleb test_grouped_gemm_quant_bquant_preshuffleb.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_bquant_preshuffleb PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()

View File

@@ -21,13 +21,29 @@ using AQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQ
// clang-format off
using KernelTypes_AQuant = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
// RCR FP8 (with/without TransposeC)
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>,
// RCR BF8 (with/without TransposeC)
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>,
// RCR non-persistent (with/without TransposeC)
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, True>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, False>
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, False>,
// RRR layout (with/without TransposeC)
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>,
// CRR layout (with/without TransposeC)
// NOT SUPPORTED: std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>,
// CCR layout (with/without TransposeC)
// NOT SUPPORTED: std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>
>;
// clang-format on

View File

@@ -21,13 +21,18 @@ using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQ
// clang-format off
using KernelTypes_BQuant = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
// Base instances: RCR FP8/BF16 persistent
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False>,
// Non-persistent variant
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, False, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, False, False>
// Alternative layouts: RRR, CRR, CCR
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>
>;
// clang-format on

View File

@@ -0,0 +1,38 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_grouped_gemm_util_quant.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using True = ck_tile::bool_constant<true>;
using False = ck_tile::bool_constant<false>;
using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
// clang-format off
using KernelTypes_BQuant_PreshuffleB = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
// Base instances: RCR FP8/BF16 persistent
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False>,
// Non-persistent variant
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, False, False>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_BQuant_PreshuffleB, KernelTypes_BQuant_PreshuffleB);
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_BQuant_PreshuffleB
#include "test_grouped_gemm_quant_ut_cases.inc"
#undef TEST_CLASS_NAME

View File

@@ -5,7 +5,7 @@
TYPED_TEST(TEST_CLASS_NAME, Basic)
{
const int group_count = 8;
const int group_count = 6;
std::vector<int> Ms;
std::vector<int> Ns;
std::vector<int> Ks;

View File

@@ -31,8 +31,8 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
using DsDataType = ck_tile::tuple<>;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using AQLayout = Row;
using BQLayout = Col;
using AQLayout = ALayout;
using BQLayout = BLayout;
static constexpr bool PreshuffleB = std::tuple_element_t<10, Tuple>::value;
static constexpr bool Persistent = std::tuple_element_t<11, Tuple>::value;
static constexpr bool TransposeC = std::tuple_element_t<12, Tuple>::value;
@@ -44,8 +44,8 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
static const bool kPadK = false;
static const int kBlockPerCu = 1;
static const ck_tile::index_t M_Tile = 256;
static const ck_tile::index_t N_Tile = 256;
static const ck_tile::index_t M_Tile = 128;
static const ck_tile::index_t N_Tile = 128;
static const ck_tile::index_t K_Tile = 128;
static const ck_tile::index_t M_Warp = 2;
@@ -782,3 +782,6 @@ using TestCkTileGroupedGemmQuant_AQuant = TestCkTileGroupedGemmQuant<Tuple>;
template <typename Tuple>
using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant<Tuple>;
template <typename Tuple>
using TestCkTileGroupedGemmQuant_BQuant_PreshuffleB = TestCkTileGroupedGemmQuant<Tuple>;