[CK_TILE] add preshuffleB mode for ABQuant GEMM (#3495)

* [CK_TILE] add preshuffleB mode for ABQuant GEMM

* fix precommit error

* use template method call for cvt_scale_to_fp32

* fix precommit error

* add test code

* fix precommit error

* switch abquant  gemmconfig to default

* Add changelog.md

* fix precommit error

* fix conflict
This commit is contained in:
kensclin
2026-01-07 04:35:01 +08:00
committed by GitHub
parent 960ef551bf
commit 2309c86054
10 changed files with 1161 additions and 27 deletions

View File

@@ -39,6 +39,12 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
)
target_compile_options(test_tile_gemm_quant_abquant_padding PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_abquant_preshuffle
test_gemm_quant_abquant_preshuffle_2d.cpp
)
target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# AQuant tests
add_gtest_executable(test_tile_gemm_quant_aquant_prefill
test_gemm_quant_aquant_prefill.cpp
)

View File

@@ -0,0 +1,44 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using ABQuantGrouped =
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// 2d block sizes for BQuant
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
// Type combinations for ABQuant tests
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
// clang-format off
using ABQuantPreshuffleBTypes = ::testing::Types<
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize, GroupSize2D128N, ColumnMajor>
>;
// clang-format on
// Test suite for ABQuant
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPreshuffleBTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -894,10 +894,10 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline =
std::conditional_t<PreshuffleB == false,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
using BaseGemmPipeline = std::conditional_t<
PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
@@ -926,8 +926,8 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
tail_number_v>;
using GemmPipeline =
std::conditional_t<PreshuffleB == false,
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<PreshuffleB == true,
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<