diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 7358d4d749..c59363fc37 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -6,8 +6,17 @@ endif() list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") - add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp) - target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + set(EXE_NAME tile_example_gemm_quant) + add_executable(${EXE_NAME} EXCLUDE_FROM_ALL + gemm_quant.cpp + gemm_aquant_quantgrouped.cpp + gemm_bquant_quantgourped_prefill_bf8i4.cpp + gemm_bquant_quantgourped_prefill_fp8i4.cpp + gemm_bquant_quantgourped_prefill_bf8.cpp + gemm_bquant_quantgourped_prefill_fp8.cpp + gemm_bquant_quantgourped_preshuffleb_prefill.cpp + ) + target_compile_options(${EXE_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile quant gemm tests for current target") endif() diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp new file mode 100644 index 0000000000..5e83963d20 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +// This example demonstrates 2D block scale quantization (N×K) for BQuant +// using non-preshuffled configuration. +// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example +// This is currently done separately to avoid too verbose dispatching. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigQuant; + +void aquant_quantgrouped_instance_factory( + std::unordered_map>& lut) +{ + using QuantGroupSize = ck_tile::QuantGroupShape>; + lut[hash_multiple_strings( + {"fp8", "aquant", "non-preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& + arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8", "aquant", "non-preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& + arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8i4", "aquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8i4", "aquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgourped_prefill_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgourped_prefill_bf8.cpp new file mode 100644 index 0000000000..11f5374ed7 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgourped_prefill_bf8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +// This example demonstrates 2D block scale quantization (N×K) for BQuant +// using non-preshuffled configuration. +// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example +// This is currently done separately to avoid too verbose dispatching. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_bf8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgourped_prefill_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgourped_prefill_bf8i4.cpp new file mode 100644 index 0000000000..94f24d2f5f --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgourped_prefill_bf8i4.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +// This example demonstrates 2D block scale quantization (N×K) for BQuant +// using non-preshuffled configuration. +// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example +// This is currently done separately to avoid too verbose dispatching. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_bf8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgourped_prefill_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgourped_prefill_fp8.cpp new file mode 100644 index 0000000000..505cda1233 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgourped_prefill_fp8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +// This example demonstrates 2D block scale quantization (N×K) for BQuant +// using non-preshuffled configuration. +// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example +// This is currently done separately to avoid too verbose dispatching. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_fp8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgourped_prefill_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgourped_prefill_fp8i4.cpp new file mode 100644 index 0000000000..9687feec04 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgourped_prefill_fp8i4.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +// This example demonstrates 2D block scale quantization (N×K) for BQuant +// using non-preshuffled configuration. +// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example +// This is currently done separately to avoid too verbose dispatching. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_fp8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgourped_preshuffleb_prefill.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgourped_preshuffleb_prefill.cpp new file mode 100644 index 0000000000..68f26ac3b2 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgourped_preshuffleb_prefill.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill; + +void bquant_quantgrouped_preshuffleb_instance_factory( + std::unordered_map>& lut) +{ + using QuantGroupSize = ck_tile::QuantGroupShape>; + lut[hash_multiple_strings( + {"fp8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp new file mode 100644 index 0000000000..a4dc550b26 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +// This example demonstrates 2D block scale quantization (N×K) for BQuant +// using non-preshuffled configuration. +// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example +// This is currently done separately to avoid too verbose dispatching. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/permute_pk_int4.hpp" +#include "ck_tile/host/tensor_shuffle_utils.hpp" +#include "gemm_utils.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("h", "false", "print help message") + .insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Column by default") + .insert("bq_layout", "C", "Bq tensor data layout - Column by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_q", "0", "Tensor AQ stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", + "fp8", + "data type. For AQuant: fp8/bf8/i4fp8/i4bf8, For Bquant: fp8/bf8/fp8i4/bf8i4") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "1000", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("device", "0", "device id that will be used to run the kernel, default 0") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("flush_cache", "true", "flush cache before running the kernel, default to true") + .insert("rotating_count", "1000", "rotating count, defaults to 1") + .insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol") + .insert("preshuffleb", "false", "Enable preshuffle tensor B, default false") + .insert("group_size", + "1x1x128", + "Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +void aquant_quantgrouped_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_fp8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_bf8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_fp8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_bf8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_instance_factory( + std::unordered_map>& lut); + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result || arg_parser.get_bool("h")) + { + arg_parser.print(); + return -1; + } + + auto device_id = arg_parser.get_int("device"); + std::printf("Device ID: %d\n", device_id); + + hipError_t err = hipSetDevice(device_id); + if(err != hipSuccess) + { + std::cerr << "hipSetDevice failed with error: " << hipGetErrorString(err) << std::endl; + return -1; + } + + std::unordered_map> lut; + aquant_quantgrouped_instance_factory(lut); + bquant_quantgrouped_fp8_instance_factory(lut); + bquant_quantgrouped_bf8_instance_factory(lut); + bquant_quantgrouped_fp8i4_instance_factory(lut); + bquant_quantgrouped_bf8i4_instance_factory(lut); + bquant_quantgrouped_preshuffleb_instance_factory(lut); + + std::string data_type = arg_parser.get_str("prec"); + std::string quant_mode = arg_parser.get_str("quant_mode"); + std::string preshuffleb = + arg_parser.get_bool("preshuffleb") ? "preshuffleb" : "non-preshuffleb"; + std::string group_size_str = arg_parser.get_str("group_size"); + + auto key = hash_multiple_strings({data_type, quant_mode, preshuffleb, group_size_str}); + if(lut.find(key) != lut.end()) + { + return lut[key](arg_parser); + } + else + { + std::cerr + << "Error: Combination of prec, quant_mode, preshuffleb, and group_size not supported." + << " (prec: " << data_type << ", quant_mode: " << quant_mode + << ", preshuffleb: " << preshuffleb << ", group_size: " << group_size_str << ")" + << std::endl; + return -1; + } +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp deleted file mode 100644 index b22596537f..0000000000 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ /dev/null @@ -1,424 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -// This example demonstrates 2D block scale quantization (N×K) for BQuant -// using non-preshuffled configuration. -// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example -// This is currently done separately to avoid too verbose dispatching. - -#include -#include -#include -#include -#include -#include - -#include "ck_tile/core/config.hpp" -#include "ck_tile/host.hpp" -#include "gemm_utils.hpp" - -template -float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s) -{ - static_assert(std::is_same_v); - using ComputeDataType = std::conditional_t; - - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile:: - sequence>; - - using TilePartitioner = ck_tile::GemmTile1DPartitioner; - - using GemmTraits = ck_tile::TileGemmQuantTraits; - - using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; - - // This example only supports BQuant (no AQuant) - // For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3 - using BaseGemmPipeline = std::conditional_t< - GemmConfig::PreshuffleB == true, - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, - ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3>; - - const ck_tile::index_t K_split = - (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr bool transpose_c = false; - - // row-col and tensor quants use the regular pipeline, A/B quants use their own - using PipelineProblem = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant, - ck_tile::GemmRowColTensorQuantPipelineProblem, - std::conditional_t, - ck_tile::GemmBQuantPipelineProblem>>; - - using GemmPipeline = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant, - ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t< - QuantMode == ck_tile::QuantType::AQuantGrouped, - ck_tile::AQuantGemmPipelineAgBgCrMem, // memory pipeline hardcoded - // for aquant - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; - - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - GemmConfig::TiledMMAPermuteN>>; - using Kernel = - ck_tile::QuantGemmKernel; - - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(args.k_batch != 1) - { - throw std::runtime_error("split-k is not supported yet!"); - } - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << PipelineProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - float ave_time = 0; - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - ck_tile::RotatingMemWrapper - rotating_mem( - kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString( - hipMemsetAsync(args.c_ptr, - 0, - args.M * args.N * sizeof(typename TypeConfig::CDataType), - s.stream_id_)); - }; - ave_time = ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - } - else - { - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - } - - return ave_time; - }; - return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); -} - -#include "run_gemm_quant_example.inc" - -template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) -{ - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - if((QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) && - GemmConfig::PreshuffleB) - { - throw std::runtime_error( - "Preshuffling weight matrix is not supported for AQuant or RowColQuant"); - } - - if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) - { - if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported memory layout for the input matrices!"); - } - } - else - { - throw std::runtime_error("Unsupported data type for A."); - } - - return 0; -} - -// Forward declaration for dispatch function -template