[CK_TILE] B matrix 2D block scale gemm (#3074)

* Refactor quant group size to be configurable for M/N/K, not just K

* add some asserts for configurations not implemented

* start setting of group size for N dimension

* enable 2d for reference quant gemm

* WIP: trying to figure out tile dstr and/or indexing for scale matrix

* WIP

* Fix handling of n dim blocks in tile windows etc

* remove commented code and enable all tests again

* fix formatting

* Add more specialized tile distributions

* Enable NWarps replication for bquant tile dstr

* fix formatting

* fix format

* Fix some issues from the merge

* fix formatting

* one more fix to tile dstr, and revert debug initialization

* Remove commented code

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* simplify conditions that are needed for tile distributions

* only enable the working group sizes in tests

* fix formatting

* Update tile distribution for 2D bquant

* add some documentation and 2d block scale example

* fix formatting

* Add in Changlog and restructure the quant 2d example

* fix CMake

* support the change for blockscale 2d

* fix the test file

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Sami Remes
2025-11-03 00:49:20 +00:00
committed by GitHub
parent 73f637894d
commit 16e85cf179
24 changed files with 476 additions and 363 deletions

View File

@@ -1,6 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// This example demonstrates 2D block scale quantization (N×K) for BQuant
// using non-preshuffled configuration.
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
// This is currently done separately to avoid too verbose dispatching.
#include <cstring>
#include <iostream>
#include <ostream>
@@ -17,7 +22,7 @@ template <typename GemmConfig,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize,
typename QuantGroupSize,
ck_tile::QuantType QuantMode,
typename CDEElementWise>
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
@@ -57,11 +62,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
GemmTraits,
ComputeDataType>;
// This example only supports BQuant (no AQuant)
// For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3
using BaseGemmPipeline = std::conditional_t<
GemmConfig::PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
ck_tile::BaseAQuantGemmPipelineAgBgCrMem<GemmPipelineProblem>>; // memory pipeline hardcoded
// for aquant
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
const ck_tile::index_t K_split =
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
@@ -229,7 +235,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
template <typename GemmConfig,
typename TypeConfig,
uint32_t QuantGroupSize,
typename QuantGroupSize,
ck_tile::QuantType QuantMode>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
@@ -266,6 +272,41 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
return 0;
}
// Forward declaration for dispatch function
template <template <typename PreType> typename GemmConfig, typename QuantGroupSize>
int dispatch_by_data_type(const std::string& data_type,
const std::string& quant_mode,
const std::string& a_layout,
const std::string& b_layout,
int argc,
char* argv[]);
// Helper function to parse group size string "MxNxK"
std::tuple<int, int, int> parse_group_size(const std::string& group_size_str)
{
int m = 1, n = 1, k = 128;
size_t first_x = group_size_str.find('x');
if(first_x == std::string::npos)
{
// Single number provided, assume it's the K dimension
k = std::stoi(group_size_str);
return {1, 1, k};
}
size_t second_x = group_size_str.find('x', first_x + 1);
if(second_x == std::string::npos)
{
throw std::runtime_error("Invalid group_size format! Expected MxNxK (e.g., 1x32x128)");
}
m = std::stoi(group_size_str.substr(0, first_x));
n = std::stoi(group_size_str.substr(first_x + 1, second_x - first_x - 1));
k = std::stoi(group_size_str.substr(second_x + 1));
return {m, n, k};
}
template <template <typename PreType> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
{
@@ -273,139 +314,57 @@ int run_gemm_example(int argc, char* argv[])
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
std::string quant_mode = arg_parser.get_str("quant_mode");
std::string group_size_str = arg_parser.get_str("group_size");
std::string quant_mode = arg_parser.get_str("quant_mode");
auto [m_group, n_group, k_group] = parse_group_size(group_size_str);
// Dispatch based on group size (M, N, K)
return dispatch_group_size_ct<GemmConfig>(m_group, n_group, k_group, [&](auto QGS_) {
using QuantGroupSize = decltype(QGS_);
return dispatch_by_data_type<GemmConfig, QuantGroupSize>(
data_type, quant_mode, a_layout, b_layout, argc, argv);
});
}
template <template <typename PreType> typename GemmConfig, typename QuantGroupSize>
int dispatch_by_data_type(const std::string& data_type,
const std::string& quant_mode,
const std::string& a_layout,
const std::string& b_layout,
int argc,
char* argv[])
{
// This example ONLY supports BQuant for 2D block scale quantization
if(quant_mode != "bquant")
{
throw std::runtime_error("This example only supports BQuant! Use --quant_mode=bquant");
}
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
if(quant_mode == "aquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
}
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
if(quant_mode == "aquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
}
}
else if(data_type == "i4fp8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
if(quant_mode == "aquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode for this datatype! Use 'aquant'.");
}
}
else if(data_type == "i4bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
if(quant_mode == "aquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode for this datatype! Use 'aquant'.");
}
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8i4")
{
@@ -414,19 +373,11 @@ int run_gemm_example(int argc, char* argv[])
ck_tile::half_t,
ck_tile::fp8_t>{});
if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode for this datatype! Use 'bquant'.");
}
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8i4")
{
@@ -435,19 +386,11 @@ int run_gemm_example(int argc, char* argv[])
ck_tile::half_t,
ck_tile::bf8_t>{});
if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode for this datatype! Use 'bquant'.");
}
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
@@ -455,7 +398,27 @@ int run_gemm_example(int argc, char* argv[])
}
}
template <template <typename> typename GemmConfig, typename F>
int dispatch_group_size_ct(int m, int n, int k, F&& f)
{
// This expands into a sequence of `if (m==M && n==N && k==K) { ... }`
#define DISPATCH_ONE(M, N, K) \
if(m == M && n == N && k == K) \
{ \
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<M, N, K>>; \
return f(QuantGroupSize{}); \
}
CK_TILE_SUPPORTED_QUANT_GROUPS(DISPATCH_ONE)
#undef DISPATCH_ONE
throw std::runtime_error(
"Unsupported group size! Please add it to CK_TILE_SUPPORTED_QUANT_GROUPS(X).");
}
int main(int argc, char* argv[])
{
return !run_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);
// Use non-preshuffled GemmConfig for 2D block scale support
return !run_gemm_example<GemmConfigBQuantPrefill>(argc, argv);
}

