From 762398fb7c07da8d1d01edce8f50809ca1b08d73 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Tue, 3 Feb 2026 09:21:59 +0000 Subject: [PATCH] chore: split up example instances for abquant --- .../38_block_scale_gemm/CMakeLists.txt | 7 +- .../gemm_abquant_quantgrouped.cpp | 198 ------------------ .../gemm_abquant_quantgrouped_bf8.cpp | 48 +++++ .../gemm_abquant_quantgrouped_fp4.cpp | 33 +++ .../gemm_abquant_quantgrouped_fp8.cpp | 79 +++++++ ...m_abquant_quantgrouped_preshuffleb_bf8.cpp | 48 +++++ ...m_abquant_quantgrouped_preshuffleb_fp4.cpp | 33 +++ ...m_abquant_quantgrouped_preshuffleb_fp8.cpp | 49 +++++ 8 files changed, 296 insertions(+), 199 deletions(-) delete mode 100644 example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_bf8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp4.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_bf8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_fp4.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_fp8.cpp diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 13cbcc8b55..42e4c77fc8 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -13,7 +13,12 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") set(EXE_NAME tile_example_gemm_quant) add_executable(${EXE_NAME} gemm_quant.cpp - gemm_abquant_quantgrouped.cpp + gemm_abquant_quantgrouped_bf8.cpp + gemm_abquant_quantgrouped_fp4.cpp + gemm_abquant_quantgrouped_fp8.cpp + gemm_abquant_quantgrouped_preshuffleb_bf8.cpp + gemm_abquant_quantgrouped_preshuffleb_fp4.cpp + gemm_abquant_quantgrouped_preshuffleb_fp8.cpp gemm_aquant_quantgrouped.cpp gemm_aquant_quantgrouped_preshufflequant.cpp gemm_bquant_quantgrouped_bf8i4.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp deleted file mode 100644 index e4e0503b5a..0000000000 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "run_gemm_quant_example.inc" - -template -using GemmConfig = GemmConfigABQuantPrefill; - -template -using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill; - -// template -// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode; - -static auto _ = []() { - auto& lut = get_kernel_lut(); - lut[hash_multiple_strings({"fp8", - "abquant", - "non-preshuffleb", - "non-preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "abquant", - "non-preshuffleb", - "non-preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "abquant", - "non-preshuffleb", - "non-preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "abquant", - "non-preshuffleb", - "non-preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "abquant", - "preshuffleb", - "non-preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "abquant", - "preshuffleb", - "non-preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "abquant", - "preshuffleb", - "non-preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "abquant", - "preshuffleb", - "non-preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "abquant", - "non-preshuffleb", - "preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "abquant", - "non-preshuffleb", - "preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp4", "abquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - return 0; -}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_bf8.cpp new file mode 100644 index 0000000000..46d690ca2e --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_bf8.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigABQuantPrefill; + +template +using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill; + +// template +// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode; + +static auto _ = []() { + auto& lut = get_kernel_lut(); + lut[hash_multiple_strings({"bf8", + "abquant", + "non-preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "non-preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp4.cpp new file mode 100644 index 0000000000..7b411b59e4 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp4.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigABQuantPrefill; + +template +using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill; + +// template +// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode; + +static auto _ = []() { + auto& lut = get_kernel_lut(); + lut[hash_multiple_strings( + {"fp4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp8.cpp new file mode 100644 index 0000000000..f2a3694c07 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp8.cpp @@ -0,0 +1,79 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigABQuantPrefill; + +template +using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill; + +// template +// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode; + +static auto _ = []() { + auto& lut = get_kernel_lut(); + lut[hash_multiple_strings({"fp8", + "abquant", + "non-preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "non-preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "non-preshuffleb", + "preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "non-preshuffleb", + "preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_bf8.cpp new file mode 100644 index 0000000000..ba4c8f743e --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_bf8.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigABQuantPrefill; + +template +using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill; + +// template +// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode; + +static auto _ = []() { + auto& lut = get_kernel_lut(); + lut[hash_multiple_strings({"bf8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_fp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_fp4.cpp new file mode 100644 index 0000000000..55523c151e --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_fp4.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigABQuantPrefill; + +template +using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill; + +// template +// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode; + +static auto _ = []() { + auto& lut = get_kernel_lut(); + lut[hash_multiple_strings( + {"fp4", "abquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_fp8.cpp new file mode 100644 index 0000000000..4454749a9d --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_fp8.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigABQuantPrefill; + +template +using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill; + +// template +// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode; + +static auto _ = []() { + auto& lut = get_kernel_lut(); + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + + return 0; +}();