fix formatting

This commit is contained in:
Sami Remes
2025-10-31 20:16:52 +00:00
parent fe92102baf
commit 6f90564708
3 changed files with 27 additions and 25 deletions

View File

@@ -4,7 +4,7 @@
// This example demonstrates 2D block scale quantization (N×K) for BQuant
// using non-preshuffled configuration.
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
// This is currently done separately to avoid too verbose dispatching.
// This is currently done separately to avoid too verbose dispatching.
#include <cstring>
#include <iostream>
@@ -278,14 +278,14 @@ int dispatch_by_data_type(const std::string& data_type,
const std::string& quant_mode,
const std::string& a_layout,
const std::string& b_layout,
int argc,
char* argv[]);
int argc,
char* argv[]);
// Helper function to parse group size string "MxNxK"
std::tuple<int, int, int> parse_group_size(const std::string& group_size_str)
{
int m = 1, n = 1, k = 128;
size_t first_x = group_size_str.find('x');
if(first_x == std::string::npos)
{
@@ -293,17 +293,17 @@ std::tuple<int, int, int> parse_group_size(const std::string& group_size_str)
k = std::stoi(group_size_str);
return {1, 1, k};
}
size_t second_x = group_size_str.find('x', first_x + 1);
if(second_x == std::string::npos)
{
throw std::runtime_error("Invalid group_size format! Expected MxNxK (e.g., 1x32x128)");
}
m = std::stoi(group_size_str.substr(0, first_x));
n = std::stoi(group_size_str.substr(first_x + 1, second_x - first_x - 1));
k = std::stoi(group_size_str.substr(second_x + 1));
return {m, n, k};
}
@@ -330,8 +330,9 @@ int run_gemm_example(int argc, char* argv[])
};
// Dispatch for supported group sizes
// Note: This example uses non-preshuffled BQuant which supports both K-only and N×K quantization
// Note: This example uses non-preshuffled BQuant which supports both K-only and N×K
// quantization
if(m_group == 1 && n_group == 1 && k_group == 64)
{
return dispatch_by_group_size.template operator()<1, 1, 64>();
@@ -371,14 +372,13 @@ int dispatch_by_data_type(const std::string& data_type,
const std::string& quant_mode,
const std::string& a_layout,
const std::string& b_layout,
int argc,
char* argv[])
int argc,
char* argv[])
{
// This example ONLY supports BQuant for 2D block scale quantization
if(quant_mode != "bquant")
{
throw std::runtime_error(
"This example only supports BQuant! Use --quant_mode=bquant");
throw std::runtime_error("This example only supports BQuant! Use --quant_mode=bquant");
}
if(data_type == "fp8")

View File

@@ -305,7 +305,9 @@ auto create_args(int argc, char* argv[])
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
.insert("rotating_count", "1000", "rotating count, defaults to 1")
.insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol")
.insert("group_size", "1x1x128", "Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128");
.insert("group_size",
"1x1x128",
"Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -187,13 +187,13 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
static_assert(KWarps == 1);
/// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales)
///
///
/// This function determines the optimal thread distribution pattern for loading and applying
/// quantization scales to the B matrix based on the quantization group size (XPerQ) relative
/// to warp dimensions.
///
/// Three distinct distribution patterns are handled:
///
///
/// 1. Fine-grained quantization (XPerQ < WarpGemm::kN):
/// - Multiple quantization groups exist within a single warp's N-dimension
/// - Each warp processes multiple scales (WarpGemm::kN / XPerQ scales per warp)
@@ -216,12 +216,12 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
if constexpr(XPerQ < WarpGemm::kN)
{
// Case 1: Fine-grained - multiple quantization scales within a single warp
constexpr index_t Y = YPerTile; // Full Y dimension of tile
constexpr index_t YR = 1; // No Y replication needed
constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim
constexpr index_t X1 = NWarps; // Number of warps in N-dim
constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp
constexpr index_t XR = XPerQ; // Elements per quantization group
constexpr index_t Y = YPerTile; // Full Y dimension of tile
constexpr index_t YR = 1; // No Y replication needed
constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim
constexpr index_t X1 = NWarps; // Number of warps in N-dim
constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp
constexpr index_t XR = XPerQ; // Elements per quantization group
static_assert(X0 * X1 * X2 == XPerTile, "X0, X1, X2 must cover the blocktile along X.");
@@ -236,9 +236,9 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
else if constexpr(XPerQ <= WarpGemm::kN * NWarps)
{
// Case 2: Medium-grained - one quantization scale per warp
constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor
constexpr auto X1 = NWarps / XR; // Warps per unique scale
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension
constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor
constexpr auto X1 = NWarps / XR; // Warps per unique scale
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, XR, get_warp_size()>,
tuple<sequence<YPerTile>, sequence<X0, X1>>,