mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
[rocm-libraries] ROCm/rocm-libraries#4816 (commit 17ff961)
[CK] Add split-K support for ABQuantGrouped in block_scale_gemm (#4816) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Changes ### Split-K support in `gemm_quant_kernel.hpp` - **`SplitKBatchOffset`**: Added `aq_group_offset` and `aq_k_split_offset` fields (mirroring the existing `bq_*` fields for B) to track each split-K batch's position within the AQ scale tensor. For `ABQuantGrouped`, both offsets are computed from `k_id * KRead` divided by `AQuantGroupSize::kK`. - **`MakeAQBlockWindow`**: Added an `aq_group_offset` parameter (defaulting to 0 for non-split-K paths) so the AQ tensor view's K-group dimension reflects only the remaining K-groups from the split-K offset, consistent with how `MakeBQBlockWindow` handles the BQ tensor. - **`RunGemm`**: Threads the `aq_k_split_offset` through to `MakeAQBlockWindow` when in split-K mode. ### Constraints in `IsSupportedArgument()` Four constraints gate split-K (`k_batch > 1`) for ABQuantGrouped: 1. **Mode check** — split-K is only allowed for `BQuantGrouped` (no preshuffle) or `ABQuantGrouped` (no `APreshuffleQuant`). Any other quant mode with `k_batch > 1` returns `false`. 2. **B quant group alignment** — `KRead` (per-batch K slice) must be divisible by `BQuantGroupSize::kK`. Each batch must operate on complete B quantization groups; a partial group would require splitting a scale value across batches. 3. **A quant group alignment** (new, ABQuantGrouped only) — `KRead` must also be divisible by `AQuantGroupSize::kK` for the same reason applied to the AQ scale tensor. 4. **Minimum 2 K-tile iterations per batch** (new) — The software-pipelined GEMM kernels (CompV3 family) prefetch one tile ahead, so they require `per_batch_num_loop = KRead / KPerBlock >= 2`. When `KRead == KPerBlock` (i.e. each batch is exactly one tile), the prefetch reads into the next batch's memory region and produces incorrect results. Configurations where `K == k_batch * KPerBlock` are therefore rejected. ### Example update (`run_gemm_quant_example.inc`) Updated the comment above the `IsSupportedArgument` call to document that split-K is now supported for both `BQuantGrouped` (no preshuffle) and `ABQuantGrouped` (no `APreshuffleQuant`). ## Unit Tests Two new test files covering decode and prefill tile shapes across a range of `k_batch` values (2–8), data types (FP8, BF8), and quantization group sizes (1×1×128 and 1×128×128 for B): - `test_gemm_quant_abquant_splitk_decode.cpp` — uses the decode tile shape (M=16, N=64, K_tile=256) - `test_gemm_quant_abquant_splitk_prefill.cpp` — uses the prefill tile shape (M=128, N=128, K_tile=128) Each test calls `run_test_with_validation` which runs the kernel and checks correctness against a CPU reference. Configurations excluded from tests are annotated with comments explaining which constraint they violate (typically the `per_batch_num_loop >= 2` requirement). ## Prerequisites This PR depends on #4429, which must be merged before this can be merged.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
6549c320fc
commit
c8a8449eec
@@ -81,6 +81,17 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_preshuffle_preshuffleQuant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# ABQuant split-K tests
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_splitk_decode
|
||||
test_gemm_quant_abquant_splitk_decode.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_splitk_decode PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_splitk_prefill
|
||||
test_gemm_quant_abquant_splitk_prefill.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_splitk_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_base
|
||||
test_gemm_quant_abquant_a4w4_base.cpp
|
||||
)
|
||||
@@ -268,7 +279,14 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
test_tile_gemm_quant_abquant_base
|
||||
test_tile_gemm_quant_abquant_padding
|
||||
test_tile_gemm_quant_abquant_preshuffle
|
||||
test_tile_gemm_quant_abquant_preshuffle_preshuffleQuant
|
||||
test_tile_gemm_quant_abquant_preshuffleQuant
|
||||
test_tile_gemm_quant_abquant_a4w4_base
|
||||
test_tile_gemm_quant_abquant_a4w4_padding
|
||||
test_tile_gemm_quant_abquant_a4w4_preshuffle
|
||||
# ABQuant split-K tests
|
||||
test_tile_gemm_quant_abquant_splitk_decode
|
||||
test_tile_gemm_quant_abquant_splitk_prefill
|
||||
# BQuant tests
|
||||
test_tile_gemm_quant_bquant_1d_128
|
||||
test_tile_gemm_quant_bquant_1d_64
|
||||
@@ -276,6 +294,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
test_tile_gemm_quant_bquant_2d_medium_n
|
||||
test_tile_gemm_quant_bquant_2d_large_n
|
||||
test_tile_gemm_quant_bquant_transpose
|
||||
# BQuant split-K tests
|
||||
test_tile_gemm_quant_bquant_splitk_decode
|
||||
test_tile_gemm_quant_bquant_splitk_prefill
|
||||
# BQuant preshuffle tests
|
||||
test_tile_gemm_quant_bquant_preshuffle_decode_1d
|
||||
test_tile_gemm_quant_bquant_preshuffle_prefill_1d
|
||||
|
||||
@@ -28,7 +28,7 @@ using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>
|
||||
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
|
||||
// clang-format off
|
||||
using ABQuantTypes = ::testing::Types<
|
||||
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
|
||||
// 1D BScales; PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
|
||||
@@ -36,12 +36,13 @@ using ABQuantTypes = ::testing::Types<
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
// 2D B-scales; PreshuffleQuant = false && TransposeC = true (RCR layout with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigTransposeC, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigTransposeC, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigTransposeC, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigTransposeC, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigTransposeC, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigTransposeC, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -28,9 +28,11 @@ using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>
|
||||
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
|
||||
// clang-format off
|
||||
using ABQuantPreshuffleBTypes = ::testing::Types<
|
||||
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
|
||||
// 1D B-scales; PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
|
||||
/// 2D B-scales; PreshuffleQuant = false && TransposeC = true (RCR layout with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefillTransposeC, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -28,8 +28,8 @@ using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>
|
||||
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
|
||||
// clang-format off
|
||||
using ABQuantPreshuffleQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantPrefill, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantPrefill, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantPrefill<false>, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantPrefill<true>, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using ABQuantGrouped =
|
||||
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
|
||||
using GroupSize1x1x128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using GroupSize1x128x128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for ABQuant split-K tests - Decode shape
|
||||
// GemmConfigDecode: M_Tile=16, N_Tile=64, K_Tile=256, kPadK=false
|
||||
// Constraints: M % 16 == 0, N % 64 == 0, K % (k_batch * 256) == 0
|
||||
//
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
|
||||
// clang-format off
|
||||
using ABQuantSplitKDecodeTypes = ::testing::Types<
|
||||
// GroupSize 1x1x128 (kK=128 for both A and B, kN=1)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigDecode, GroupSize1x1x128, GroupSize1x1x128, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigDecode, GroupSize1x1x128, GroupSize1x1x128, ColumnMajor>,
|
||||
// GroupSize 1x128x128 for B (kK=128, kN=128)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigDecode, GroupSize1x1x128, GroupSize1x128x128, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigDecode, GroupSize1x1x128, GroupSize1x128x128, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for ABQuant split-K Decode
|
||||
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantSplitKDecodeTypes);
|
||||
|
||||
// ---- k_batch=2 ----------------------------------------------------------------
|
||||
// Note: K=512 (= 2*K_Tile) is excluded because KRead=K_Tile=256, giving
|
||||
// per_batch_num_loop=1 which the software-pipelined kernel cannot handle.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK2_MedK_BaseShape)
|
||||
{
|
||||
// K=1024=4*256: standard decode decode shape
|
||||
this->run_test_with_validation(32, 64, 1024, 2);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK2_LargeK_WideN)
|
||||
{
|
||||
// K=2048, larger N (multiple of N_Tile=64)
|
||||
this->run_test_with_validation(32, 256, 2048, 2);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK2_LargeK_TallM)
|
||||
{
|
||||
// K=4096, larger M (multiple of M_Tile=16)
|
||||
this->run_test_with_validation(64, 64, 4096, 2);
|
||||
}
|
||||
|
||||
// ---- k_batch=3 ----------------------------------------------------------------
|
||||
// Note: K=768 (= 3*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK3_MedK_BaseShape)
|
||||
{
|
||||
// K=1536=6*256
|
||||
this->run_test_with_validation(32, 64, 1536, 3);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK3_LargeK_BaseShape)
|
||||
{
|
||||
// K=3072=12*256
|
||||
this->run_test_with_validation(32, 64, 3072, 3);
|
||||
}
|
||||
|
||||
// ---- k_batch=4 ----------------------------------------------------------------
|
||||
// Note: K=1024 (= 4*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK4_MedK_BaseShape)
|
||||
{
|
||||
// K=2048=8*256
|
||||
this->run_test_with_validation(32, 64, 2048, 4);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK4_LargeK_WideN)
|
||||
{
|
||||
// K=4096, wider N
|
||||
this->run_test_with_validation(32, 128, 4096, 4);
|
||||
}
|
||||
|
||||
// ---- k_batch=5 ----------------------------------------------------------------
|
||||
// Note: K=1280 (= 5*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK5_MedK_BaseShape)
|
||||
{
|
||||
// K=2560=10*256
|
||||
this->run_test_with_validation(32, 64, 2560, 5);
|
||||
}
|
||||
|
||||
// ---- k_batch=6 ----------------------------------------------------------------
|
||||
// Note: K=1536 (= 6*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK6_LargeK_BaseShape)
|
||||
{
|
||||
// K=3072=12*256
|
||||
this->run_test_with_validation(32, 64, 3072, 6);
|
||||
}
|
||||
|
||||
// ---- k_batch=8 ----------------------------------------------------------------
|
||||
// Note: K=2048 (= 8*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK8_LargeK_BaseShape)
|
||||
{
|
||||
// K=4096=16*256
|
||||
this->run_test_with_validation(32, 64, 4096, 8);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK8_LargeK_LargeMN)
|
||||
{
|
||||
// K=4096, larger M and N
|
||||
this->run_test_with_validation(48, 192, 4096, 8);
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using ABQuantGrouped =
|
||||
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
|
||||
using GroupSize1x1x128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using GroupSize1x128x128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for ABQuant split-K tests - Prefill shape
|
||||
// GemmConfigPrefill: M_Tile=128, N_Tile=128, K_Tile=128, kPadK=false
|
||||
// Constraints: M % 128 == 0, N % 128 == 0, K % (k_batch * 128) == 0
|
||||
//
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
|
||||
// clang-format off
|
||||
using ABQuantSplitKPrefillTypes = ::testing::Types<
|
||||
// GroupSize 1x1x128 (kK=128 for both A and B, kN=1)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPrefill, GroupSize1x1x128, GroupSize1x1x128, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigPrefill, GroupSize1x1x128, GroupSize1x1x128, ColumnMajor>,
|
||||
// GroupSize 1x128x128 for B (kK=128, kN=128)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPrefill, GroupSize1x1x128, GroupSize1x128x128, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigPrefill, GroupSize1x1x128, GroupSize1x128x128, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for ABQuant split-K Prefill
|
||||
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantSplitKPrefillTypes);
|
||||
|
||||
// ---- k_batch=2 ----------------------------------------------------------------
|
||||
// Note: K=256 (= 2*K_Tile) excluded: KRead=K_Tile=128, per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK2_MedK_BaseShape)
|
||||
{
|
||||
// K=1024=8*128
|
||||
this->run_test_with_validation(128, 128, 1024, 2);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK2_LargeK_WideN)
|
||||
{
|
||||
// K=2048, wider N
|
||||
this->run_test_with_validation(128, 256, 2048, 2);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK2_LargeK_TallM)
|
||||
{
|
||||
// K=4096, taller M
|
||||
this->run_test_with_validation(256, 128, 4096, 2);
|
||||
}
|
||||
|
||||
// ---- k_batch=3 ----------------------------------------------------------------
|
||||
// Note: K=384 (= 3*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK3_MedK_BaseShape)
|
||||
{
|
||||
// K=768=6*128
|
||||
this->run_test_with_validation(128, 128, 768, 3);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK3_LargeK_BaseShape)
|
||||
{
|
||||
// K=3072=24*128
|
||||
this->run_test_with_validation(128, 128, 3072, 3);
|
||||
}
|
||||
|
||||
// ---- k_batch=4 ----------------------------------------------------------------
|
||||
// Note: K=512 (= 4*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK4_MedK_BaseShape)
|
||||
{
|
||||
// K=2048=16*128
|
||||
this->run_test_with_validation(128, 128, 2048, 4);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK4_LargeK_LargeMN)
|
||||
{
|
||||
// K=4096, larger M and N
|
||||
this->run_test_with_validation(256, 256, 4096, 4);
|
||||
}
|
||||
|
||||
// ---- k_batch=5 ----------------------------------------------------------------
|
||||
// Note: K=640 (= 5*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK5_MedK_BaseShape)
|
||||
{
|
||||
// K=1280=10*128
|
||||
this->run_test_with_validation(128, 128, 1280, 5);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK5_LargeK_BaseShape)
|
||||
{
|
||||
// K=2560=20*128
|
||||
this->run_test_with_validation(128, 128, 2560, 5);
|
||||
}
|
||||
|
||||
// ---- k_batch=6 ----------------------------------------------------------------
|
||||
// Note: K=768 (= 6*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK6_LargeK_BaseShape)
|
||||
{
|
||||
// K=3072=24*128
|
||||
this->run_test_with_validation(128, 128, 3072, 6);
|
||||
}
|
||||
|
||||
// ---- k_batch=8 ----------------------------------------------------------------
|
||||
// Note: K=1024 (= 8*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK8_MedK_BaseShape)
|
||||
{
|
||||
// K=2048=16*128
|
||||
this->run_test_with_validation(128, 128, 2048, 8);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK8_LargeK_LargeMN)
|
||||
{
|
||||
// K=4096, larger M and N
|
||||
this->run_test_with_validation(256, 256, 4096, 8);
|
||||
}
|
||||
@@ -158,6 +158,10 @@ struct GemmConfigPreshuffleBPrefill : public GemmConfigPrefill
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
struct GemmConfigPreshuffleBPrefillTransposeC : public GemmConfigPreshuffleBPrefill
|
||||
{
|
||||
static constexpr bool TransposeC = true;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleQuantPrefill : public GemmConfigPrefill
|
||||
{
|
||||
@@ -170,14 +174,18 @@ struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBP
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <bool TransposeC_ = false>
|
||||
struct GemmConfigPreshuffleBPreshuffleQuantPrefill : public GemmConfigPreshuffleBPrefill
|
||||
{
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
};
|
||||
|
||||
template <bool TransposeC_ = false>
|
||||
struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode
|
||||
{
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
@@ -980,7 +988,10 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
void TearDownQuantTypeSpecific() {}
|
||||
|
||||
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
void run_test_with_validation(ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t k_batch = 1)
|
||||
{
|
||||
const ck_tile::index_t stride_A =
|
||||
ck_tile::get_default_stride(M, K, 0, this->is_row_major(ALayout{}));
|
||||
@@ -1091,6 +1102,13 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
bq_bqk_bqn_dev_buf.ToDevice(bq_bqk_bqn.data());
|
||||
}
|
||||
|
||||
// For split-K (k_batch > 1), the kernel uses atomic_add to accumulate partial results
|
||||
// into C. Zero the output buffer before launching so atomic additions start from zero.
|
||||
if(k_batch > 1)
|
||||
{
|
||||
c_m_n_dev_buf.SetZero();
|
||||
}
|
||||
|
||||
// Create args for kernel execution
|
||||
ck_tile::QuantGemmHostArgs args{
|
||||
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
|
||||
@@ -1098,7 +1116,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
|
||||
aq_m_aqk_dev_buf.GetDeviceBuffer(), // aq_ptr (scales)
|
||||
bq_bqk_bqn_dev_buf.GetDeviceBuffer(), // bq_ptr (scales)
|
||||
1, // k_batch
|
||||
k_batch, // k_batch
|
||||
M,
|
||||
N,
|
||||
K, // M, N, K
|
||||
@@ -1136,12 +1154,12 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
|
||||
|
||||
// Calculate error tolerances
|
||||
// Calculate error tolerances (adjusted for split-K accumulation error)
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1, max_accumulated_value);
|
||||
K, k_batch, max_accumulated_value);
|
||||
|
||||
// Validate results
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
@@ -1151,7 +1169,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
EXPECT_TRUE(pass) << "ABQuantGrouped validation failed with M=" << M << ", N=" << N
|
||||
<< ", K=" << K;
|
||||
<< ", K=" << K << ", k_batch=" << k_batch;
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user