[CK TILE GEMM] Refactor block_scale_gemm examples (#3181)

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Split cpp file to reduce building time
- Support multiple GemmConfig

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update Readme

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Add support for rowcol and tensor GEMM operations

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update README

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Set quant group size to (1, 1, 64) for targets excluding gfx950, where warp tile size (16, 16, 128) is incompatible.

[ROCm/composable_kernel commit: 6fd8ddabe7]
This commit is contained in:
Cong Ma
2025-11-13 00:43:40 -07:00
committed by GitHub
parent 5c19f34cb4
commit fec8b3228b
14 changed files with 805 additions and 495 deletions

View File

@@ -6,8 +6,19 @@ endif()
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp)
target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
set(EXE_NAME tile_example_gemm_quant)
add_executable(${EXE_NAME} EXCLUDE_FROM_ALL
gemm_quant.cpp
gemm_aquant_quantgrouped.cpp
gemm_bquant_quantgrouped_prefill_bf8i4.cpp
gemm_bquant_quantgrouped_prefill_fp8i4.cpp
gemm_bquant_quantgrouped_prefill_bf8.cpp
gemm_bquant_quantgrouped_prefill_fp8.cpp
gemm_bquant_quantgrouped_preshuffleb_prefill.cpp
gemm_quant_rowcol.cpp
gemm_quant_tensor.cpp
)
target_compile_options(${EXE_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
endif()

View File

@@ -40,23 +40,31 @@ This will result in an executable `build/bin/tile_example_gemm_quant_basic`
## example
```
args:
-b batch size (default:1)
-m m dimension (default:1024)
-n n dimension (default:2048)
-k k dimension (default:64)
-a_layout Tensor A data layout (default: R)
-b_layout Tensor B data layout (default: C)
-c_layout Tensor C data layout (default: R)
-stride_a Tensor A stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8 (default:fp8)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-quant_mode Which quant method to use (aquant, bquant, tensor, rowcol)
-h Print help message (default:false)
-m m dimension (default:3840)
-n n dimension (default:4096)
-k k dimension (default:2048)
-a_layout A tensor data layout - Row or Column (default:R)
-b_layout B tensor data layout - Row or Column (default:C)
-bq_layout Bq tensor data layout - Row or Column (default:C)
-c_layout C tensor data layout - Row or Column (default:R)
-stride_a Tensor A stride (default:0)
-stride_q Tensor AQ stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1)
-prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8)
-warmup Number of iterations before benchmarking the kernel (default:50)
-repeat Number of iterations to benchmark the kernel (default:1000)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-split_k SplitK value (default:1)
-device Device id that will be used to run the kernel (default:0)
-init 0:random, 1:linear, 2:constant(1) (default:0)
-flush_cache Flush cache before running the kernel (default:true)
-rotating_count Rotating count (default:1000)
-quant_mode Choose aquant, bquant, tensor or rowcol (default:bquant)
-preshuffleb Enable preshuffle of tensor B (default:false)
-group_size Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128)
```
User need to select correct mapping of config for each quant mode:

View File

@@ -0,0 +1,53 @@
// SPDX-License-Identifier: MIT
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigQuant<T>;
void aquant_quantgrouped_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
lut[hash_multiple_strings({"fp8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser&
arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser&
arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8i4", "aquant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8i4", "aquant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
}

View File

@@ -0,0 +1,47 @@
// SPDX-License-Identifier: MIT
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
TypeConfig, \
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_bf8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
#ifndef CK_GFX950_SUPPORT
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x64"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
#endif
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x32x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x64x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
TypeConfig, \
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_bf8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
#ifndef CK_GFX950_SUPPORT
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x64"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
#endif
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x32x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x64x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}

View File

@@ -0,0 +1,47 @@
// SPDX-License-Identifier: MIT
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
TypeConfig, \
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_fp8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
#ifndef CK_GFX950_SUPPORT
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x64"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
#endif
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x32x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x64x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
TypeConfig, \
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_fp8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
#ifndef CK_GFX950_SUPPORT
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x64"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
#endif
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x32x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x64x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}

View File

@@ -0,0 +1,53 @@
// SPDX-License-Identifier: MIT
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill<T>;
void bquant_quantgrouped_preshuffleb_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
lut[hash_multiple_strings(
{"fp8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
}

View File

@@ -0,0 +1,130 @@
// SPDX-License-Identifier: MIT
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
#include <cstring>
#include <iostream>
#include <stdexcept>
#include <string>
#include <tuple>
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/tensor_shuffle_utils.hpp"
#include "gemm_utils.hpp"
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("h", "false", "Print help message")
.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row or Column")
.insert("b_layout", "C", "B tensor data layout - Row or Column")
.insert("bq_layout", "C", "Bq tensor data layout - Row or Column")
.insert("c_layout", "R", "C tensor data layout - Row or Column")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_q", "0", "Tensor AQ stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "1", "0: No validation, 1: Validation on CPU, 2: Validation on GPU")
.insert("prec",
"fp8",
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
"or bf8i4")
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "SplitK value")
.insert("device", "0", "Device id that will be used to run the kernel")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("flush_cache", "true", "Flush cache before running the kernel")
.insert("rotating_count", "1000", "Rotating count")
.insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol")
.insert("preshuffleb", "false", "Enable preshuffle of tensor B")
.insert("group_size",
"1x1x128",
"Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
auto gen_lut_key(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
std::string quant_mode = arg_parser.get_str("quant_mode");
std::vector<std::string> params = {data_type, quant_mode};
if(quant_mode == "bquant")
{
std::string preshuffleb =
arg_parser.get_bool("preshuffleb") ? "preshuffleb" : "non-preshuffleb";
params.push_back(preshuffleb);
}
if(quant_mode != "rowcol" && quant_mode != "tensor")
{
// NOTE: rowcol and tensor pipeline do not use group size
std::string group_size_str = arg_parser.get_str("group_size");
params.push_back(group_size_str);
}
return hash_multiple_strings(params);
}
void aquant_quantgrouped_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_fp8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_bf8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_fp8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_bf8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_preshuffleb_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void quant_rowcol_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void quant_tensor_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result || arg_parser.get_bool("h"))
{
arg_parser.print();
return -1;
}
auto device_id = arg_parser.get_int("device");
std::cout << "Device ID: " << device_id << std::endl;
ck_tile::hip_check_error(hipSetDevice(device_id));
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>> lut;
aquant_quantgrouped_instance_factory(lut);
bquant_quantgrouped_fp8_instance_factory(lut);
bquant_quantgrouped_bf8_instance_factory(lut);
bquant_quantgrouped_fp8i4_instance_factory(lut);
bquant_quantgrouped_bf8i4_instance_factory(lut);
bquant_quantgrouped_preshuffleb_instance_factory(lut);
quant_rowcol_instance_factory(lut);
quant_tensor_instance_factory(lut);
auto key = gen_lut_key(arg_parser);
if(lut.find(key) != lut.end())
{
return lut[key](arg_parser);
}
else
{
std::cerr
<< "Error: Combination of prec, quant_mode, preshuffleb, and group_size not supported."
<< std::endl;
return -1;
}
}

View File

@@ -1,428 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// This example demonstrates 2D block scale quantization (N×K) for BQuant
// using non-preshuffled configuration.
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
// This is currently done separately to avoid too verbose dispatching.
#include <cstring>
#include <iostream>
#include <ostream>
#include <stdexcept>
#include <string>
#include <tuple>
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
template <typename GemmConfig,
typename TypeConfig,
typename ALayout,
typename BLayout,
typename CLayout,
typename QuantGroupSize,
ck_tile::QuantType QuantMode,
typename CDEElementWise>
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
using ComputeDataType = std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant,
typename TypeConfig::BDataType,
typename TypeConfig::ADataType>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<GemmShape>;
using GemmTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::PreshuffleQuant,
GemmConfig::PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantMode,
ALayout, // for AQLayout
BLayout, // for BQLayout
false,
GemmConfig::DoubleSmemBuffer>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
GemmShape,
GemmTraits,
ComputeDataType>;
// This example only supports BQuant (no AQuant)
// For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3
using BaseGemmPipeline = std::conditional_t<
GemmConfig::PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
const ck_tile::index_t K_split =
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = false;
// row-col and tensor quants use the regular pipeline, A/B quants use their own
using PipelineProblem = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant,
ck_tile::GemmRowColTensorQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
typename TypeConfig::AccDataType,
GemmShape,
GemmTraits,
transpose_c,
ComputeDataType,
GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::QDataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
GemmShape,
GemmTraits,
QuantGroupSize,
transpose_c,
ComputeDataType,
GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>,
ck_tile::GemmBQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
typename TypeConfig::QDataType,
typename TypeConfig::AccDataType,
GemmShape,
GemmTraits,
QuantGroupSize,
ComputeDataType,
GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>>>;
using GemmPipeline = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant,
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>, // memory pipeline hardcoded
// for aquant
std::conditional_t<GemmConfig::PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
ck_tile::tuple<>,
typename TypeConfig::AccDataType,
typename TypeConfig::CDataType,
ck_tile::tuple<>,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set,
1,
false,
1,
GemmConfig::TiledMMAPermuteN>>;
using Kernel =
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
throw std::runtime_error("split-k is not supported yet!");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << PipelineProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = 0;
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
ck_tile::HostTensor<typename TypeConfig::ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<typename TypeConfig::BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
ck_tile::RotatingMemWrapper<typename TypeConfig::ADataType,
typename TypeConfig::BDataType>
rotating_mem(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(
hipMemsetAsync(args.c_ptr,
0,
args.M * args.N * sizeof(typename TypeConfig::CDataType),
s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
#include "run_gemm_quant_example.inc"
template <typename GemmConfig,
typename TypeConfig,
typename QuantGroupSize,
ck_tile::QuantType QuantMode>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if((QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant) &&
GemmConfig::PreshuffleB)
{
throw std::runtime_error(
"Preshuffling weight matrix is not supported for AQuant or RowColQuant");
}
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig, TypeConfig, QuantGroupSize, QuantMode>(
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
else
{
throw std::runtime_error("Unsupported data type for A.");
}
return 0;
}
// Forward declaration for dispatch function
template <template <typename PreType> typename GemmConfig, typename QuantGroupSize>
int dispatch_by_data_type(const std::string& data_type,
const std::string& quant_mode,
const std::string& a_layout,
const std::string& b_layout,
int argc,
char* argv[]);
// Helper function to parse group size string "MxNxK"
std::tuple<int, int, int> parse_group_size(const std::string& group_size_str)
{
int m = 1, n = 1, k = 128;
size_t first_x = group_size_str.find('x');
if(first_x == std::string::npos)
{
// Single number provided, assume it's the K dimension
k = std::stoi(group_size_str);
return {1, 1, k};
}
size_t second_x = group_size_str.find('x', first_x + 1);
if(second_x == std::string::npos)
{
throw std::runtime_error("Invalid group_size format! Expected MxNxK (e.g., 1x32x128)");
}
m = std::stoi(group_size_str.substr(0, first_x));
n = std::stoi(group_size_str.substr(first_x + 1, second_x - first_x - 1));
k = std::stoi(group_size_str.substr(second_x + 1));
return {m, n, k};
}
template <template <typename PreType> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
std::string quant_mode = arg_parser.get_str("quant_mode");
std::string group_size_str = arg_parser.get_str("group_size");
auto [m_group, n_group, k_group] = parse_group_size(group_size_str);
// Dispatch based on group size (M, N, K)
return dispatch_group_size_ct<GemmConfig>(m_group, n_group, k_group, [&](auto QGS_) {
using QuantGroupSize = decltype(QGS_);
return dispatch_by_data_type<GemmConfig, QuantGroupSize>(
data_type, quant_mode, a_layout, b_layout, argc, argv);
});
}
template <template <typename PreType> typename GemmConfig, typename QuantGroupSize>
int dispatch_by_data_type(const std::string& data_type,
const std::string& quant_mode,
const std::string& a_layout,
const std::string& b_layout,
int argc,
char* argv[])
{
// This example ONLY supports BQuant for 2D block scale quantization
if(quant_mode != "bquant")
{
throw std::runtime_error("This example only supports BQuant! Use --quant_mode=bquant");
}
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8i4")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8i4")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
template <template <typename> typename GemmConfig, typename F>
int dispatch_group_size_ct(int m, int n, int k, F&& f)
{
// This expands into a sequence of `if (m==M && n==N && k==K) { ... }`
#define DISPATCH_ONE(M, N, K) \
if(m == M && n == N && k == K) \
{ \
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<M, N, K>>; \
return f(QuantGroupSize{}); \
}
CK_TILE_SUPPORTED_QUANT_GROUPS(DISPATCH_ONE)
#undef DISPATCH_ONE
throw std::runtime_error(
"Unsupported group size! Please add it to CK_TILE_SUPPORTED_QUANT_GROUPS(X).");
}
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigBQuantPrefill_Wmma>(argc, argv);
#else
// Use non-preshuffled GemmConfig for 2D block scale support
return !run_gemm_example<GemmConfigBQuantPrefill>(argc, argv);
#endif
}

View File

@@ -0,0 +1,30 @@
// SPDX-License-Identifier: MIT
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigQuant<T>;
void quant_rowcol_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
// NOTE: QuantGroupSize is a place holder. rowcol pipeline does not use QuantGroupSize
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 1>>;
lut[hash_multiple_strings({"fp8", "rowcol"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::RowColQuant>(arg_parser);
};
lut[hash_multiple_strings({"bf8", "rowcol"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::RowColQuant>(arg_parser);
};
}

View File

@@ -0,0 +1,30 @@
// SPDX-License-Identifier: MIT
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigQuant<T>;
void quant_tensor_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
// NOTE: QuantGroupSize is a place holder. tensor pipeline does not use QuantGroupSize
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 1>>;
lut[hash_multiple_strings({"fp8", "tensor"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::TensorQuant>(arg_parser);
};
lut[hash_multiple_strings({"bf8", "tensor"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::TensorQuant>(arg_parser);
};
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -11,12 +11,18 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_quant.hpp"
#define CK_TILE_SUPPORTED_QUANT_GROUPS(X) \
X(1, 1, 64) /* 1D */ \
X(1, 1, 128) /* 1D */ \
X(1, 8, 128) /* 2D N=8 */ \
X(1, 32, 128) /* 2D N=32 */ \
X(1, 64, 128) /* 2D N=64 */
inline size_t hash_multiple_strings(const std::vector<std::string>& inputs)
{
std::hash<std::string> hasher;
size_t combined_hash = 0;
for(const auto& str : inputs)
{
// Hash combine using golden ratio constant and bit shifts for good distribution and
// order-dependent mixing
combined_hash ^= hasher(str) + 0x9e3779b9 + (combined_hash << 6) + (combined_hash >> 2);
}
return combined_hash;
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
@@ -293,37 +299,3 @@ struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Column by default")
.insert("bq_layout", "C", "Bq tensor data layout - Column by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_q", "0", "Tensor AQ stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec",
"fp8",
"data type. For AQuant: fp8/bf8/i4fp8/i4bf8, For Bquant: fp8/bf8/fp8i4/bf8i4")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "1000", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
.insert("rotating_count", "1000", "rotating count, defaults to 1")
.insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol")
.insert("group_size",
"1x1x128",
"Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}

View File

@@ -1,11 +1,233 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c), Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstring>
#include <iostream>
#include <ostream>
#include <random>
#include <stdexcept>
#include <string>
#include <tuple>
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/tensor_shuffle_utils.hpp"
#include "gemm_utils.hpp"
template <typename GemmConfig,
typename TypeConfig,
typename ALayout,
typename BLayout,
typename CLayout,
typename QuantGroupSize,
ck_tile::QuantType QuantMode,
typename CDEElementWise>
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
using ComputeDataType = std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant,
typename TypeConfig::BDataType,
typename TypeConfig::ADataType>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<GemmShape>;
using GemmTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::PreshuffleQuant,
GemmConfig::PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantMode,
ALayout, // for AQLayout
BLayout, // for BQLayout
false,
GemmConfig::DoubleSmemBuffer>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
GemmShape,
GemmTraits,
ComputeDataType>;
// This example only supports BQuant (no AQuant)
// For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3
using BaseGemmPipeline = std::conditional_t<
GemmConfig::PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
const ck_tile::index_t K_split =
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = false;
// row-col and tensor quants use the regular pipeline, A/B quants use their own
using PipelineProblem = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant,
ck_tile::GemmRowColTensorQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
typename TypeConfig::AccDataType,
GemmShape,
GemmTraits,
transpose_c,
ComputeDataType,
GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::QDataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
GemmShape,
GemmTraits,
QuantGroupSize,
transpose_c,
ComputeDataType,
GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>,
ck_tile::GemmBQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
typename TypeConfig::QDataType,
typename TypeConfig::AccDataType,
GemmShape,
GemmTraits,
QuantGroupSize,
ComputeDataType,
GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>>>;
using GemmPipeline = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant,
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<GemmConfig::PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
ck_tile::tuple<>,
typename TypeConfig::AccDataType,
typename TypeConfig::CDataType,
ck_tile::tuple<>,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set,
1,
false,
1,
GemmConfig::TiledMMAPermuteN>>;
using Kernel =
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
throw std::runtime_error("split-k is not supported yet!");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << PipelineProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = 0;
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
ck_tile::HostTensor<typename TypeConfig::ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<typename TypeConfig::BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
ck_tile::RotatingMemWrapper<typename TypeConfig::ADataType,
typename TypeConfig::BDataType>
rotating_mem(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(
hipMemsetAsync(args.c_ptr,
0,
args.M * args.N * sizeof(typename TypeConfig::CDataType),
s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
template <typename GemmConfig,
typename TypeConfig,
@@ -120,18 +342,13 @@ template <typename GemmConfig,
typename BLayout,
typename BQLayout,
typename CLayout>
int run_gemm_example_with_layouts(int argc,
char* argv[],
int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
const ALayout a_layout = ALayout{},
const AQLayout aq_layout = AQLayout{},
const BLayout b_layout = BLayout{},
const BQLayout bq_layout = BQLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using ADataType = typename TypeConfig::ADataType;
using AQDataType = typename TypeConfig::QDataType;
using BDataType = typename TypeConfig::BDataType;
@@ -522,3 +739,45 @@ int run_gemm_example_with_layouts(int argc,
return pass;
}
template <typename GemmConfig,
typename TypeConfig,
typename QuantGroupSize,
ck_tile::QuantType QuantMode>
int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if((QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant) &&
GemmConfig::PreshuffleB)
{
throw std::runtime_error(
"Preshuffling weight matrix is not supported for AQuant or RowColQuant");
}
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
{
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig, TypeConfig, QuantGroupSize, QuantMode>(
arg_parser, Row{}, Row{}, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
else
{
throw std::runtime_error("Unsupported data type for A.");
}
return 0;
}