diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_2d_block.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_2d_block.cpp index 5638ead778..173b47ea5d 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_2d_block.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_2d_block.cpp @@ -4,7 +4,7 @@ // This example demonstrates 2D block scale quantization (N×K) for BQuant // using non-preshuffled configuration. // NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example -// This is currently done separately to avoid too verbose dispatching. +// This is currently done separately to avoid too verbose dispatching. #include #include @@ -278,14 +278,14 @@ int dispatch_by_data_type(const std::string& data_type, const std::string& quant_mode, const std::string& a_layout, const std::string& b_layout, - int argc, - char* argv[]); + int argc, + char* argv[]); // Helper function to parse group size string "MxNxK" std::tuple parse_group_size(const std::string& group_size_str) { int m = 1, n = 1, k = 128; - + size_t first_x = group_size_str.find('x'); if(first_x == std::string::npos) { @@ -293,17 +293,17 @@ std::tuple parse_group_size(const std::string& group_size_str) k = std::stoi(group_size_str); return {1, 1, k}; } - + size_t second_x = group_size_str.find('x', first_x + 1); if(second_x == std::string::npos) { throw std::runtime_error("Invalid group_size format! Expected MxNxK (e.g., 1x32x128)"); } - + m = std::stoi(group_size_str.substr(0, first_x)); n = std::stoi(group_size_str.substr(first_x + 1, second_x - first_x - 1)); k = std::stoi(group_size_str.substr(second_x + 1)); - + return {m, n, k}; } @@ -330,8 +330,9 @@ int run_gemm_example(int argc, char* argv[]) }; // Dispatch for supported group sizes - // Note: This example uses non-preshuffled BQuant which supports both K-only and N×K quantization - + // Note: This example uses non-preshuffled BQuant which supports both K-only and N×K + // quantization + if(m_group == 1 && n_group == 1 && k_group == 64) { return dispatch_by_group_size.template operator()<1, 1, 64>(); @@ -371,14 +372,13 @@ int dispatch_by_data_type(const std::string& data_type, const std::string& quant_mode, const std::string& a_layout, const std::string& b_layout, - int argc, - char* argv[]) + int argc, + char* argv[]) { // This example ONLY supports BQuant for 2D block scale quantization if(quant_mode != "bquant") { - throw std::runtime_error( - "This example only supports BQuant! Use --quant_mode=bquant"); + throw std::runtime_error("This example only supports BQuant! Use --quant_mode=bquant"); } if(data_type == "fp8") diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 0e7f6740af..49aee71f7f 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -305,7 +305,9 @@ auto create_args(int argc, char* argv[]) .insert("flush_cache", "true", "flush cache before running the kernel, defaults to true") .insert("rotating_count", "1000", "rotating count, defaults to 1") .insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol") - .insert("group_size", "1x1x128", "Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128"); + .insert("group_size", + "1x1x128", + "Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index d65021b177..43dbf95941 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -187,13 +187,13 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding static_assert(KWarps == 1); /// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales) - /// + /// /// This function determines the optimal thread distribution pattern for loading and applying /// quantization scales to the B matrix based on the quantization group size (XPerQ) relative /// to warp dimensions. /// /// Three distinct distribution patterns are handled: - /// + /// /// 1. Fine-grained quantization (XPerQ < WarpGemm::kN): /// - Multiple quantization groups exist within a single warp's N-dimension /// - Each warp processes multiple scales (WarpGemm::kN / XPerQ scales per warp) @@ -216,12 +216,12 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding if constexpr(XPerQ < WarpGemm::kN) { // Case 1: Fine-grained - multiple quantization scales within a single warp - constexpr index_t Y = YPerTile; // Full Y dimension of tile - constexpr index_t YR = 1; // No Y replication needed - constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim - constexpr index_t X1 = NWarps; // Number of warps in N-dim - constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp - constexpr index_t XR = XPerQ; // Elements per quantization group + constexpr index_t Y = YPerTile; // Full Y dimension of tile + constexpr index_t YR = 1; // No Y replication needed + constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim + constexpr index_t X1 = NWarps; // Number of warps in N-dim + constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp + constexpr index_t XR = XPerQ; // Elements per quantization group static_assert(X0 * X1 * X2 == XPerTile, "X0, X1, X2 must cover the blocktile along X."); @@ -236,9 +236,9 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding else if constexpr(XPerQ <= WarpGemm::kN * NWarps) { // Case 2: Medium-grained - one quantization scale per warp - constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor - constexpr auto X1 = NWarps / XR; // Warps per unique scale - constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension + constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor + constexpr auto X1 = NWarps / XR; // Warps per unique scale + constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension return make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>,