View File

@@ -11,6 +11,14 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_quant.hpp"
#define CK_TILE_SUPPORTED_QUANT_GROUPS(X) \
X(1, 1, 64) /* 1D */ \
X(1, 1, 128) /* 1D */ \
X(1, 8, 128) /* 2D N=8 */ \
X(1, 32, 128) /* 2D N=32 */ \
X(1, 64, 128) /* 2D N=64 */ \
X(1, 128, 128) /* 2D N=128 */
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
@@ -193,6 +201,22 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};
template <typename PrecType>
struct GemmConfigBQuantPrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename ADataType_,
typename BDataType_ = ADataType_,
typename CDataType_ = ADataType_,
@@ -288,7 +312,10 @@ auto create_args(int argc, char* argv[])
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
.insert("rotating_count", "1000", "rotating count, defaults to 1")
.insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol");
.insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol")
.insert("group_size",
"1x1x128",
"Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -14,7 +14,7 @@ template <typename GemmConfig,
typename BLayout,
typename BQLayout,
typename CLayout,
uint32_t QuantGroupSize,
typename QuantGroupSize,
ck_tile::QuantType QuantMode,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
@@ -113,7 +113,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
template <typename GemmConfig,
typename TypeConfig,
uint32_t QuantGroupSize,
typename QuantGroupSize,
ck_tile::QuantType QuantMode,
typename ALayout,
typename AQLayout,
@@ -146,7 +146,7 @@ int run_gemm_example_with_layouts(int argc,
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::BQuantGrouped)
{
if(K % QuantGroupSize != 0)
if(K % QuantGroupSize::kK != 0)
{
throw std::runtime_error(
"K must be aligned with QuantGroupSize for AQuantGrouped/BQuantGrouped mode");
@@ -155,13 +155,13 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t AQK, BQK;
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
AQK = K / QuantGroupSize; // Group quantization: AQK = K / GroupSize
BQK = 0; // No B quantization
AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize
BQK = 0; // No B quantization
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
AQK = 0; // No A quantization
BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize
AQK = 0; // No A quantization
BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize
}
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
@@ -357,7 +357,7 @@ int run_gemm_example_with_layouts(int argc,
if constexpr(GemmConfig::PreshuffleQuant)
{
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize);
ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize::kK);
aq_dev_buf_ptr->ToDevice(aq_shuffle_host.data());
}
else