fix formatting

This commit is contained in:
Sami Remes
2025-10-31 20:16:52 +00:00
parent fe92102baf
commit 6f90564708
3 changed files with 27 additions and 25 deletions

View File

@@ -187,13 +187,13 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
static_assert(KWarps == 1);
/// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales)
///
///
/// This function determines the optimal thread distribution pattern for loading and applying
/// quantization scales to the B matrix based on the quantization group size (XPerQ) relative
/// to warp dimensions.
///
/// Three distinct distribution patterns are handled:
///
///
/// 1. Fine-grained quantization (XPerQ < WarpGemm::kN):
/// - Multiple quantization groups exist within a single warp's N-dimension
/// - Each warp processes multiple scales (WarpGemm::kN / XPerQ scales per warp)
@@ -216,12 +216,12 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
if constexpr(XPerQ < WarpGemm::kN)
{
// Case 1: Fine-grained - multiple quantization scales within a single warp
constexpr index_t Y = YPerTile; // Full Y dimension of tile
constexpr index_t YR = 1; // No Y replication needed
constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim
constexpr index_t X1 = NWarps; // Number of warps in N-dim
constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp
constexpr index_t XR = XPerQ; // Elements per quantization group
constexpr index_t Y = YPerTile; // Full Y dimension of tile
constexpr index_t YR = 1; // No Y replication needed
constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim
constexpr index_t X1 = NWarps; // Number of warps in N-dim
constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp
constexpr index_t XR = XPerQ; // Elements per quantization group
static_assert(X0 * X1 * X2 == XPerTile, "X0, X1, X2 must cover the blocktile along X.");
@@ -236,9 +236,9 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
else if constexpr(XPerQ <= WarpGemm::kN * NWarps)
{
// Case 2: Medium-grained - one quantization scale per warp
constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor
constexpr auto X1 = NWarps / XR; // Warps per unique scale
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension
constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor
constexpr auto X1 = NWarps / XR; // Warps per unique scale
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, XR, get_warp_size()>,
tuple<sequence<YPerTile>, sequence<X0, X1>>,