feat: add split_k support for block scale gemm bquant mode. (#3653)

* WIP: add splitk to bquant

* feat: add support for bf8i4 and fp8i4 by calculating correct stride for packed data types

* chore: remove temporary test script

* fix: incorrect tile window length for splitted bq tensor window

* chore: improve comments

* test: add unit tests to cover bquant splitk functionality

* fix: conflict resolution by renaming variables
This commit is contained in:
Aviral Goel
2026-02-03 02:41:53 +04:00
committed by GitHub
parent 301eb5cf08
commit 3e77721755
11 changed files with 273 additions and 208 deletions

View File

@@ -128,6 +128,17 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
)
target_compile_options(test_tile_gemm_quant_bquant_transpose PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# BQuant split-K tests (no preshuffle)
add_gtest_executable(test_tile_gemm_quant_bquant_splitk_decode
test_gemm_quant_bquant_splitk_decode.cpp
)
target_compile_options(test_tile_gemm_quant_bquant_splitk_decode PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_bquant_splitk_prefill
test_gemm_quant_bquant_splitk_prefill.cpp
)
target_compile_options(test_tile_gemm_quant_bquant_splitk_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# BQuant tests (with PreshuffleB) - split into 5 files
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_decode_1d
test_gemm_quant_bquant_preshuffle_decode_1d.cpp

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using GroupSize128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// Type combinations for BQuant split-K tests - Decode shape, GroupSize 128
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BQuantSplitKDecodeTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigDecode, GroupSize128>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigDecode, GroupSize128>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigDecode, GroupSize128>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigDecode, GroupSize128>
>;
// clang-format on
// Test suite for BQuant split-K Decode
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantSplitKDecodeTypes);
// BQuant split-K tests
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK2Test)
{
// K=1024 for split_k=2: 1024/2=512=4×128 ✓
this->run_test_with_validation(32, 128, 1024, 2);
}
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK3Test)
{
// K=3072 for split_k=3: 3072/3=1024=8×128 ✓
this->run_test_with_validation(32, 128, 3072, 3);
}
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK4Test)
{
// K=2048 for split_k=4: 2048/4=512=4×128 ✓
this->run_test_with_validation(32, 128, 2048, 4);
}
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK5Test)
{
// K=2560 for split_k=5: 2560/5=512=4×128 ✓
// Also K must be divisible by K_Tile(256)*split_k(5)=1280
this->run_test_with_validation(32, 128, 2560, 5);
}

View File

@@ -0,0 +1,64 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using GroupSize128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// Type combinations for BQuant split-K tests - Prefill shape, GroupSize 128
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BQuantSplitKPrefillTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPrefill, GroupSize128>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPrefill, GroupSize128>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPrefill, GroupSize128>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPrefill, GroupSize128>
>;
// clang-format on
// Test suite for BQuant split-K Prefill
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantSplitKPrefillTypes);
// BQuant split-K tests
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK2Test)
{
// K=1024 for split_k=2: 1024/2=512=4×128 ✓
// K must be divisible by K_Tile(128)*split_k(2)=256
this->run_test_with_validation(128, 128, 1024, 2);
}
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK3Test)
{
// K=3072 for split_k=3: 3072/3=1024=8×128 ✓
// K must be divisible by K_Tile(128)*split_k(3)=384
this->run_test_with_validation(128, 128, 3072, 3);
}
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK4Test)
{
// K=2048 for split_k=4: 2048/4=512=4×128 ✓
// K must be divisible by K_Tile(128)*split_k(4)=512
this->run_test_with_validation(128, 128, 2048, 4);
}
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK5Test)
{
// K=1920 for split_k=5: 1920/5=384=3×128 ✓
// K must be divisible by K_Tile(128)*split_k(5)=640
this->run_test_with_validation(128, 128, 1920, 5);
}

View File

@@ -655,7 +655,10 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
void SetUpQuantTypeSpecific() {}
void TearDownQuantTypeSpecific() {}
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
void run_test_with_validation(ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t k_batch = 1)
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B =
@@ -698,6 +701,9 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
sizeof(QDataType));
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
// Zero C buffer - required for split-K atomic_add accumulation
c_m_n_dev_buf.SetZero();
// Copy to device
a_m_k_dev_buf.ToDevice(a_m_k.data());
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
@@ -746,12 +752,12 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
nullptr, // aq_ptr (not used for BQuant)
bq_bqk_bqn_dev_buf.GetDeviceBuffer(), // bq_ptr (scales)
1, // k_batch
k_batch, // k_batch (split-K)
M,
N,
K, // M, N, K
0, // QK_A (not used for BQuant)
BQK, // QK_B - TODO: we can remove BQK and BQN from args later?
BQK, // QK_B
stride_A,
stride_B,
stride_C,
@@ -796,7 +802,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol =
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1, max_accumulated_value);
K, k_batch, max_accumulated_value);
// Validate results
bool pass = ck_tile::check_err(c_m_n_dev_result,
@@ -806,7 +812,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
rtol_atol.at(ck_tile::number<1>{}));
EXPECT_TRUE(pass) << "BQuantGrouped validation failed with M=" << M << ", N=" << N
<< ", K=" << K;
<< ", K=" << K << ", k_batch=" << k_batch;
if(!pass)
{