diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 97e719177f..b9f94b7491 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -14,7 +14,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") set(EXE_NAME tile_example_gemm_quant) add_executable(${EXE_NAME} gemm_quant.cpp - gemm_abquant_quantgrouped.cpp + gemm_abquant_quantgrouped_fp8.cpp + gemm_abquant_quantgrouped_fp4.cpp + gemm_abquant_quantgrouped_bf8.cpp + gemm_abquant_quantgrouped_preshuffleb_fp8.cpp + gemm_abquant_quantgrouped_preshuffleb_bf8.cpp gemm_abquant_quantgrouped_preshuffleb_preshufflequant.cpp gemm_aquant_quantgrouped.cpp gemm_aquant_quantgrouped_preshufflequant.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp deleted file mode 100644 index 4ece442158..0000000000 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "run_gemm_quant_example.inc" - -#if defined(CK_USE_GFX950) -template -using GemmConfig = GemmConfigEightWarps; -template -using GemmConfigPrefill = GemmConfigPreshuffleBEightWarps; -#else -template -using GemmConfig = GemmConfigABQuantPrefill; -template -using GemmConfigPrefill = GemmConfigPreshuffleB_ABQuant_Prefill; -#endif - -static auto _ = []() { - auto& lut = get_kernel_lut(); - lut[hash_multiple_strings({"fp8", - "abquant", - "non-preshuffleb", - "non-preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "abquant", - "non-preshuffleb", - "non-preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "abquant", - "non-preshuffleb", - "non-preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "abquant", - "non-preshuffleb", - "non-preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "abquant", - "preshuffleb", - "non-preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "abquant", - "preshuffleb", - "non-preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "abquant", - "preshuffleb", - "non-preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "abquant", - "preshuffleb", - "non-preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "abquant", - "non-preshuffleb", - "preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "abquant", - "non-preshuffleb", - "preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp4", "abquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type< - GemmConfigPreshuffleB_ABQuant_Prefill, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - return 0; -}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.h b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.h new file mode 100644 index 0000000000..2b4c381cdc --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.h @@ -0,0 +1,18 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "gemm_utils.hpp" + +#if defined(CK_USE_GFX950) +template +using GemmConfig = GemmConfigEightWarps; +template +using GemmConfigPrefill = GemmConfigPreshuffleBEightWarps; +#else +template +using GemmConfig = GemmConfigABQuantPrefill; +template +using GemmConfigPrefill = GemmConfigPreshuffleB_ABQuant_Prefill; +#endif diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_bf8.cpp new file mode 100644 index 0000000000..85ef4ddb89 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_bf8.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" +#include "gemm_abquant_quantgrouped.h" + +static auto _ = []() { + auto& lut = get_kernel_lut(); + + lut[hash_multiple_strings({"bf8", + "abquant", + "non-preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "non-preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp4.cpp new file mode 100644 index 0000000000..fa515518be --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp4.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" +#include "gemm_abquant_quantgrouped.h" + +static auto _ = []() { + auto& lut = get_kernel_lut(); + lut[hash_multiple_strings( + {"fp4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp4", "abquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type< + GemmConfigPreshuffleB_ABQuant_Prefill, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp8.cpp new file mode 100644 index 0000000000..fdf8b35b7e --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp8.cpp @@ -0,0 +1,70 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" +#include "gemm_abquant_quantgrouped.h" + +static auto _ = []() { + auto& lut = get_kernel_lut(); + lut[hash_multiple_strings({"fp8", + "abquant", + "non-preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "non-preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "non-preshuffleb", + "preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "non-preshuffleb", + "preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_bf8.cpp new file mode 100644 index 0000000000..667ea8efbf --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_bf8.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" +#include "gemm_abquant_quantgrouped.h" + +static auto _ = []() { + auto& lut = get_kernel_lut(); + + lut[hash_multiple_strings({"bf8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_fp8.cpp new file mode 100644 index 0000000000..894eb13328 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_fp8.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" +#include "gemm_abquant_quantgrouped.h" + +static auto _ = []() { + auto& lut = get_kernel_lut(); + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_preshufflequant.cpp index e854a3bcc8..41f77f4b4b 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_preshufflequant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_preshufflequant.cpp @@ -4,41 +4,41 @@ #include "38_block_scale_gemm/gemm_utils.hpp" #include "run_gemm_quant_example.inc" -template +template using GemmConfigPreshuffleB_PreshuffleBQuant = - GemmConfigPreshuffleB_ABQuant_PreshuffleBQuant_Prefill; + GemmConfigPreshuffleB_ABQuant_PreshuffleBQuant_Prefill; static auto _ = []() { - auto& lut = get_kernel_lut(); - lut[hash_multiple_strings({"fp8", - "abquant", - "preshuffleb", - "preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "abquant", - "preshuffleb", - "preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; + auto& lut = get_kernel_lut(); + lut[hash_multiple_strings({"fp8", "abquant", "preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type< + GemmConfigPreshuffleB_PreshuffleBQuant, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", "abquant", "preshuffleb", "preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type< + GemmConfigPreshuffleB_PreshuffleBQuant, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; return 0; }(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index db3f4c6e17..fef00e993b 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -223,7 +223,7 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill static constexpr bool BPreshuffleQuant = true; }; -template +template struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleB_BQuant_Prefill { static constexpr ck_tile::index_t M_Warp = 2; @@ -231,17 +231,17 @@ struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleB_BQua static constexpr ck_tile::index_t K_Warp = 1; static constexpr bool kPadK = false; - static constexpr bool TransposeC = true; + static constexpr bool TransposeC = TransposeC_; }; -template +template struct GemmConfigPreshuffleB_ABQuant_PreshuffleBQuant_Prefill - : public GemmConfigPreshuffleB_ABQuant_Prefill + : public GemmConfigPreshuffleB_ABQuant_Prefill { static constexpr bool BPreshuffleQuant = true; }; -template +template struct GemmConfigPreshuffleB_ABQuant_Decode : public GemmConfigPreshuffleB_BQuant_Prefill { static constexpr ck_tile::index_t M_Tile = 16; @@ -249,7 +249,7 @@ struct GemmConfigPreshuffleB_ABQuant_Decode : public GemmConfigPreshuffleB_BQuan static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); static constexpr bool kPadK = false; - static constexpr bool TransposeC = true; + static constexpr bool TransposeC = TransposeC_; }; template @@ -271,11 +271,11 @@ struct GemmConfigQuantPrefill : public GemmConfigBase // static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; -template +template struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill { static constexpr bool kPadK = false; - static constexpr bool TransposeC = true; + static constexpr bool TransposeC = TransposeC_; }; // Used for A=16bit and B=8bit. The warp tile has KPack=16 @@ -296,8 +296,8 @@ struct GemmConfigMixedPrecision : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = 64; }; -template -struct GemmConfigEightWarps : public GemmConfigABQuantPrefill +template +struct GemmConfigEightWarps : public GemmConfigABQuantPrefill { static constexpr ck_tile::index_t M_Warp = 4; static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong! @@ -308,12 +308,11 @@ struct GemmConfigEightWarps : public GemmConfigABQuantPrefill static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType) * K_Warp; static constexpr bool kPadK = false; - static constexpr bool TransposeC = true; static constexpr int kBlockPerCu = 1; }; -template -struct GemmConfigPreshuffleBEightWarps : public GemmConfigEightWarps +template +struct GemmConfigPreshuffleBEightWarps : public GemmConfigEightWarps { static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index da14f85c2c..85faf8b58b 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -235,8 +235,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); const dim3 blocks = Kernel::BlockSize(); - // Split-K validation is handled by Kernel::IsSupportedArgument - // Split-K is only supported for BQuantGrouped without preshuffle + // Split-K validation is handled by Kernel::IsSupportedArgument. + // Split-K is supported for: + // - BQuantGrouped without preshuffle + // - ABQuantGrouped without APreshuffleQuant if(!Kernel::IsSupportedArgument(kargs)) { throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); diff --git a/include/ck_tile/core/tensor/sweep_tile.hpp b/include/ck_tile/core/tensor/sweep_tile.hpp index e074fe4b14..1947ce0289 100644 --- a/include/ck_tile/core/tensor/sweep_tile.hpp +++ b/include/ck_tile/core/tensor/sweep_tile.hpp @@ -19,13 +19,13 @@ template > CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f) { - using DstrSpanImpl = typename remove_cvref_t::Impl; + using DstrSpan = remove_cvref_t; - if constexpr(DstrSpanImpl::size() == 0) // handle the 0-dim span case - f(detail::make_tile_distributed_index(sequence<>{})); - else - static_ford{}( - [&](auto dstr_idx_impl) { f(detail::make_tile_distributed_index(dstr_idx_impl)); }); + static_ford{}([&](auto dstr_idx_impl) { + constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl); + + f(dstr_idx); + }); } // unpacked span, this version support span with unpack(multi-arg) functor diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp index 1d9512b7f7..32c53d2f18 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp @@ -15,7 +15,7 @@ namespace ck_tile { // B is block window on block distributed tensor. // C is block distributed tensor template -struct BlockGemmWeightPreshuffleABQuantARegBRegCReg +struct BlockGemmWeightPreshuffleABQuantARegBRegCReg : public BlockGemmQuantBase { private: template @@ -121,6 +121,7 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg }; public: + using Base = BlockGemmQuantBase; using Traits = GemmTraits_; using Problem = remove_cvref_t; using BlockPolicy = remove_cvref_t; @@ -217,22 +218,6 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg }); }); }; - - auto q_block_tensor = aq_block_tensor; - constexpr bool SimpleDequant = - Traits::NQPerBlock == 1 && - AccTensor::get_distributed_spans()[I0].impl_.size() == 0; // c_transpose - if constexpr(SimpleDequant) - { - constexpr auto aq_spans = AQBlockTensor::get_distributed_spans(); - sweep_tile_span(aq_spans[I0], [&](auto im) { - sweep_tile_span(aq_spans[I1], [&](auto ik) { - q_block_tensor(make_tuple(im, ik)) *= - bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik)); - }); - }); - } - // hot loop: static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) { zero_accumulators(); static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) { @@ -265,29 +250,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg } }); }); - static_for_product, number>{}([&](auto mIter, - auto nIter) { - if constexpr(SimpleDequant) - { - constexpr auto tbuf_offset = - number{}, - c_warp_y_index_zeros)) / - CBlockTensor::PackedSize>{}; - - constexpr auto block_idx_m = tile_distributed_index{}; - constexpr auto block_idx_kq = tile_distributed_index{}; - - static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { - auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; - const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; - c_ref += acc_val * q_block_tensor(make_tuple(block_idx_m, block_idx_kq)); - }); - } - else - { - AQPickerCommon aq_picker( - aq_block_tensor); + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + AQPickerCommon aq_picker(aq_block_tensor); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { constexpr auto tbuf_offset = number{}, @@ -305,9 +270,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg return nIter * KPerBlockBQ + kQScale; } }(); - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - float b_scale_reg_f = - aq_picker.template cvt_scale_to_fp32(scale_reg); + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float b_scale_reg_f = Base::cvt_scale_to_fp32(scale_reg); static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { float a_scale_reg_f = aq_picker.template pick(); @@ -315,7 +279,7 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f; }); - } + }); }); }); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index d79bd31489..2c8b7031f5 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -291,66 +291,37 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase "C block tensor data type!"); constexpr auto warp_size = get_warp_size(); - // Start from AQ block tensor and then scale it using BQ; this represents - // the combined A/B quantization scales for the block. - auto q_block_tensor = aq_block_tensor; - constexpr bool SimpleDequant = - Traits::NQPerBlock == 1 && - CWarpTensor::get_distributed_spans()[I0{}].impl_.size() == 0; // c_transpose - if constexpr(SimpleDequant) - { - constexpr auto aq_spans = AQBlockTensor::get_distributed_spans(); - sweep_tile_span(aq_spans[I0{}], [&](auto im) { - sweep_tile_span(aq_spans[I1{}], [&](auto ik) { - q_block_tensor(make_tuple(im, ik)) *= - bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik)); - }); - }); - } - // hot loop: - static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { - static_for_product, number>{}([&](auto mIter, - auto nIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { CWarpTensor c_warp_tensor; - static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { - constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; - if constexpr(kIterInQScale == 0) - { - c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); - } - else - { - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - } - }); + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = + a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - if constexpr(SimpleDequant) - { - constexpr auto cw_spans = CWarpTensor::get_distributed_spans(); - sweep_tile_span(cw_spans[I1{}], [&](auto in) { - constexpr auto block_idx_m = tile_distributed_index{}; - constexpr auto block_idx_n = detail::make_tile_distributed_index( - merge_sequences(sequence{}, in.impl_)); - constexpr auto block_idx_kq = tile_distributed_index{}; - constexpr auto empty_idx = tile_distributed_index<>{}; - c_block_tensor(make_tuple(block_idx_m, block_idx_n)) += - c_warp_tensor(make_tuple(empty_idx, in)) * - q_block_tensor(make_tuple(block_idx_m, block_idx_kq)); + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = + b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + if constexpr(kIterInQScale == 0) + { + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } }); - } - else - { + constexpr auto tbuf_offset = number{}, @@ -435,7 +406,7 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase b_scale_reg_f); }); } - } + }); }); }); } diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 62ac2115cc..7507ff58cc 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -448,18 +448,46 @@ struct QuantGemmKernel // offset = bq_group_offset bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset); } + + aq_group_offset = 0; + aq_k_split_offset = 0; + } + else if constexpr(kQuantType == QuantType::ABQuantGrouped && !APreshuffleQuant) + { + using AQuantGroupSize = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; + + // Compute AQ K-group offset for this split-K batch. + // AQ tensor layout is RowMajor [M, QK_A] with stride [stride_AQ, 1]. + // Advancing to column aq_group_offset means a pointer offset of aq_group_offset + // elements (column stride = 1). + const index_t k_offset_aq = amd_wave_read_first_lane(k_id * KRead); + aq_group_offset = amd_wave_read_first_lane(k_offset_aq / AQuantGroupSize::kK); + aq_k_split_offset = amd_wave_read_first_lane(aq_group_offset); + + // Compute BQ K-group offset for this split-K batch. + // BQ tensor layout is ColumnMajor [N/kN, K/kK] with stride [K/kK, 1] for + // ABQuantGrouped. Advancing to column bq_group_offset means a pointer offset of + // bq_group_offset elements (column stride = 1). + const index_t k_offset_bq = amd_wave_read_first_lane(k_id * KRead); + bq_group_offset = amd_wave_read_first_lane(k_offset_bq / BQuantGroupSize::kK); + bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset); } else { bq_group_offset = 0; bq_k_split_offset = 0; + aq_group_offset = 0; + aq_k_split_offset = 0; } } index_t a_k_split_offset; index_t b_k_split_offset; - index_t bq_group_offset; // Logical offset in K-groups (K/kK dimension) - index_t bq_k_split_offset; // Memory pointer offset (accounting for layout/stride) + index_t aq_group_offset; // Logical offset in K-groups for AQ (K/kK dimension) + index_t aq_k_split_offset; // Memory pointer offset for AQ + index_t bq_group_offset; // Logical offset in K-groups for BQ (K/kK dimension) + index_t bq_k_split_offset; // Memory pointer offset for BQ (accounting for layout/stride) index_t splitted_k; }; @@ -532,7 +560,8 @@ struct QuantGemmKernel CK_TILE_DEVICE static auto MakeAQBlockWindow(const AQDataType* aq_ptr, const QuantGemmKernelArgs& kargs, const index_t i_m, - const index_t i_n) + const index_t i_n, + const index_t aq_group_offset = 0) { // Step 1: Create tensor view for AQ const auto& aq_tensor_view = [&]() { @@ -615,11 +644,14 @@ struct QuantGemmKernel } else if constexpr(kQuantType == QuantType::ABQuantGrouped && !APreshuffleQuant) { + // For split-K, aq_ptr is already offset by aq_k_split_offset elements. + // The remaining K-groups from this offset position = QK_A - aq_group_offset. + const index_t remaining_qk_a = kargs.QK_A - aq_group_offset; if constexpr(std::is_same_v) { return make_naive_tensor_view( aq_ptr, - make_tuple(kargs.M, kargs.QK_A), + make_tuple(kargs.M, remaining_qk_a), make_tuple(kargs.stride_AQ, 1), number{}, number<1>{}); @@ -628,9 +660,8 @@ struct QuantGemmKernel { return make_naive_tensor_view( aq_ptr, - make_tuple(kargs.M, kargs.QK_A), + make_tuple(kargs.M, remaining_qk_a), make_tuple(1, kargs.stride_AQ), - number{}, number<1>{}); } @@ -1100,26 +1131,32 @@ struct QuantGemmKernel CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs) { - // Split-K is supported for BQuantGrouped mode without preshuffle + // Split-K is supported for BQuantGrouped (without preshuffle) and + // ABQuantGrouped (without APreshuffleQuant) modes. if(kargs.k_batch != 1) { constexpr bool is_bquant_non_preshuffle = (kQuantType == QuantType::BQuantGrouped) && !BPreshuffleQuant; - if constexpr(!is_bquant_non_preshuffle) + constexpr bool is_abquant_non_preshuffle = + (kQuantType == QuantType::ABQuantGrouped) && !APreshuffleQuant; + constexpr bool is_splitk_supported = + is_bquant_non_preshuffle || is_abquant_non_preshuffle; + + if constexpr(!is_splitk_supported) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { CK_TILE_ERROR("Conditions not met for Kbatch >1 ! " - "Split-K only supported for BQuantGrouped without preshuffle."); + "Split-K is supported for BQuantGrouped without preshuffle " + "and ABQuantGrouped without APreshuffleQuant."); } return false; } else { - using BQuantGroupSize = remove_cvref_t; - constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2); - const index_t K_t = kargs.k_batch * K1; - const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; + constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2); + const index_t K_t = kargs.k_batch * K1; + const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; // per-batch K read size constexpr index_t BPackedSize = ck_tile::numeric_traits>::PackedSize; @@ -1137,22 +1174,67 @@ struct QuantGemmKernel return false; } - // Constraint 2: KRead must align with quantization group boundaries. - // Each split-K batch reads KRead consecutive K elements. If KRead is not - // a multiple of BQuantGroupSize::kK, the batch will span partial quantization - // groups, requiring split access to a quantization scale. This violates the - // atomic processing requirement where each batch must work with complete groups. - if(KRead % BQuantGroupSize::kK != 0) + // Constraint 2: KRead must align with B quantization group boundaries. + if constexpr(is_bquant_non_preshuffle || is_abquant_non_preshuffle) { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + using BQuantGroupSize = remove_cvref_t; + if(KRead % BQuantGroupSize::kK != 0) { - CK_TILE_ERROR("Split-K batch size must be aligned with quantization group " - "size! KRead=" + - std::to_string(KRead) + - " is not divisible by BQuantGroupSize::kK=" + - std::to_string(BQuantGroupSize::kK)); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Split-K batch size must be aligned with B quantization group " + "size! KRead=" + + std::to_string(KRead) + + " is not divisible by BQuantGroupSize::kK=" + + std::to_string(BQuantGroupSize::kK)); + } + return false; + } + } + + // Constraint 3: KRead must align with A quantization group boundaries + // (only needed for ABQuantGrouped since AQ also indexes into K). + if constexpr(is_abquant_non_preshuffle) + { + using AQuantGroupSize = remove_cvref_t; + if(KRead % AQuantGroupSize::kK != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Split-K batch size must be aligned with A quantization group " + "size! KRead=" + + std::to_string(KRead) + + " is not divisible by AQuantGroupSize::kK=" + + std::to_string(AQuantGroupSize::kK)); + } + return false; + } + } + + // Constraint 4: per-batch K must span at least 2 K_Tile iterations. + // The software-pipelined GEMM kernels (CompV3 family) prefetch one tile + // ahead and require num_loop >= 2 per batch. When KRead == KPerBlock + // (i.e. per_batch_num_loop == 1) the prefetch would read the tile + // belonging to the next split-K batch, producing incorrect results. + { + const index_t per_batch_num_loop = + KRead / static_cast(TilePartitioner::KPerBlock); + if(per_batch_num_loop < 2) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Split-K requires at least 2 K-tile iterations per batch. " + "KRead=" + + std::to_string(KRead) + " < 2 * KPerBlock=" + + std::to_string(2 * + static_cast(TilePartitioner::KPerBlock)) + + ". Increase K or decrease k_batch."); + } + return false; } - return false; } } } @@ -1243,6 +1325,18 @@ struct QuantGemmKernel if constexpr(std::is_same_v) { + // For RowMajor C, M is the row dimension — check M alignment here because + // ALayout=RowMajor does not check M (it only checks K), leaving a gap for + // the RowMajorA + RowMajorC combination. + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } + return false; + } if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) @@ -1315,7 +1409,10 @@ struct QuantGemmKernel MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); - const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n); + // Note: Pass aq_group_offset so the tensor view dimension reflects + // the remaining K-groups from the split-K offset position. + const auto& aq_block_window = MakeAQBlockWindow( + aq_ptr, kargs, block_idx_m, block_idx_n, splitk_batch_offset.aq_group_offset); // Note: Pass bq_group_offset so the tensor view dimension reflects // the remaining K-groups from the split-K offset position. const auto& bq_block_window = MakeBQBlockWindow( @@ -1445,7 +1542,10 @@ struct QuantGemmKernel static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; const BDataType* b_ptr = static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; - const AQDataType* aq_ptr = static_cast(kargs.aq_ptr); + // For ABQuantGrouped split-K, aq_ptr is offset by aq_k_split_offset elements to point + // to the start of this batch's AQ K-groups (aq_group_offset columns in RowMajor AQ). + const AQDataType* aq_ptr = + static_cast(kargs.aq_ptr) + splitk_batch_offset.aq_k_split_offset; const BQDataType* bq_ptr = static_cast(kargs.bq_ptr) + splitk_batch_offset.bq_k_split_offset; CDataType* c_ptr = static_cast(kargs.c_ptr); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index befe35ddc1..a7a64518b8 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -108,14 +108,10 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(), BQuantGroupSize::GetName()); // clang-format on } - /** - * @tparam nloop The number of iterations in the hot loop, - * used to normalize scheduling costs. - */ + template CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() { - static_assert(nloop > 0, "nloop must be greater than 0"); // Estimated number of VMEM vector loads for A per block: // total A bytes / (threads per block * vector width) constexpr index_t Aload_inst = @@ -138,13 +134,12 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe // Total VMEM load instructions (A + B + quant data) constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst; // Approximate number of LDS reads per block - constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle / nloop; + constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle; // Approximate number of LDS writes per block // (e.g., writing A from VMEM into LDS once per A load) constexpr index_t ds_write_inst = Aload_inst; // Number of MFMA instructions per wave for one block tile: - constexpr index_t mfma_inst = - ((kMPerBlock / WG::kM) / nloop) * ((kNPerBlock / WG::kN) / nloop); + constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); // How often (in MFMA units) we should insert DS (LDS) operations. constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); // How often (in MFMA units) we should insert VMEM buffer loads. @@ -181,7 +176,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe } // Always mark some VALU work in the loop to reflect auxiliary scalar // or vector ALU instructions that coexist with MFMA (Blockscale calculation). - __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 4, 0); // VALU + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU }); }); __builtin_amdgcn_sched_barrier(0); @@ -409,6 +404,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe // Prefetch A1 a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // initialize C @@ -437,7 +433,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe while(iCounter > 0) { __builtin_amdgcn_sched_barrier(0); - // Prefill A(2i+1) ds_write + // Prefill A(2i+1) a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window_pong, a_block_tile_tmp); @@ -465,14 +461,10 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - // prefetch Q(2i+1) aq_block_tile_2 = load_tile(aq_copy_dram_window); move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); bq_block_tile_2 = load_tile(bq_copy_dram_window); move_tile_window(bq_copy_dram_window, bq_dram_tile_window_step); - - // Preload A(2i+1) ds_read static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; @@ -494,8 +486,6 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - // prefetch Q(2i+1) aq_block_tile = load_tile(aq_copy_dram_window); move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); bq_block_tile = load_tile(bq_copy_dram_window); @@ -517,7 +507,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe aq_block_tile_2, bq_block_tile_2, a_warp_windows_pong); - // Preload A(2i+2) ds_read + static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; @@ -557,7 +547,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe aq_block_tile, bq_block_tile, a_warp_windows_ping); - // Preload A ds_read + static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index cdeae9b38d..ede2665c76 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -81,6 +81,17 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_abquant_preshuffle_preshuffleQuant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # ABQuant split-K tests + add_gtest_executable(test_tile_gemm_quant_abquant_splitk_decode + test_gemm_quant_abquant_splitk_decode.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_splitk_decode PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_abquant_splitk_prefill + test_gemm_quant_abquant_splitk_prefill.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_splitk_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_base test_gemm_quant_abquant_a4w4_base.cpp ) @@ -268,7 +279,14 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") test_tile_gemm_quant_abquant_base test_tile_gemm_quant_abquant_padding test_tile_gemm_quant_abquant_preshuffle + test_tile_gemm_quant_abquant_preshuffle_preshuffleQuant test_tile_gemm_quant_abquant_preshuffleQuant + test_tile_gemm_quant_abquant_a4w4_base + test_tile_gemm_quant_abquant_a4w4_padding + test_tile_gemm_quant_abquant_a4w4_preshuffle + # ABQuant split-K tests + test_tile_gemm_quant_abquant_splitk_decode + test_tile_gemm_quant_abquant_splitk_prefill # BQuant tests test_tile_gemm_quant_bquant_1d_128 test_tile_gemm_quant_bquant_1d_64 @@ -276,6 +294,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") test_tile_gemm_quant_bquant_2d_medium_n test_tile_gemm_quant_bquant_2d_large_n test_tile_gemm_quant_bquant_transpose + # BQuant split-K tests + test_tile_gemm_quant_bquant_splitk_decode + test_tile_gemm_quant_bquant_splitk_prefill # BQuant preshuffle tests test_tile_gemm_quant_bquant_preshuffle_decode_1d test_tile_gemm_quant_bquant_preshuffle_prefill_1d diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_base.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_base.cpp index 6e3e95fccf..2524f7887f 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_base.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_base.cpp @@ -28,7 +28,7 @@ using GroupSize2D128N = ck_tile::QuantGroupShape> // QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout> // clang-format off using ABQuantTypes = ::testing::Types< - // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) + // 1D BScales; PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) std::tuple, std::tuple, std::tuple, @@ -36,12 +36,13 @@ using ABQuantTypes = ::testing::Types< std::tuple, std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + // 2D B-scales; PreshuffleQuant = false && TransposeC = true (RCR layout with RowMajor AQ) + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp index 793c9bd1df..a317a413ce 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp @@ -28,9 +28,11 @@ using GroupSize2D128N = ck_tile::QuantGroupShape> // QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout> // clang-format off using ABQuantPreshuffleBTypes = ::testing::Types< - // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) + // 1D B-scales; PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) std::tuple, - std::tuple + + /// 2D B-scales; PreshuffleQuant = false && TransposeC = true (RCR layout with RowMajor AQ) + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_preshuffleQuant.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_preshuffleQuant.cpp index f061c7dd47..0b845ac16d 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_preshuffleQuant.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_preshuffleQuant.cpp @@ -28,8 +28,8 @@ using GroupSize2D128N = ck_tile::QuantGroupShape> // QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout> // clang-format off using ABQuantPreshuffleQuantTypes = ::testing::Types< - std::tuple, - std::tuple + std::tuple, GroupSize, GroupSize, ColumnMajor>, + std::tuple, GroupSize, GroupSize2D128N, ColumnMajor> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_splitk_decode.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_splitk_decode.cpp new file mode 100644 index 0000000000..7732779d7a --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_splitk_decode.cpp @@ -0,0 +1,126 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using ABQuantGrouped = + std::integral_constant; +using GroupSize1x1x128 = ck_tile::QuantGroupShape>; +using GroupSize1x128x128 = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant split-K tests - Decode shape +// GemmConfigDecode: M_Tile=16, N_Tile=64, K_Tile=256, kPadK=false +// Constraints: M % 16 == 0, N % 64 == 0, K % (k_batch * 256) == 0 +// +// Tuple format: +// clang-format off +using ABQuantSplitKDecodeTypes = ::testing::Types< + // GroupSize 1x1x128 (kK=128 for both A and B, kN=1) + std::tuple, + std::tuple, + // GroupSize 1x128x128 for B (kK=128, kN=128) + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for ABQuant split-K Decode +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantSplitKDecodeTypes); + +// ---- k_batch=2 ---------------------------------------------------------------- +// Note: K=512 (= 2*K_Tile) is excluded because KRead=K_Tile=256, giving +// per_batch_num_loop=1 which the software-pipelined kernel cannot handle. + +TYPED_TEST(TestCkTileGemmABQuant, SplitK2_MedK_BaseShape) +{ + // K=1024=4*256: standard decode decode shape + this->run_test_with_validation(32, 64, 1024, 2); +} + +TYPED_TEST(TestCkTileGemmABQuant, SplitK2_LargeK_WideN) +{ + // K=2048, larger N (multiple of N_Tile=64) + this->run_test_with_validation(32, 256, 2048, 2); +} + +TYPED_TEST(TestCkTileGemmABQuant, SplitK2_LargeK_TallM) +{ + // K=4096, larger M (multiple of M_Tile=16) + this->run_test_with_validation(64, 64, 4096, 2); +} + +// ---- k_batch=3 ---------------------------------------------------------------- +// Note: K=768 (= 3*K_Tile) excluded: per_batch_num_loop=1. + +TYPED_TEST(TestCkTileGemmABQuant, SplitK3_MedK_BaseShape) +{ + // K=1536=6*256 + this->run_test_with_validation(32, 64, 1536, 3); +} + +TYPED_TEST(TestCkTileGemmABQuant, SplitK3_LargeK_BaseShape) +{ + // K=3072=12*256 + this->run_test_with_validation(32, 64, 3072, 3); +} + +// ---- k_batch=4 ---------------------------------------------------------------- +// Note: K=1024 (= 4*K_Tile) excluded: per_batch_num_loop=1. + +TYPED_TEST(TestCkTileGemmABQuant, SplitK4_MedK_BaseShape) +{ + // K=2048=8*256 + this->run_test_with_validation(32, 64, 2048, 4); +} + +TYPED_TEST(TestCkTileGemmABQuant, SplitK4_LargeK_WideN) +{ + // K=4096, wider N + this->run_test_with_validation(32, 128, 4096, 4); +} + +// ---- k_batch=5 ---------------------------------------------------------------- +// Note: K=1280 (= 5*K_Tile) excluded: per_batch_num_loop=1. + +TYPED_TEST(TestCkTileGemmABQuant, SplitK5_MedK_BaseShape) +{ + // K=2560=10*256 + this->run_test_with_validation(32, 64, 2560, 5); +} + +// ---- k_batch=6 ---------------------------------------------------------------- +// Note: K=1536 (= 6*K_Tile) excluded: per_batch_num_loop=1. + +TYPED_TEST(TestCkTileGemmABQuant, SplitK6_LargeK_BaseShape) +{ + // K=3072=12*256 + this->run_test_with_validation(32, 64, 3072, 6); +} + +// ---- k_batch=8 ---------------------------------------------------------------- +// Note: K=2048 (= 8*K_Tile) excluded: per_batch_num_loop=1. + +TYPED_TEST(TestCkTileGemmABQuant, SplitK8_LargeK_BaseShape) +{ + // K=4096=16*256 + this->run_test_with_validation(32, 64, 4096, 8); +} + +TYPED_TEST(TestCkTileGemmABQuant, SplitK8_LargeK_LargeMN) +{ + // K=4096, larger M and N + this->run_test_with_validation(48, 192, 4096, 8); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_splitk_prefill.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_splitk_prefill.cpp new file mode 100644 index 0000000000..f746983d06 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_splitk_prefill.cpp @@ -0,0 +1,131 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using ABQuantGrouped = + std::integral_constant; +using GroupSize1x1x128 = ck_tile::QuantGroupShape>; +using GroupSize1x128x128 = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant split-K tests - Prefill shape +// GemmConfigPrefill: M_Tile=128, N_Tile=128, K_Tile=128, kPadK=false +// Constraints: M % 128 == 0, N % 128 == 0, K % (k_batch * 128) == 0 +// +// Tuple format: +// clang-format off +using ABQuantSplitKPrefillTypes = ::testing::Types< + // GroupSize 1x1x128 (kK=128 for both A and B, kN=1) + std::tuple, + std::tuple, + // GroupSize 1x128x128 for B (kK=128, kN=128) + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for ABQuant split-K Prefill +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantSplitKPrefillTypes); + +// ---- k_batch=2 ---------------------------------------------------------------- +// Note: K=256 (= 2*K_Tile) excluded: KRead=K_Tile=128, per_batch_num_loop=1. + +TYPED_TEST(TestCkTileGemmABQuant, SplitK2_MedK_BaseShape) +{ + // K=1024=8*128 + this->run_test_with_validation(128, 128, 1024, 2); +} + +TYPED_TEST(TestCkTileGemmABQuant, SplitK2_LargeK_WideN) +{ + // K=2048, wider N + this->run_test_with_validation(128, 256, 2048, 2); +} + +TYPED_TEST(TestCkTileGemmABQuant, SplitK2_LargeK_TallM) +{ + // K=4096, taller M + this->run_test_with_validation(256, 128, 4096, 2); +} + +// ---- k_batch=3 ---------------------------------------------------------------- +// Note: K=384 (= 3*K_Tile) excluded: per_batch_num_loop=1. + +TYPED_TEST(TestCkTileGemmABQuant, SplitK3_MedK_BaseShape) +{ + // K=768=6*128 + this->run_test_with_validation(128, 128, 768, 3); +} + +TYPED_TEST(TestCkTileGemmABQuant, SplitK3_LargeK_BaseShape) +{ + // K=3072=24*128 + this->run_test_with_validation(128, 128, 3072, 3); +} + +// ---- k_batch=4 ---------------------------------------------------------------- +// Note: K=512 (= 4*K_Tile) excluded: per_batch_num_loop=1. + +TYPED_TEST(TestCkTileGemmABQuant, SplitK4_MedK_BaseShape) +{ + // K=2048=16*128 + this->run_test_with_validation(128, 128, 2048, 4); +} + +TYPED_TEST(TestCkTileGemmABQuant, SplitK4_LargeK_LargeMN) +{ + // K=4096, larger M and N + this->run_test_with_validation(256, 256, 4096, 4); +} + +// ---- k_batch=5 ---------------------------------------------------------------- +// Note: K=640 (= 5*K_Tile) excluded: per_batch_num_loop=1. + +TYPED_TEST(TestCkTileGemmABQuant, SplitK5_MedK_BaseShape) +{ + // K=1280=10*128 + this->run_test_with_validation(128, 128, 1280, 5); +} + +TYPED_TEST(TestCkTileGemmABQuant, SplitK5_LargeK_BaseShape) +{ + // K=2560=20*128 + this->run_test_with_validation(128, 128, 2560, 5); +} + +// ---- k_batch=6 ---------------------------------------------------------------- +// Note: K=768 (= 6*K_Tile) excluded: per_batch_num_loop=1. + +TYPED_TEST(TestCkTileGemmABQuant, SplitK6_LargeK_BaseShape) +{ + // K=3072=24*128 + this->run_test_with_validation(128, 128, 3072, 6); +} + +// ---- k_batch=8 ---------------------------------------------------------------- +// Note: K=1024 (= 8*K_Tile) excluded: per_batch_num_loop=1. + +TYPED_TEST(TestCkTileGemmABQuant, SplitK8_MedK_BaseShape) +{ + // K=2048=16*128 + this->run_test_with_validation(128, 128, 2048, 8); +} + +TYPED_TEST(TestCkTileGemmABQuant, SplitK8_LargeK_LargeMN) +{ + // K=4096, larger M and N + this->run_test_with_validation(256, 256, 4096, 8); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 5a26034182..1266fa5889 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -158,6 +158,10 @@ struct GemmConfigPreshuffleBPrefill : public GemmConfigPrefill static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; }; +struct GemmConfigPreshuffleBPrefillTransposeC : public GemmConfigPreshuffleBPrefill +{ + static constexpr bool TransposeC = true; +}; struct GemmConfigPreshuffleQuantPrefill : public GemmConfigPrefill { @@ -170,14 +174,18 @@ struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBP static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; +template struct GemmConfigPreshuffleBPreshuffleQuantPrefill : public GemmConfigPreshuffleBPrefill { static constexpr bool BPreshuffleQuant = true; + static constexpr bool TransposeC = TransposeC_; }; +template struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode { static constexpr bool BPreshuffleQuant = true; + static constexpr bool TransposeC = TransposeC_; }; template @@ -980,7 +988,10 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBaseis_row_major(ALayout{})); @@ -1091,6 +1102,13 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase 1), the kernel uses atomic_add to accumulate partial results + // into C. Zero the output buffer before launching so atomic additions start from zero. + if(k_batch > 1) + { + c_m_n_dev_buf.SetZero(); + } + // Create args for kernel execution ck_tile::QuantGemmHostArgs args{ a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr @@ -1098,7 +1116,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBaseis_row_major(CLayout{}))); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data()); - // Calculate error tolerances + // Calculate error tolerances (adjusted for split-K accumulation error) const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); const auto rtol_atol = this->template calculate_rtol_atol( - K, 1, max_accumulated_value); + K, k_batch, max_accumulated_value); // Validate results bool pass = ck_tile::check_err(c_m_n_dev_result, @@ -1151,7 +1169,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase{})); EXPECT_TRUE(pass) << "ABQuantGrouped validation failed with M=" << M << ", N=" << N - << ", K=" << K; + << ", K=" << K << ", k_batch=" << k_batch; if(!pass) {