Make CK TILE GEMM Aquant support block tile 128x128x128 (#3325)

* [CK TILE GEMM Quant] Rename GemmConfigBQuantPrefill to GemmConfigQuantPrefill in examples

* [CK TILE GEMM Quant] update tile distribution of aquant

* [CK TILE GEMM Quant] update aquant register offset calculation

* [CK TILE GEMM Quant] Reimplement aquant register offset calculation

* [CK TILE GEMM Quant] Add more unit tests of Aquant

- Test M128xN128xK128

* [CK TILE GEMM Quant] Add more comments to Gemm Aquant
This commit is contained in:
Cong Ma
2025-12-01 16:04:37 -07:00
committed by GitHub
parent 7873f8fa13
commit 23fb253c4e
11 changed files with 58 additions and 46 deletions

View File

@@ -94,21 +94,20 @@ struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding
// # of elements per thread
constexpr index_t X = XPerTile;
constexpr index_t Y0 = 1;
constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
constexpr index_t Y2 = MWarps;
constexpr index_t Y3 = WarpGemm::kM;
static_assert(Y3 >= WarpGemm::kM,
constexpr index_t YR = 1;
constexpr index_t Y0 = MIterPerWarp ? MIterPerWarp : 1;
constexpr index_t Y1 = MWarps;
constexpr index_t Y2 = WarpGemm::kM;
static_assert(Y2 >= WarpGemm::kM,
"Scales for all rows must be available within the warp.");
static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
"Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>,
tuple<sequence<Y0, Y1, Y2, Y3>, sequence<X>>,
tuple<sequence<1, 0>, sequence<1, 1>>,
tuple<sequence<2, 0>, sequence<0, 3>>,
tile_distribution_encoding<sequence<NWarps, YR>,
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
tuple<sequence<1, 0>, sequence<0, 1>>,
tuple<sequence<1, 0>, sequence<1, 2>>,
sequence<1, 2>,
sequence<1, 0>>{});
sequence<0, 0>>{});
}
}
};