From 76e50bb65b2d02e163c75a951c2f7820556b1256 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Thu, 13 Nov 2025 08:15:17 +0000 Subject: [PATCH] Merge commit 'fb41a7b73be5b686611e3bc75668cb8025252d8d' into develop --- example/ck_tile/03_gemm/run_gemm_example.inc | 64 +-- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 43 -- .../38_block_scale_gemm/CMakeLists.txt | 15 +- example/ck_tile/38_block_scale_gemm/README.md | 42 +- .../gemm_aquant_quantgrouped.cpp | 53 +++ .../gemm_bquant_quantgrouped_prefill_bf8.cpp | 47 ++ ...gemm_bquant_quantgrouped_prefill_bf8i4.cpp | 49 ++ .../gemm_bquant_quantgrouped_prefill_fp8.cpp | 47 ++ ...gemm_bquant_quantgrouped_prefill_fp8i4.cpp | 49 ++ ...quant_quantgrouped_preshuffleb_prefill.cpp | 53 +++ .../38_block_scale_gemm/gemm_quant.cpp | 130 ++++++ .../38_block_scale_gemm/gemm_quant_basic.cpp | 428 ------------------ .../38_block_scale_gemm/gemm_quant_rowcol.cpp | 30 ++ .../38_block_scale_gemm/gemm_quant_tensor.cpp | 30 ++ .../38_block_scale_gemm/gemm_utils.hpp | 54 +-- .../run_gemm_quant_example.inc | 273 ++++++++++- .../test_gemm_pipeline_util.hpp | 43 +- 17 files changed, 807 insertions(+), 643 deletions(-) create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8i4.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8i4.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_prefill.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_quant.cpp delete mode 100644 example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_quant_rowcol.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_quant_tensor.cpp diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 703ab810d8..1c57a03c97 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -2,6 +2,7 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/host/permute_pk_int4.hpp" +#include "ck_tile/host/tensor_shuffle_utils.hpp" template static constexpr inline auto is_row_major(Layout layout_) @@ -172,69 +173,6 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, return ave_time; } -template -auto shuffle_b(const ck_tile::HostTensor& t) -{ - assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - - if(ck_tile::is_gfx12_supported()) - { - constexpr int divisor = 2; - constexpr int kABK1PerLane = 8; - constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, - kABK0PerLane, - divisor, - kABK1PerLane}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); - } - else - { - int divisor = 1; - if(ck_tile::is_gfx11_supported()) - { - divisor = 1; - } - else - { - assert(is_wave32() == false); - divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; - } - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); - } -} - -template -auto shuffle_b_permuteN(const ck_tile::HostTensor& t) -{ - assert(t.get_lengths().size() == 2); - - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; - constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, - GemmConfig::N_Warp, - GemmConfig::N_Warp_Tile, - NRepeat, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); -} - template bool do_verify(const ck_tile::HostTensor& c_m_n_dev_result, const ck_tile::HostTensor& c_m_n_ref, diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 049957cbfd..9b14efb561 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -323,49 +323,6 @@ inline std::size_t get_workspace_size(const std::vector& gem return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>); } -template -auto shuffle_b(const ck_tile::HostTensor& t) -{ - assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - - if(ck_tile::is_gfx12_supported()) - { - constexpr int divisor = 2; - constexpr int kABK1PerLane = 8; - constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, - kABK0PerLane, - divisor, - kABK1PerLane}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); - } - else - { - int divisor = 1; - if(ck_tile::is_gfx11_supported()) - { - divisor = 1; - } - else - { - assert(is_wave32() == false); - divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; - } - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); - } -} - template +using GemmConfig = GemmConfigQuant; + +void aquant_quantgrouped_instance_factory( + std::unordered_map>& lut) +{ + using QuantGroupSize = ck_tile::QuantGroupShape>; + lut[hash_multiple_strings({"fp8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser& + arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser& + arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8i4", "aquant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8i4", "aquant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8.cpp new file mode 100644 index 0000000000..cb9f8b62cf --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_bf8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); +#ifndef CK_GFX950_SUPPORT + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +#endif + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8i4.cpp new file mode 100644 index 0000000000..33ae3bc4a9 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8i4.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_bf8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); +#ifndef CK_GFX950_SUPPORT + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +#endif + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8.cpp new file mode 100644 index 0000000000..526c35b081 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_fp8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); +#ifndef CK_GFX950_SUPPORT + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +#endif + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8i4.cpp new file mode 100644 index 0000000000..4b2a8efb14 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8i4.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_fp8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); +#ifndef CK_GFX950_SUPPORT + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +#endif + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_prefill.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_prefill.cpp new file mode 100644 index 0000000000..d9591bb588 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_prefill.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill; + +void bquant_quantgrouped_preshuffleb_instance_factory( + std::unordered_map>& lut) +{ + using QuantGroupSize = ck_tile::QuantGroupShape>; + lut[hash_multiple_strings( + {"fp8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp new file mode 100644 index 0000000000..a35f867f5d --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/permute_pk_int4.hpp" +#include "ck_tile/host/tensor_shuffle_utils.hpp" +#include "gemm_utils.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("h", "false", "Print help message") + .insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row or Column") + .insert("b_layout", "C", "B tensor data layout - Row or Column") + .insert("bq_layout", "C", "Bq tensor data layout - Row or Column") + .insert("c_layout", "R", "C tensor data layout - Row or Column") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_q", "0", "Tensor AQ stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "1", "0: No validation, 1: Validation on CPU, 2: Validation on GPU") + .insert("prec", + "fp8", + "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " + "or bf8i4") + .insert("warmup", "50", "Number of iterations before benchmarking the kernel") + .insert("repeat", "1000", "Number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "SplitK value") + .insert("device", "0", "Device id that will be used to run the kernel") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("flush_cache", "true", "Flush cache before running the kernel") + .insert("rotating_count", "1000", "Rotating count") + .insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol") + .insert("preshuffleb", "false", "Enable preshuffle of tensor B") + .insert("group_size", + "1x1x128", + "Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +auto gen_lut_key(const ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + std::string quant_mode = arg_parser.get_str("quant_mode"); + + std::vector params = {data_type, quant_mode}; + + if(quant_mode == "bquant") + { + std::string preshuffleb = + arg_parser.get_bool("preshuffleb") ? "preshuffleb" : "non-preshuffleb"; + params.push_back(preshuffleb); + } + if(quant_mode != "rowcol" && quant_mode != "tensor") + { + // NOTE: rowcol and tensor pipeline do not use group size + std::string group_size_str = arg_parser.get_str("group_size"); + params.push_back(group_size_str); + } + + return hash_multiple_strings(params); +} + +void aquant_quantgrouped_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_fp8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_bf8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_fp8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_bf8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_instance_factory( + std::unordered_map>& lut); +void quant_rowcol_instance_factory( + std::unordered_map>& lut); +void quant_tensor_instance_factory( + std::unordered_map>& lut); + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result || arg_parser.get_bool("h")) + { + arg_parser.print(); + return -1; + } + + auto device_id = arg_parser.get_int("device"); + std::cout << "Device ID: " << device_id << std::endl; + ck_tile::hip_check_error(hipSetDevice(device_id)); + + std::unordered_map> lut; + aquant_quantgrouped_instance_factory(lut); + bquant_quantgrouped_fp8_instance_factory(lut); + bquant_quantgrouped_bf8_instance_factory(lut); + bquant_quantgrouped_fp8i4_instance_factory(lut); + bquant_quantgrouped_bf8i4_instance_factory(lut); + bquant_quantgrouped_preshuffleb_instance_factory(lut); + quant_rowcol_instance_factory(lut); + quant_tensor_instance_factory(lut); + + auto key = gen_lut_key(arg_parser); + + if(lut.find(key) != lut.end()) + { + return lut[key](arg_parser); + } + else + { + std::cerr + << "Error: Combination of prec, quant_mode, preshuffleb, and group_size not supported." + << std::endl; + return -1; + } +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp deleted file mode 100644 index d605a2b780..0000000000 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ /dev/null @@ -1,428 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -// This example demonstrates 2D block scale quantization (N×K) for BQuant -// using non-preshuffled configuration. -// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example -// This is currently done separately to avoid too verbose dispatching. - -#include -#include -#include -#include -#include -#include - -#include "ck_tile/core/config.hpp" -#include "ck_tile/host.hpp" -#include "gemm_utils.hpp" - -template -float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s) -{ - static_assert(std::is_same_v); - using ComputeDataType = std::conditional_t; - - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile:: - sequence>; - - using TilePartitioner = ck_tile::GemmTile1DPartitioner; - - using GemmTraits = ck_tile::TileGemmQuantTraits; - - using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; - - // This example only supports BQuant (no AQuant) - // For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3 - using BaseGemmPipeline = std::conditional_t< - GemmConfig::PreshuffleB == true, - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, - ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3>; - - const ck_tile::index_t K_split = - (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr bool transpose_c = false; - - // row-col and tensor quants use the regular pipeline, A/B quants use their own - using PipelineProblem = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant, - ck_tile::GemmRowColTensorQuantPipelineProblem, - std::conditional_t, - ck_tile::GemmBQuantPipelineProblem>>; - - using GemmPipeline = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant, - ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t< - QuantMode == ck_tile::QuantType::AQuantGrouped, - ck_tile::AQuantGemmPipelineAgBgCrMem, // memory pipeline hardcoded - // for aquant - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; - - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - GemmConfig::TiledMMAPermuteN>>; - using Kernel = - ck_tile::QuantGemmKernel; - - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(args.k_batch != 1) - { - throw std::runtime_error("split-k is not supported yet!"); - } - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << PipelineProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - float ave_time = 0; - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - ck_tile::RotatingMemWrapper - rotating_mem( - kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString( - hipMemsetAsync(args.c_ptr, - 0, - args.M * args.N * sizeof(typename TypeConfig::CDataType), - s.stream_id_)); - }; - ave_time = ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - } - else - { - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - } - - return ave_time; - }; - return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); -} - -#include "run_gemm_quant_example.inc" - -template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) -{ - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - if((QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) && - GemmConfig::PreshuffleB) - { - throw std::runtime_error( - "Preshuffling weight matrix is not supported for AQuant or RowColQuant"); - } - - if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) - { - if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported memory layout for the input matrices!"); - } - } - else - { - throw std::runtime_error("Unsupported data type for A."); - } - - return 0; -} - -// Forward declaration for dispatch function -template