mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
feat: add split_k support for block scale gemm bquant mode. (#3653)
* WIP: add splitk to bquant * feat: add support for bf8i4 and fp8i4 by calculating correct stride for packed data types * chore: remove temporary test script * fix: incorrect tile window length for splitted bq tensor window * chore: improve comments * test: add unit tests to cover bquant splitk functionality * fix: conflict resolution by renaming variables
This commit is contained in:
@@ -128,6 +128,17 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_transpose PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# BQuant split-K tests (no preshuffle)
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_splitk_decode
|
||||
test_gemm_quant_bquant_splitk_decode.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_splitk_decode PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_splitk_prefill
|
||||
test_gemm_quant_bquant_splitk_prefill.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_splitk_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# BQuant tests (with PreshuffleB) - split into 5 files
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_decode_1d
|
||||
test_gemm_quant_bquant_preshuffle_decode_1d.cpp
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for BQuant split-K tests - Decode shape, GroupSize 128
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuantSplitKDecodeTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigDecode, GroupSize128>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigDecode, GroupSize128>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigDecode, GroupSize128>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigDecode, GroupSize128>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant split-K Decode
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantSplitKDecodeTypes);
|
||||
|
||||
// BQuant split-K tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK2Test)
|
||||
{
|
||||
// K=1024 for split_k=2: 1024/2=512=4×128 ✓
|
||||
this->run_test_with_validation(32, 128, 1024, 2);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK3Test)
|
||||
{
|
||||
// K=3072 for split_k=3: 3072/3=1024=8×128 ✓
|
||||
this->run_test_with_validation(32, 128, 3072, 3);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK4Test)
|
||||
{
|
||||
// K=2048 for split_k=4: 2048/4=512=4×128 ✓
|
||||
this->run_test_with_validation(32, 128, 2048, 4);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK5Test)
|
||||
{
|
||||
// K=2560 for split_k=5: 2560/5=512=4×128 ✓
|
||||
// Also K must be divisible by K_Tile(256)*split_k(5)=1280
|
||||
this->run_test_with_validation(32, 128, 2560, 5);
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for BQuant split-K tests - Prefill shape, GroupSize 128
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuantSplitKPrefillTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPrefill, GroupSize128>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPrefill, GroupSize128>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPrefill, GroupSize128>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPrefill, GroupSize128>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant split-K Prefill
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantSplitKPrefillTypes);
|
||||
|
||||
// BQuant split-K tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK2Test)
|
||||
{
|
||||
// K=1024 for split_k=2: 1024/2=512=4×128 ✓
|
||||
// K must be divisible by K_Tile(128)*split_k(2)=256
|
||||
this->run_test_with_validation(128, 128, 1024, 2);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK3Test)
|
||||
{
|
||||
// K=3072 for split_k=3: 3072/3=1024=8×128 ✓
|
||||
// K must be divisible by K_Tile(128)*split_k(3)=384
|
||||
this->run_test_with_validation(128, 128, 3072, 3);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK4Test)
|
||||
{
|
||||
// K=2048 for split_k=4: 2048/4=512=4×128 ✓
|
||||
// K must be divisible by K_Tile(128)*split_k(4)=512
|
||||
this->run_test_with_validation(128, 128, 2048, 4);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK5Test)
|
||||
{
|
||||
// K=1920 for split_k=5: 1920/5=384=3×128 ✓
|
||||
// K must be divisible by K_Tile(128)*split_k(5)=640
|
||||
this->run_test_with_validation(128, 128, 1920, 5);
|
||||
}
|
||||
@@ -655,7 +655,10 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
void TearDownQuantTypeSpecific() {}
|
||||
|
||||
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
void run_test_with_validation(ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t k_batch = 1)
|
||||
{
|
||||
const ck_tile::index_t stride_A = K;
|
||||
const ck_tile::index_t stride_B =
|
||||
@@ -698,6 +701,9 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
sizeof(QDataType));
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
|
||||
|
||||
// Zero C buffer - required for split-K atomic_add accumulation
|
||||
c_m_n_dev_buf.SetZero();
|
||||
|
||||
// Copy to device
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
@@ -746,12 +752,12 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
|
||||
nullptr, // aq_ptr (not used for BQuant)
|
||||
bq_bqk_bqn_dev_buf.GetDeviceBuffer(), // bq_ptr (scales)
|
||||
1, // k_batch
|
||||
k_batch, // k_batch (split-K)
|
||||
M,
|
||||
N,
|
||||
K, // M, N, K
|
||||
0, // QK_A (not used for BQuant)
|
||||
BQK, // QK_B - TODO: we can remove BQK and BQN from args later?
|
||||
BQK, // QK_B
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
@@ -796,7 +802,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1, max_accumulated_value);
|
||||
K, k_batch, max_accumulated_value);
|
||||
|
||||
// Validate results
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
@@ -806,7 +812,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
EXPECT_TRUE(pass) << "BQuantGrouped validation failed with M=" << M << ", N=" << N
|
||||
<< ", K=" << K;
|
||||
<< ", K=" << K << ", k_batch=" << k_batch;
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user