mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK_TILE] B matrix 2D block scale gemm (#3074)
* Refactor quant group size to be configurable for M/N/K, not just K * add some asserts for configurations not implemented * start setting of group size for N dimension * enable 2d for reference quant gemm * WIP: trying to figure out tile dstr and/or indexing for scale matrix * WIP * Fix handling of n dim blocks in tile windows etc * remove commented code and enable all tests again * fix formatting * Add more specialized tile distributions * Enable NWarps replication for bquant tile dstr * fix formatting * fix format * Fix some issues from the merge * fix formatting * one more fix to tile dstr, and revert debug initialization * Remove commented code Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * simplify conditions that are needed for tile distributions * only enable the working group sizes in tests * fix formatting * Update tile distribution for 2D bquant * add some documentation and 2d block scale example * fix formatting * Add in Changlog and restructure the quant 2d example * fix CMake * support the change for blockscale 2d * fix the test file --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Cong Ma <congma13@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -1,6 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// This example demonstrates 2D block scale quantization (N×K) for BQuant
|
||||
// using non-preshuffled configuration.
|
||||
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
|
||||
// This is currently done separately to avoid too verbose dispatching.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
@@ -17,7 +22,7 @@ template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename CDEElementWise>
|
||||
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
@@ -57,11 +62,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
GemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
// This example only supports BQuant (no AQuant)
|
||||
// For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3
|
||||
using BaseGemmPipeline = std::conditional_t<
|
||||
GemmConfig::PreshuffleB == true,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
|
||||
ck_tile::BaseAQuantGemmPipelineAgBgCrMem<GemmPipelineProblem>>; // memory pipeline hardcoded
|
||||
// for aquant
|
||||
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
|
||||
|
||||
const ck_tile::index_t K_split =
|
||||
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
|
||||
@@ -229,7 +235,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
uint32_t QuantGroupSize,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
@@ -266,6 +272,41 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Forward declaration for dispatch function
|
||||
template <template <typename PreType> typename GemmConfig, typename QuantGroupSize>
|
||||
int dispatch_by_data_type(const std::string& data_type,
|
||||
const std::string& quant_mode,
|
||||
const std::string& a_layout,
|
||||
const std::string& b_layout,
|
||||
int argc,
|
||||
char* argv[]);
|
||||
|
||||
// Helper function to parse group size string "MxNxK"
|
||||
std::tuple<int, int, int> parse_group_size(const std::string& group_size_str)
|
||||
{
|
||||
int m = 1, n = 1, k = 128;
|
||||
|
||||
size_t first_x = group_size_str.find('x');
|
||||
if(first_x == std::string::npos)
|
||||
{
|
||||
// Single number provided, assume it's the K dimension
|
||||
k = std::stoi(group_size_str);
|
||||
return {1, 1, k};
|
||||
}
|
||||
|
||||
size_t second_x = group_size_str.find('x', first_x + 1);
|
||||
if(second_x == std::string::npos)
|
||||
{
|
||||
throw std::runtime_error("Invalid group_size format! Expected MxNxK (e.g., 1x32x128)");
|
||||
}
|
||||
|
||||
m = std::stoi(group_size_str.substr(0, first_x));
|
||||
n = std::stoi(group_size_str.substr(first_x + 1, second_x - first_x - 1));
|
||||
k = std::stoi(group_size_str.substr(second_x + 1));
|
||||
|
||||
return {m, n, k};
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename GemmConfig>
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
@@ -273,139 +314,57 @@ int run_gemm_example(int argc, char* argv[])
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
std::string quant_mode = arg_parser.get_str("quant_mode");
|
||||
std::string group_size_str = arg_parser.get_str("group_size");
|
||||
|
||||
std::string quant_mode = arg_parser.get_str("quant_mode");
|
||||
auto [m_group, n_group, k_group] = parse_group_size(group_size_str);
|
||||
|
||||
// Dispatch based on group size (M, N, K)
|
||||
return dispatch_group_size_ct<GemmConfig>(m_group, n_group, k_group, [&](auto QGS_) {
|
||||
using QuantGroupSize = decltype(QGS_);
|
||||
return dispatch_by_data_type<GemmConfig, QuantGroupSize>(
|
||||
data_type, quant_mode, a_layout, b_layout, argc, argv);
|
||||
});
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename GemmConfig, typename QuantGroupSize>
|
||||
int dispatch_by_data_type(const std::string& data_type,
|
||||
const std::string& quant_mode,
|
||||
const std::string& a_layout,
|
||||
const std::string& b_layout,
|
||||
int argc,
|
||||
char* argv[])
|
||||
{
|
||||
// This example ONLY supports BQuant for 2D block scale quantization
|
||||
if(quant_mode != "bquant")
|
||||
{
|
||||
throw std::runtime_error("This example only supports BQuant! Use --quant_mode=bquant");
|
||||
}
|
||||
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
|
||||
if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
|
||||
}
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
|
||||
if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
|
||||
}
|
||||
}
|
||||
else if(data_type == "i4fp8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
|
||||
if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode for this datatype! Use 'aquant'.");
|
||||
}
|
||||
}
|
||||
else if(data_type == "i4bf8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
|
||||
if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode for this datatype! Use 'aquant'.");
|
||||
}
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8i4")
|
||||
{
|
||||
@@ -414,19 +373,11 @@ int run_gemm_example(int argc, char* argv[])
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
|
||||
if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode for this datatype! Use 'bquant'.");
|
||||
}
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8i4")
|
||||
{
|
||||
@@ -435,19 +386,11 @@ int run_gemm_example(int argc, char* argv[])
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
|
||||
if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode for this datatype! Use 'bquant'.");
|
||||
}
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -455,7 +398,27 @@ int run_gemm_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename> typename GemmConfig, typename F>
|
||||
int dispatch_group_size_ct(int m, int n, int k, F&& f)
|
||||
{
|
||||
// This expands into a sequence of `if (m==M && n==N && k==K) { ... }`
|
||||
#define DISPATCH_ONE(M, N, K) \
|
||||
if(m == M && n == N && k == K) \
|
||||
{ \
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<M, N, K>>; \
|
||||
return f(QuantGroupSize{}); \
|
||||
}
|
||||
|
||||
CK_TILE_SUPPORTED_QUANT_GROUPS(DISPATCH_ONE)
|
||||
|
||||
#undef DISPATCH_ONE
|
||||
|
||||
throw std::runtime_error(
|
||||
"Unsupported group size! Please add it to CK_TILE_SUPPORTED_QUANT_GROUPS(X).");
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);
|
||||
// Use non-preshuffled GemmConfig for 2D block scale support
|
||||
return !run_gemm_example<GemmConfigBQuantPrefill>(argc, argv);
|
||||
}
|
||||
|
||||
@@ -11,6 +11,14 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
|
||||
#define CK_TILE_SUPPORTED_QUANT_GROUPS(X) \
|
||||
X(1, 1, 64) /* 1D */ \
|
||||
X(1, 1, 128) /* 1D */ \
|
||||
X(1, 8, 128) /* 2D N=8 */ \
|
||||
X(1, 32, 128) /* 2D N=32 */ \
|
||||
X(1, 64, 128) /* 2D N=64 */ \
|
||||
X(1, 128, 128) /* 2D N=128 */
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
@@ -193,6 +201,22 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigBQuantPrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_ = ADataType_,
|
||||
typename CDataType_ = ADataType_,
|
||||
@@ -288,7 +312,10 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
|
||||
.insert("rotating_count", "1000", "rotating count, defaults to 1")
|
||||
.insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol");
|
||||
.insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol")
|
||||
.insert("group_size",
|
||||
"1x1x128",
|
||||
"Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -14,7 +14,7 @@ template <typename GemmConfig,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
@@ -113,7 +113,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
uint32_t QuantGroupSize,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
@@ -146,7 +146,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
if(K % QuantGroupSize != 0)
|
||||
if(K % QuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be aligned with QuantGroupSize for AQuantGrouped/BQuantGrouped mode");
|
||||
@@ -155,13 +155,13 @@ int run_gemm_example_with_layouts(int argc,
|
||||
ck_tile::index_t AQK, BQK;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
AQK = K / QuantGroupSize; // Group quantization: AQK = K / GroupSize
|
||||
BQK = 0; // No B quantization
|
||||
AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize
|
||||
BQK = 0; // No B quantization
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
AQK = 0; // No A quantization
|
||||
BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize
|
||||
AQK = 0; // No A quantization
|
||||
BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
@@ -357,7 +357,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
|
||||
ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize);
|
||||
ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize::kK);
|
||||
aq_dev_buf_ptr->ToDevice(aq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user