mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK] Add split-K support for ABQuantGrouped in block_scale_gemm (#4816)
## Changes ### Split-K support in `gemm_quant_kernel.hpp` - **`SplitKBatchOffset`**: Added `aq_group_offset` and `aq_k_split_offset` fields (mirroring the existing `bq_*` fields for B) to track each split-K batch's position within the AQ scale tensor. For `ABQuantGrouped`, both offsets are computed from `k_id * KRead` divided by `AQuantGroupSize::kK`. - **`MakeAQBlockWindow`**: Added an `aq_group_offset` parameter (defaulting to 0 for non-split-K paths) so the AQ tensor view's K-group dimension reflects only the remaining K-groups from the split-K offset, consistent with how `MakeBQBlockWindow` handles the BQ tensor. - **`RunGemm`**: Threads the `aq_k_split_offset` through to `MakeAQBlockWindow` when in split-K mode. ### Constraints in `IsSupportedArgument()` Four constraints gate split-K (`k_batch > 1`) for ABQuantGrouped: 1. **Mode check** — split-K is only allowed for `BQuantGrouped` (no preshuffle) or `ABQuantGrouped` (no `APreshuffleQuant`). Any other quant mode with `k_batch > 1` returns `false`. 2. **B quant group alignment** — `KRead` (per-batch K slice) must be divisible by `BQuantGroupSize::kK`. Each batch must operate on complete B quantization groups; a partial group would require splitting a scale value across batches. 3. **A quant group alignment** (new, ABQuantGrouped only) — `KRead` must also be divisible by `AQuantGroupSize::kK` for the same reason applied to the AQ scale tensor. 4. **Minimum 2 K-tile iterations per batch** (new) — The software-pipelined GEMM kernels (CompV3 family) prefetch one tile ahead, so they require `per_batch_num_loop = KRead / KPerBlock >= 2`. When `KRead == KPerBlock` (i.e. each batch is exactly one tile), the prefetch reads into the next batch's memory region and produces incorrect results. Configurations where `K == k_batch * KPerBlock` are therefore rejected. ### Example update (`run_gemm_quant_example.inc`) Updated the comment above the `IsSupportedArgument` call to document that split-K is now supported for both `BQuantGrouped` (no preshuffle) and `ABQuantGrouped` (no `APreshuffleQuant`). ## Unit Tests Two new test files covering decode and prefill tile shapes across a range of `k_batch` values (2–8), data types (FP8, BF8), and quantization group sizes (1×1×128 and 1×128×128 for B): - `test_gemm_quant_abquant_splitk_decode.cpp` — uses the decode tile shape (M=16, N=64, K_tile=256) - `test_gemm_quant_abquant_splitk_prefill.cpp` — uses the prefill tile shape (M=128, N=128, K_tile=128) Each test calls `run_test_with_validation` which runs the kernel and checks correctness against a CPU reference. Configurations excluded from tests are annotated with comments explaining which constraint they violate (typically the `per_batch_num_loop >= 2` requirement). ## Prerequisites This PR depends on #4429, which must be merged before this can be merged. --------- Co-authored-by: Erwin Terpstra <erwin.terpstra@streamhpc.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -14,7 +14,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
set(EXE_NAME tile_example_gemm_quant)
|
||||
add_executable(${EXE_NAME}
|
||||
gemm_quant.cpp
|
||||
gemm_abquant_quantgrouped.cpp
|
||||
gemm_abquant_quantgrouped_fp8.cpp
|
||||
gemm_abquant_quantgrouped_fp4.cpp
|
||||
gemm_abquant_quantgrouped_bf8.cpp
|
||||
gemm_abquant_quantgrouped_preshuffleb_fp8.cpp
|
||||
gemm_abquant_quantgrouped_preshuffleb_bf8.cpp
|
||||
gemm_abquant_quantgrouped_preshuffleb_preshufflequant.cpp
|
||||
gemm_aquant_quantgrouped.cpp
|
||||
gemm_aquant_quantgrouped_preshufflequant.cpp
|
||||
|
||||
@@ -1,202 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if defined(CK_USE_GFX950)
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigEightWarps<T>;
|
||||
template <typename T>
|
||||
using GemmConfigPrefill = GemmConfigPreshuffleBEightWarps<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigABQuantPrefill<T>;
|
||||
template <typename T>
|
||||
using GemmConfigPrefill = GemmConfigPreshuffleB_ABQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigABQuantPrefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigABQuantPrefill<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigABQuantPrefill<ck_tile::pk_fp4_raw_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp4", "abquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<
|
||||
GemmConfigPreshuffleB_ABQuant_Prefill<ck_tile::pk_fp4_raw_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
@@ -0,0 +1,18 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
#if defined(CK_USE_GFX950)
|
||||
template <typename T, bool TransposeC = true>
|
||||
using GemmConfig = GemmConfigEightWarps<T, TransposeC>;
|
||||
template <typename T, bool TransposeC = true>
|
||||
using GemmConfigPrefill = GemmConfigPreshuffleBEightWarps<T, TransposeC>;
|
||||
#else
|
||||
template <typename T, bool TransposeC = true>
|
||||
using GemmConfig = GemmConfigABQuantPrefill<T, TransposeC>;
|
||||
template <typename T, bool TransposeC = true>
|
||||
using GemmConfigPrefill = GemmConfigPreshuffleB_ABQuant_Prefill<T, TransposeC>;
|
||||
#endif
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
#include "gemm_abquant_quantgrouped.h"
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigABQuantPrefill<ck_tile::bf8_t, false>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
#include "gemm_abquant_quantgrouped.h"
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
lut[hash_multiple_strings(
|
||||
{"fp4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigABQuantPrefill<ck_tile::pk_fp4_raw_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp4", "abquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<
|
||||
GemmConfigPreshuffleB_ABQuant_Prefill<ck_tile::pk_fp4_raw_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
@@ -0,0 +1,70 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
#include "gemm_abquant_quantgrouped.h"
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigABQuantPrefill<ck_tile::fp8_t, false>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
#include "gemm_abquant_quantgrouped.h"
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
#include "gemm_abquant_quantgrouped.h"
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
@@ -4,41 +4,41 @@
|
||||
#include "38_block_scale_gemm/gemm_utils.hpp"
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
template <typename T, bool TransposeC>
|
||||
using GemmConfigPreshuffleB_PreshuffleBQuant =
|
||||
GemmConfigPreshuffleB_ABQuant_PreshuffleBQuant_Prefill<T>;
|
||||
GemmConfigPreshuffleB_ABQuant_PreshuffleBQuant_Prefill<T, TransposeC>;
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_PreshuffleBQuant<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_PreshuffleBQuant<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
auto& lut = get_kernel_lut();
|
||||
lut[hash_multiple_strings({"fp8", "abquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<
|
||||
GemmConfigPreshuffleB_PreshuffleBQuant<ck_tile::fp8_t, false>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "abquant", "preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<
|
||||
GemmConfigPreshuffleB_PreshuffleBQuant<ck_tile::fp8_t, true>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -223,7 +223,7 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
template <typename PrecType, bool TransposeC_ = true>
|
||||
struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
@@ -231,17 +231,17 @@ struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleB_BQua
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr bool TransposeC = true;
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
template <typename PrecType, bool TransposeC_ = true>
|
||||
struct GemmConfigPreshuffleB_ABQuant_PreshuffleBQuant_Prefill
|
||||
: public GemmConfigPreshuffleB_ABQuant_Prefill<PrecType>
|
||||
: public GemmConfigPreshuffleB_ABQuant_Prefill<PrecType, TransposeC_>
|
||||
{
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
template <typename PrecType, bool TransposeC_ = true>
|
||||
struct GemmConfigPreshuffleB_ABQuant_Decode : public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
@@ -249,7 +249,7 @@ struct GemmConfigPreshuffleB_ABQuant_Decode : public GemmConfigPreshuffleB_BQuan
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr bool TransposeC = true;
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -271,11 +271,11 @@ struct GemmConfigQuantPrefill : public GemmConfigBase
|
||||
// static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
template <typename PrecType, bool TransposeC_ = true>
|
||||
struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill<PrecType>
|
||||
{
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr bool TransposeC = true;
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
};
|
||||
|
||||
// Used for A=16bit and B=8bit. The warp tile has KPack=16
|
||||
@@ -296,8 +296,8 @@ struct GemmConfigMixedPrecision : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 64;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigEightWarps : public GemmConfigABQuantPrefill<PrecType>
|
||||
template <typename PrecType, bool TransposeC_ = true>
|
||||
struct GemmConfigEightWarps : public GemmConfigABQuantPrefill<PrecType, TransposeC_>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong!
|
||||
@@ -308,12 +308,11 @@ struct GemmConfigEightWarps : public GemmConfigABQuantPrefill<PrecType>
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType) * K_Warp;
|
||||
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr bool TransposeC = true;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleBEightWarps : public GemmConfigEightWarps<PrecType>
|
||||
template <typename PrecType, bool TransposeC_ = true>
|
||||
struct GemmConfigPreshuffleBEightWarps : public GemmConfigEightWarps<PrecType, TransposeC_>
|
||||
{
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
|
||||
@@ -235,8 +235,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
// Split-K validation is handled by Kernel::IsSupportedArgument
|
||||
// Split-K is only supported for BQuantGrouped without preshuffle
|
||||
// Split-K validation is handled by Kernel::IsSupportedArgument.
|
||||
// Split-K is supported for:
|
||||
// - BQuantGrouped without preshuffle
|
||||
// - ABQuantGrouped without APreshuffleQuant
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
|
||||
@@ -19,13 +19,13 @@ template <typename TileDistributedSpan_, // tile_distributed_span<...>
|
||||
>
|
||||
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
|
||||
{
|
||||
using DstrSpanImpl = typename remove_cvref_t<TileDistributedSpan_>::Impl;
|
||||
using DstrSpan = remove_cvref_t<TileDistributedSpan_>;
|
||||
|
||||
if constexpr(DstrSpanImpl::size() == 0) // handle the 0-dim span case
|
||||
f(detail::make_tile_distributed_index(sequence<>{}));
|
||||
else
|
||||
static_ford<DstrSpanImpl>{}(
|
||||
[&](auto dstr_idx_impl) { f(detail::make_tile_distributed_index(dstr_idx_impl)); });
|
||||
static_ford<typename DstrSpan::Impl>{}([&](auto dstr_idx_impl) {
|
||||
constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl);
|
||||
|
||||
f(dstr_idx);
|
||||
});
|
||||
}
|
||||
|
||||
// unpacked span, this version support span with unpack(multi-arg) functor
|
||||
|
||||
@@ -15,7 +15,7 @@ namespace ck_tile {
|
||||
// B is block window on block distributed tensor.
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename BlockPolicy_>
|
||||
struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
struct BlockGemmWeightPreshuffleABQuantARegBRegCReg : public BlockGemmQuantBase
|
||||
{
|
||||
private:
|
||||
template <typename PipelineProblem_, typename GemmPolicy_>
|
||||
@@ -121,6 +121,7 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
};
|
||||
|
||||
public:
|
||||
using Base = BlockGemmQuantBase;
|
||||
using Traits = GemmTraits_<Problem_, BlockPolicy_>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using BlockPolicy = remove_cvref_t<BlockPolicy_>;
|
||||
@@ -217,22 +218,6 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
auto q_block_tensor = aq_block_tensor;
|
||||
constexpr bool SimpleDequant =
|
||||
Traits::NQPerBlock == 1 &&
|
||||
AccTensor::get_distributed_spans()[I0].impl_.size() == 0; // c_transpose
|
||||
if constexpr(SimpleDequant)
|
||||
{
|
||||
constexpr auto aq_spans = AQBlockTensor::get_distributed_spans();
|
||||
sweep_tile_span(aq_spans[I0], [&](auto im) {
|
||||
sweep_tile_span(aq_spans[I1], [&](auto ik) {
|
||||
q_block_tensor(make_tuple(im, ik)) *=
|
||||
bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik));
|
||||
});
|
||||
});
|
||||
}
|
||||
// hot loop:
|
||||
static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
zero_accumulators();
|
||||
static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
@@ -265,29 +250,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
}
|
||||
});
|
||||
});
|
||||
static_for_product<number<MIterPerWarp>, number<NIterPerWarp>>{}([&](auto mIter,
|
||||
auto nIter) {
|
||||
if constexpr(SimpleDequant)
|
||||
{
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
constexpr auto block_idx_m = tile_distributed_index<mIter>{};
|
||||
constexpr auto block_idx_kq = tile_distributed_index<kQScale>{};
|
||||
|
||||
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
|
||||
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
|
||||
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
|
||||
c_ref += acc_val * q_block_tensor(make_tuple(block_idx_m, block_idx_kq));
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
|
||||
aq_block_tensor);
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(aq_block_tensor);
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
@@ -305,9 +270,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
return nIter * KPerBlockBQ + kQScale;
|
||||
}
|
||||
}();
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float b_scale_reg_f =
|
||||
aq_picker.template cvt_scale_to_fp32<BQDataType>(scale_reg);
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float b_scale_reg_f = Base::cvt_scale_to_fp32<BQDataType>(scale_reg);
|
||||
|
||||
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
|
||||
float a_scale_reg_f = aq_picker.template pick<c_row>();
|
||||
@@ -315,7 +279,7 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
|
||||
c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f;
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -291,66 +291,37 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
"C block tensor data type!");
|
||||
constexpr auto warp_size = get_warp_size();
|
||||
|
||||
// Start from AQ block tensor and then scale it using BQ; this represents
|
||||
// the combined A/B quantization scales for the block.
|
||||
auto q_block_tensor = aq_block_tensor;
|
||||
constexpr bool SimpleDequant =
|
||||
Traits::NQPerBlock == 1 &&
|
||||
CWarpTensor::get_distributed_spans()[I0{}].impl_.size() == 0; // c_transpose
|
||||
if constexpr(SimpleDequant)
|
||||
{
|
||||
constexpr auto aq_spans = AQBlockTensor::get_distributed_spans();
|
||||
sweep_tile_span(aq_spans[I0{}], [&](auto im) {
|
||||
sweep_tile_span(aq_spans[I1{}], [&](auto ik) {
|
||||
q_block_tensor(make_tuple(im, ik)) *=
|
||||
bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// hot loop:
|
||||
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
static_for_product<number<MIterPerWarp>, number<NIterPerWarp>>{}([&](auto mIter,
|
||||
auto nIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
|
||||
if constexpr(kIterInQScale == 0)
|
||||
{
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
});
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() =
|
||||
a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
if constexpr(SimpleDequant)
|
||||
{
|
||||
constexpr auto cw_spans = CWarpTensor::get_distributed_spans();
|
||||
sweep_tile_span(cw_spans[I1{}], [&](auto in) {
|
||||
constexpr auto block_idx_m = tile_distributed_index<mIter>{};
|
||||
constexpr auto block_idx_n = detail::make_tile_distributed_index(
|
||||
merge_sequences(sequence<nIter>{}, in.impl_));
|
||||
constexpr auto block_idx_kq = tile_distributed_index<kQScale>{};
|
||||
constexpr auto empty_idx = tile_distributed_index<>{};
|
||||
c_block_tensor(make_tuple(block_idx_m, block_idx_n)) +=
|
||||
c_warp_tensor(make_tuple(empty_idx, in)) *
|
||||
q_block_tensor(make_tuple(block_idx_m, block_idx_kq));
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() =
|
||||
b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
if constexpr(kIterInQScale == 0)
|
||||
{
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
@@ -435,7 +406,7 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
b_scale_reg_f);
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -448,18 +448,46 @@ struct QuantGemmKernel
|
||||
// offset = bq_group_offset
|
||||
bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset);
|
||||
}
|
||||
|
||||
aq_group_offset = 0;
|
||||
aq_k_split_offset = 0;
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !APreshuffleQuant)
|
||||
{
|
||||
using AQuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
|
||||
// Compute AQ K-group offset for this split-K batch.
|
||||
// AQ tensor layout is RowMajor [M, QK_A] with stride [stride_AQ, 1].
|
||||
// Advancing to column aq_group_offset means a pointer offset of aq_group_offset
|
||||
// elements (column stride = 1).
|
||||
const index_t k_offset_aq = amd_wave_read_first_lane(k_id * KRead);
|
||||
aq_group_offset = amd_wave_read_first_lane(k_offset_aq / AQuantGroupSize::kK);
|
||||
aq_k_split_offset = amd_wave_read_first_lane(aq_group_offset);
|
||||
|
||||
// Compute BQ K-group offset for this split-K batch.
|
||||
// BQ tensor layout is ColumnMajor [N/kN, K/kK] with stride [K/kK, 1] for
|
||||
// ABQuantGrouped. Advancing to column bq_group_offset means a pointer offset of
|
||||
// bq_group_offset elements (column stride = 1).
|
||||
const index_t k_offset_bq = amd_wave_read_first_lane(k_id * KRead);
|
||||
bq_group_offset = amd_wave_read_first_lane(k_offset_bq / BQuantGroupSize::kK);
|
||||
bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset);
|
||||
}
|
||||
else
|
||||
{
|
||||
bq_group_offset = 0;
|
||||
bq_k_split_offset = 0;
|
||||
aq_group_offset = 0;
|
||||
aq_k_split_offset = 0;
|
||||
}
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t bq_group_offset; // Logical offset in K-groups (K/kK dimension)
|
||||
index_t bq_k_split_offset; // Memory pointer offset (accounting for layout/stride)
|
||||
index_t aq_group_offset; // Logical offset in K-groups for AQ (K/kK dimension)
|
||||
index_t aq_k_split_offset; // Memory pointer offset for AQ
|
||||
index_t bq_group_offset; // Logical offset in K-groups for BQ (K/kK dimension)
|
||||
index_t bq_k_split_offset; // Memory pointer offset for BQ (accounting for layout/stride)
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
@@ -532,7 +560,8 @@ struct QuantGemmKernel
|
||||
CK_TILE_DEVICE static auto MakeAQBlockWindow(const AQDataType* aq_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
const index_t i_n,
|
||||
const index_t aq_group_offset = 0)
|
||||
{
|
||||
// Step 1: Create tensor view for AQ
|
||||
const auto& aq_tensor_view = [&]() {
|
||||
@@ -615,11 +644,14 @@ struct QuantGemmKernel
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !APreshuffleQuant)
|
||||
{
|
||||
// For split-K, aq_ptr is already offset by aq_k_split_offset elements.
|
||||
// The remaining K-groups from this offset position = QK_A - aq_group_offset.
|
||||
const index_t remaining_qk_a = kargs.QK_A - aq_group_offset;
|
||||
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.M, kargs.QK_A),
|
||||
make_tuple(kargs.M, remaining_qk_a),
|
||||
make_tuple(kargs.stride_AQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
@@ -628,9 +660,8 @@ struct QuantGemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.M, kargs.QK_A),
|
||||
make_tuple(kargs.M, remaining_qk_a),
|
||||
make_tuple(1, kargs.stride_AQ),
|
||||
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
@@ -1100,26 +1131,32 @@ struct QuantGemmKernel
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs)
|
||||
{
|
||||
// Split-K is supported for BQuantGrouped mode without preshuffle
|
||||
// Split-K is supported for BQuantGrouped (without preshuffle) and
|
||||
// ABQuantGrouped (without APreshuffleQuant) modes.
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
constexpr bool is_bquant_non_preshuffle =
|
||||
(kQuantType == QuantType::BQuantGrouped) && !BPreshuffleQuant;
|
||||
if constexpr(!is_bquant_non_preshuffle)
|
||||
constexpr bool is_abquant_non_preshuffle =
|
||||
(kQuantType == QuantType::ABQuantGrouped) && !APreshuffleQuant;
|
||||
constexpr bool is_splitk_supported =
|
||||
is_bquant_non_preshuffle || is_abquant_non_preshuffle;
|
||||
|
||||
if constexpr(!is_splitk_supported)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met for Kbatch >1 ! "
|
||||
"Split-K only supported for BQuantGrouped without preshuffle.");
|
||||
"Split-K is supported for BQuantGrouped without preshuffle "
|
||||
"and ABQuantGrouped without APreshuffleQuant.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
|
||||
const index_t K_t = kargs.k_batch * K1;
|
||||
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
|
||||
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
|
||||
const index_t K_t = kargs.k_batch * K1;
|
||||
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; // per-batch K read size
|
||||
constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
@@ -1137,22 +1174,67 @@ struct QuantGemmKernel
|
||||
return false;
|
||||
}
|
||||
|
||||
// Constraint 2: KRead must align with quantization group boundaries.
|
||||
// Each split-K batch reads KRead consecutive K elements. If KRead is not
|
||||
// a multiple of BQuantGroupSize::kK, the batch will span partial quantization
|
||||
// groups, requiring split access to a quantization scale. This violates the
|
||||
// atomic processing requirement where each batch must work with complete groups.
|
||||
if(KRead % BQuantGroupSize::kK != 0)
|
||||
// Constraint 2: KRead must align with B quantization group boundaries.
|
||||
if constexpr(is_bquant_non_preshuffle || is_abquant_non_preshuffle)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
if(KRead % BQuantGroupSize::kK != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Split-K batch size must be aligned with quantization group "
|
||||
"size! KRead=" +
|
||||
std::to_string(KRead) +
|
||||
" is not divisible by BQuantGroupSize::kK=" +
|
||||
std::to_string(BQuantGroupSize::kK));
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Split-K batch size must be aligned with B quantization group "
|
||||
"size! KRead=" +
|
||||
std::to_string(KRead) +
|
||||
" is not divisible by BQuantGroupSize::kK=" +
|
||||
std::to_string(BQuantGroupSize::kK));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Constraint 3: KRead must align with A quantization group boundaries
|
||||
// (only needed for ABQuantGrouped since AQ also indexes into K).
|
||||
if constexpr(is_abquant_non_preshuffle)
|
||||
{
|
||||
using AQuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
|
||||
if(KRead % AQuantGroupSize::kK != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Split-K batch size must be aligned with A quantization group "
|
||||
"size! KRead=" +
|
||||
std::to_string(KRead) +
|
||||
" is not divisible by AQuantGroupSize::kK=" +
|
||||
std::to_string(AQuantGroupSize::kK));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Constraint 4: per-batch K must span at least 2 K_Tile iterations.
|
||||
// The software-pipelined GEMM kernels (CompV3 family) prefetch one tile
|
||||
// ahead and require num_loop >= 2 per batch. When KRead == KPerBlock
|
||||
// (i.e. per_batch_num_loop == 1) the prefetch would read the tile
|
||||
// belonging to the next split-K batch, producing incorrect results.
|
||||
{
|
||||
const index_t per_batch_num_loop =
|
||||
KRead / static_cast<index_t>(TilePartitioner::KPerBlock);
|
||||
if(per_batch_num_loop < 2)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Split-K requires at least 2 K-tile iterations per batch. "
|
||||
"KRead=" +
|
||||
std::to_string(KRead) + " < 2 * KPerBlock=" +
|
||||
std::to_string(2 *
|
||||
static_cast<index_t>(TilePartitioner::KPerBlock)) +
|
||||
". Increase K or decrease k_batch.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1243,6 +1325,18 @@ struct QuantGemmKernel
|
||||
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
// For RowMajor C, M is the row dimension — check M alignment here because
|
||||
// ALayout=RowMajor does not check M (it only checks K), leaving a gap for
|
||||
// the RowMajorA + RowMajorC combination.
|
||||
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Can't support M that is not a multiple of MPerBlock without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
@@ -1315,7 +1409,10 @@ struct QuantGemmKernel
|
||||
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_block_window =
|
||||
MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
// Note: Pass aq_group_offset so the tensor view dimension reflects
|
||||
// the remaining K-groups from the split-K offset position.
|
||||
const auto& aq_block_window = MakeAQBlockWindow(
|
||||
aq_ptr, kargs, block_idx_m, block_idx_n, splitk_batch_offset.aq_group_offset);
|
||||
// Note: Pass bq_group_offset so the tensor view dimension reflects
|
||||
// the remaining K-groups from the split-K offset position.
|
||||
const auto& bq_block_window = MakeBQBlockWindow(
|
||||
@@ -1445,7 +1542,10 @@ struct QuantGemmKernel
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
|
||||
// For ABQuantGrouped split-K, aq_ptr is offset by aq_k_split_offset elements to point
|
||||
// to the start of this batch's AQ K-groups (aq_group_offset columns in RowMajor AQ).
|
||||
const AQDataType* aq_ptr =
|
||||
static_cast<const AQDataType*>(kargs.aq_ptr) + splitk_batch_offset.aq_k_split_offset;
|
||||
const BQDataType* bq_ptr =
|
||||
static_cast<const BQDataType*>(kargs.bq_ptr) + splitk_batch_offset.bq_k_split_offset;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
@@ -108,14 +108,10 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(), BQuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
/**
|
||||
* @tparam nloop The number of iterations in the hot loop,
|
||||
* used to normalize scheduling costs.
|
||||
*/
|
||||
|
||||
template <index_t nloop>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
static_assert(nloop > 0, "nloop must be greater than 0");
|
||||
// Estimated number of VMEM vector loads for A per block:
|
||||
// total A bytes / (threads per block * vector width)
|
||||
constexpr index_t Aload_inst =
|
||||
@@ -138,13 +134,12 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
// Total VMEM load instructions (A + B + quant data)
|
||||
constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst;
|
||||
// Approximate number of LDS reads per block
|
||||
constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle / nloop;
|
||||
constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle;
|
||||
// Approximate number of LDS writes per block
|
||||
// (e.g., writing A from VMEM into LDS once per A load)
|
||||
constexpr index_t ds_write_inst = Aload_inst;
|
||||
// Number of MFMA instructions per wave for one block tile:
|
||||
constexpr index_t mfma_inst =
|
||||
((kMPerBlock / WG::kM) / nloop) * ((kNPerBlock / WG::kN) / nloop);
|
||||
constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN);
|
||||
// How often (in MFMA units) we should insert DS (LDS) operations.
|
||||
constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst);
|
||||
// How often (in MFMA units) we should insert VMEM buffer loads.
|
||||
@@ -181,7 +176,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
}
|
||||
// Always mark some VALU work in the loop to reflect auxiliary scalar
|
||||
// or vector ALU instructions that coexist with MFMA (Blockscale calculation).
|
||||
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 4, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -409,6 +404,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
|
||||
// Prefetch A1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// initialize C
|
||||
@@ -437,7 +433,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
while(iCounter > 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// Prefill A(2i+1) ds_write
|
||||
// Prefill A(2i+1)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
|
||||
|
||||
@@ -465,14 +461,10 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
});
|
||||
});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// prefetch Q(2i+1)
|
||||
aq_block_tile_2 = load_tile(aq_copy_dram_window);
|
||||
move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ});
|
||||
bq_block_tile_2 = load_tile(bq_copy_dram_window);
|
||||
move_tile_window(bq_copy_dram_window, bq_dram_tile_window_step);
|
||||
|
||||
// Preload A(2i+1) ds_read
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
@@ -494,8 +486,6 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
});
|
||||
});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// prefetch Q(2i+1)
|
||||
aq_block_tile = load_tile(aq_copy_dram_window);
|
||||
move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ});
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
@@ -517,7 +507,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
aq_block_tile_2,
|
||||
bq_block_tile_2,
|
||||
a_warp_windows_pong);
|
||||
// Preload A(2i+2) ds_read
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
@@ -557,7 +547,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
aq_block_tile,
|
||||
bq_block_tile,
|
||||
a_warp_windows_ping);
|
||||
// Preload A ds_read
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
|
||||
@@ -81,6 +81,17 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_preshuffle_preshuffleQuant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# ABQuant split-K tests
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_splitk_decode
|
||||
test_gemm_quant_abquant_splitk_decode.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_splitk_decode PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_splitk_prefill
|
||||
test_gemm_quant_abquant_splitk_prefill.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_splitk_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_base
|
||||
test_gemm_quant_abquant_a4w4_base.cpp
|
||||
)
|
||||
@@ -268,7 +279,14 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
test_tile_gemm_quant_abquant_base
|
||||
test_tile_gemm_quant_abquant_padding
|
||||
test_tile_gemm_quant_abquant_preshuffle
|
||||
test_tile_gemm_quant_abquant_preshuffle_preshuffleQuant
|
||||
test_tile_gemm_quant_abquant_preshuffleQuant
|
||||
test_tile_gemm_quant_abquant_a4w4_base
|
||||
test_tile_gemm_quant_abquant_a4w4_padding
|
||||
test_tile_gemm_quant_abquant_a4w4_preshuffle
|
||||
# ABQuant split-K tests
|
||||
test_tile_gemm_quant_abquant_splitk_decode
|
||||
test_tile_gemm_quant_abquant_splitk_prefill
|
||||
# BQuant tests
|
||||
test_tile_gemm_quant_bquant_1d_128
|
||||
test_tile_gemm_quant_bquant_1d_64
|
||||
@@ -276,6 +294,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
test_tile_gemm_quant_bquant_2d_medium_n
|
||||
test_tile_gemm_quant_bquant_2d_large_n
|
||||
test_tile_gemm_quant_bquant_transpose
|
||||
# BQuant split-K tests
|
||||
test_tile_gemm_quant_bquant_splitk_decode
|
||||
test_tile_gemm_quant_bquant_splitk_prefill
|
||||
# BQuant preshuffle tests
|
||||
test_tile_gemm_quant_bquant_preshuffle_decode_1d
|
||||
test_tile_gemm_quant_bquant_preshuffle_prefill_1d
|
||||
|
||||
@@ -28,7 +28,7 @@ using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>
|
||||
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
|
||||
// clang-format off
|
||||
using ABQuantTypes = ::testing::Types<
|
||||
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
|
||||
// 1D BScales; PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
|
||||
@@ -36,12 +36,13 @@ using ABQuantTypes = ::testing::Types<
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
// 2D B-scales; PreshuffleQuant = false && TransposeC = true (RCR layout with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigTransposeC, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigTransposeC, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigTransposeC, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigTransposeC, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigTransposeC, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigTransposeC, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -28,9 +28,11 @@ using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>
|
||||
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
|
||||
// clang-format off
|
||||
using ABQuantPreshuffleBTypes = ::testing::Types<
|
||||
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
|
||||
// 1D B-scales; PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
|
||||
/// 2D B-scales; PreshuffleQuant = false && TransposeC = true (RCR layout with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefillTransposeC, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -28,8 +28,8 @@ using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>
|
||||
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
|
||||
// clang-format off
|
||||
using ABQuantPreshuffleQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantPrefill, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantPrefill, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantPrefill<false>, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantPrefill<true>, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using ABQuantGrouped =
|
||||
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
|
||||
using GroupSize1x1x128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using GroupSize1x128x128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for ABQuant split-K tests - Decode shape
|
||||
// GemmConfigDecode: M_Tile=16, N_Tile=64, K_Tile=256, kPadK=false
|
||||
// Constraints: M % 16 == 0, N % 64 == 0, K % (k_batch * 256) == 0
|
||||
//
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
|
||||
// clang-format off
|
||||
using ABQuantSplitKDecodeTypes = ::testing::Types<
|
||||
// GroupSize 1x1x128 (kK=128 for both A and B, kN=1)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigDecode, GroupSize1x1x128, GroupSize1x1x128, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigDecode, GroupSize1x1x128, GroupSize1x1x128, ColumnMajor>,
|
||||
// GroupSize 1x128x128 for B (kK=128, kN=128)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigDecode, GroupSize1x1x128, GroupSize1x128x128, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigDecode, GroupSize1x1x128, GroupSize1x128x128, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for ABQuant split-K Decode
|
||||
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantSplitKDecodeTypes);
|
||||
|
||||
// ---- k_batch=2 ----------------------------------------------------------------
|
||||
// Note: K=512 (= 2*K_Tile) is excluded because KRead=K_Tile=256, giving
|
||||
// per_batch_num_loop=1 which the software-pipelined kernel cannot handle.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK2_MedK_BaseShape)
|
||||
{
|
||||
// K=1024=4*256: standard decode decode shape
|
||||
this->run_test_with_validation(32, 64, 1024, 2);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK2_LargeK_WideN)
|
||||
{
|
||||
// K=2048, larger N (multiple of N_Tile=64)
|
||||
this->run_test_with_validation(32, 256, 2048, 2);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK2_LargeK_TallM)
|
||||
{
|
||||
// K=4096, larger M (multiple of M_Tile=16)
|
||||
this->run_test_with_validation(64, 64, 4096, 2);
|
||||
}
|
||||
|
||||
// ---- k_batch=3 ----------------------------------------------------------------
|
||||
// Note: K=768 (= 3*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK3_MedK_BaseShape)
|
||||
{
|
||||
// K=1536=6*256
|
||||
this->run_test_with_validation(32, 64, 1536, 3);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK3_LargeK_BaseShape)
|
||||
{
|
||||
// K=3072=12*256
|
||||
this->run_test_with_validation(32, 64, 3072, 3);
|
||||
}
|
||||
|
||||
// ---- k_batch=4 ----------------------------------------------------------------
|
||||
// Note: K=1024 (= 4*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK4_MedK_BaseShape)
|
||||
{
|
||||
// K=2048=8*256
|
||||
this->run_test_with_validation(32, 64, 2048, 4);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK4_LargeK_WideN)
|
||||
{
|
||||
// K=4096, wider N
|
||||
this->run_test_with_validation(32, 128, 4096, 4);
|
||||
}
|
||||
|
||||
// ---- k_batch=5 ----------------------------------------------------------------
|
||||
// Note: K=1280 (= 5*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK5_MedK_BaseShape)
|
||||
{
|
||||
// K=2560=10*256
|
||||
this->run_test_with_validation(32, 64, 2560, 5);
|
||||
}
|
||||
|
||||
// ---- k_batch=6 ----------------------------------------------------------------
|
||||
// Note: K=1536 (= 6*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK6_LargeK_BaseShape)
|
||||
{
|
||||
// K=3072=12*256
|
||||
this->run_test_with_validation(32, 64, 3072, 6);
|
||||
}
|
||||
|
||||
// ---- k_batch=8 ----------------------------------------------------------------
|
||||
// Note: K=2048 (= 8*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK8_LargeK_BaseShape)
|
||||
{
|
||||
// K=4096=16*256
|
||||
this->run_test_with_validation(32, 64, 4096, 8);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK8_LargeK_LargeMN)
|
||||
{
|
||||
// K=4096, larger M and N
|
||||
this->run_test_with_validation(48, 192, 4096, 8);
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using ABQuantGrouped =
|
||||
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
|
||||
using GroupSize1x1x128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using GroupSize1x128x128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for ABQuant split-K tests - Prefill shape
|
||||
// GemmConfigPrefill: M_Tile=128, N_Tile=128, K_Tile=128, kPadK=false
|
||||
// Constraints: M % 128 == 0, N % 128 == 0, K % (k_batch * 128) == 0
|
||||
//
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
|
||||
// clang-format off
|
||||
using ABQuantSplitKPrefillTypes = ::testing::Types<
|
||||
// GroupSize 1x1x128 (kK=128 for both A and B, kN=1)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPrefill, GroupSize1x1x128, GroupSize1x1x128, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigPrefill, GroupSize1x1x128, GroupSize1x1x128, ColumnMajor>,
|
||||
// GroupSize 1x128x128 for B (kK=128, kN=128)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPrefill, GroupSize1x1x128, GroupSize1x128x128, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigPrefill, GroupSize1x1x128, GroupSize1x128x128, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for ABQuant split-K Prefill
|
||||
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantSplitKPrefillTypes);
|
||||
|
||||
// ---- k_batch=2 ----------------------------------------------------------------
|
||||
// Note: K=256 (= 2*K_Tile) excluded: KRead=K_Tile=128, per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK2_MedK_BaseShape)
|
||||
{
|
||||
// K=1024=8*128
|
||||
this->run_test_with_validation(128, 128, 1024, 2);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK2_LargeK_WideN)
|
||||
{
|
||||
// K=2048, wider N
|
||||
this->run_test_with_validation(128, 256, 2048, 2);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK2_LargeK_TallM)
|
||||
{
|
||||
// K=4096, taller M
|
||||
this->run_test_with_validation(256, 128, 4096, 2);
|
||||
}
|
||||
|
||||
// ---- k_batch=3 ----------------------------------------------------------------
|
||||
// Note: K=384 (= 3*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK3_MedK_BaseShape)
|
||||
{
|
||||
// K=768=6*128
|
||||
this->run_test_with_validation(128, 128, 768, 3);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK3_LargeK_BaseShape)
|
||||
{
|
||||
// K=3072=24*128
|
||||
this->run_test_with_validation(128, 128, 3072, 3);
|
||||
}
|
||||
|
||||
// ---- k_batch=4 ----------------------------------------------------------------
|
||||
// Note: K=512 (= 4*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK4_MedK_BaseShape)
|
||||
{
|
||||
// K=2048=16*128
|
||||
this->run_test_with_validation(128, 128, 2048, 4);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK4_LargeK_LargeMN)
|
||||
{
|
||||
// K=4096, larger M and N
|
||||
this->run_test_with_validation(256, 256, 4096, 4);
|
||||
}
|
||||
|
||||
// ---- k_batch=5 ----------------------------------------------------------------
|
||||
// Note: K=640 (= 5*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK5_MedK_BaseShape)
|
||||
{
|
||||
// K=1280=10*128
|
||||
this->run_test_with_validation(128, 128, 1280, 5);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK5_LargeK_BaseShape)
|
||||
{
|
||||
// K=2560=20*128
|
||||
this->run_test_with_validation(128, 128, 2560, 5);
|
||||
}
|
||||
|
||||
// ---- k_batch=6 ----------------------------------------------------------------
|
||||
// Note: K=768 (= 6*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK6_LargeK_BaseShape)
|
||||
{
|
||||
// K=3072=24*128
|
||||
this->run_test_with_validation(128, 128, 3072, 6);
|
||||
}
|
||||
|
||||
// ---- k_batch=8 ----------------------------------------------------------------
|
||||
// Note: K=1024 (= 8*K_Tile) excluded: per_batch_num_loop=1.
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK8_MedK_BaseShape)
|
||||
{
|
||||
// K=2048=16*128
|
||||
this->run_test_with_validation(128, 128, 2048, 8);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, SplitK8_LargeK_LargeMN)
|
||||
{
|
||||
// K=4096, larger M and N
|
||||
this->run_test_with_validation(256, 256, 4096, 8);
|
||||
}
|
||||
@@ -158,6 +158,10 @@ struct GemmConfigPreshuffleBPrefill : public GemmConfigPrefill
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
struct GemmConfigPreshuffleBPrefillTransposeC : public GemmConfigPreshuffleBPrefill
|
||||
{
|
||||
static constexpr bool TransposeC = true;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleQuantPrefill : public GemmConfigPrefill
|
||||
{
|
||||
@@ -170,14 +174,18 @@ struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBP
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <bool TransposeC_ = false>
|
||||
struct GemmConfigPreshuffleBPreshuffleQuantPrefill : public GemmConfigPreshuffleBPrefill
|
||||
{
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
};
|
||||
|
||||
template <bool TransposeC_ = false>
|
||||
struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode
|
||||
{
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
@@ -980,7 +988,10 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
void TearDownQuantTypeSpecific() {}
|
||||
|
||||
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
void run_test_with_validation(ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t k_batch = 1)
|
||||
{
|
||||
const ck_tile::index_t stride_A =
|
||||
ck_tile::get_default_stride(M, K, 0, this->is_row_major(ALayout{}));
|
||||
@@ -1091,6 +1102,13 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
bq_bqk_bqn_dev_buf.ToDevice(bq_bqk_bqn.data());
|
||||
}
|
||||
|
||||
// For split-K (k_batch > 1), the kernel uses atomic_add to accumulate partial results
|
||||
// into C. Zero the output buffer before launching so atomic additions start from zero.
|
||||
if(k_batch > 1)
|
||||
{
|
||||
c_m_n_dev_buf.SetZero();
|
||||
}
|
||||
|
||||
// Create args for kernel execution
|
||||
ck_tile::QuantGemmHostArgs args{
|
||||
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
|
||||
@@ -1098,7 +1116,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
|
||||
aq_m_aqk_dev_buf.GetDeviceBuffer(), // aq_ptr (scales)
|
||||
bq_bqk_bqn_dev_buf.GetDeviceBuffer(), // bq_ptr (scales)
|
||||
1, // k_batch
|
||||
k_batch, // k_batch
|
||||
M,
|
||||
N,
|
||||
K, // M, N, K
|
||||
@@ -1136,12 +1154,12 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
|
||||
|
||||
// Calculate error tolerances
|
||||
// Calculate error tolerances (adjusted for split-K accumulation error)
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1, max_accumulated_value);
|
||||
K, k_batch, max_accumulated_value);
|
||||
|
||||
// Validate results
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
@@ -1151,7 +1169,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
EXPECT_TRUE(pass) << "ABQuantGrouped validation failed with M=" << M << ", N=" << N
|
||||
<< ", K=" << K;
|
||||
<< ", K=" << K << ", k_batch=" << k_batch;
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user