[rocm-libraries] ROCm/rocm-libraries#5323 (commit 5454e9e)

CK Tile MX GEMM Packing Improvement

## Motivation

Reduce the scale loading size and also has better utilization of MFMA
scale selection.

## Technical Details

Add up the packing of mx scales.

## Test Plan

Use the existing test cases.

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Thomas Ning
2026-03-17 18:58:56 +00:00
committed by assistant-librarian[bot]
parent 859acb5ae7
commit 5f90f69795
9 changed files with 399 additions and 130 deletions

View File

@@ -249,14 +249,19 @@ struct BlockGemmARegBRegCRegV1
});
}
// C += A * B with MX scaling
// ScaleATensor: [MIterPerWarp, KIterPerWarp] -> int32_t
// ScaleBTensor: [NIterPerWarp, KIterPerWarp] -> int32_t
// C += A * B with MX scaling and packed-in-two (XdlPack) optimization
// Scale tensors contain pre-packed int32_t: each int32_t holds MXdlPack * KXdlPack e8m0_t
// values (for A) or NXdlPack * KXdlPack (for B), packed on the host.
// Uses OpSel (0-3) to select which byte within the packed int32_t for each MFMA call.
// XdlPack template parameters default to 2; fall back to 1 when iteration count is too small.
template <typename CBlockTensor,
typename ABlockTensor,
typename BBlockTensor,
typename ScaleATensor,
typename ScaleBTensor>
typename ScaleBTensor,
index_t MXdlPack_ = 2,
index_t NXdlPack_ = 2,
index_t KXdlPack_ = 2>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor,
@@ -304,53 +309,88 @@ struct BlockGemmARegBRegCRegV1
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop with MX scaling:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// Effective XdlPack: fall back to 1 when iteration count is insufficient
constexpr index_t MXdlPack =
(MIterPerWarp >= MXdlPack_ && MIterPerWarp % MXdlPack_ == 0) ? MXdlPack_ : 1;
constexpr index_t NXdlPack =
(NIterPerWarp >= NXdlPack_ && NIterPerWarp % NXdlPack_ == 0) ? NXdlPack_ : 1;
constexpr index_t KXdlPack =
(KIterPerWarp >= KXdlPack_ && KIterPerWarp % KXdlPack_ == 0) ? KXdlPack_ : 1;
// get A scale for this M-K tile using get_y_sliced_thread_data
constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack;
constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack;
constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack;
// hot loop with MX scaling and pre-packed int32_t scales:
// Outer loops iterate over pack groups (scale tile indices)
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) {
// Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t)
auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data(
sequence<kIter, mIter, 0>{}, sequence<1, 1, 1>{});
const auto a_scale_e8m0 = scale_a_slice[number<0>{}];
const int32_t a_scale = static_cast<int32_t>(a_scale_e8m0.get());
sequence<ikpack, impack, 0>{}, sequence<1, 1, 1>{});
const int32_t a_scale_packed = bit_cast<int32_t>(scale_a_slice[number<0>{}]);
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// get B scale for this N-K tile using get_y_sliced_thread_data
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
// Get pre-packed int32_t B scale
auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data(
sequence<kIter, nIter, 0>{}, sequence<1, 1, 1>{});
const auto b_scale_e8m0 = scale_b_slice[number<0>{}];
const int32_t b_scale = static_cast<int32_t>(b_scale_e8m0.get());
sequence<ikpack, inpack, 0>{}, sequence<1, 1, 1>{});
const int32_t b_scale_packed = bit_cast<int32_t>(scale_b_slice[number<0>{}]);
// read C warp tensor from C block tensor
using c_iter_idx = std::
conditional_t<TransposeC, sequence<nIter, mIter>, sequence<mIter, nIter>>;
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// Inner loops: issue MFMAs within the pack group using OpSel
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto kIter = ikpack * KXdlPack + ikxdl;
constexpr auto mIter = impack * MXdlPack + imxdl;
// warp GEMM with MX scaling
// Cast e8m0_t to int32_t, use OpSel=0 (least significant byte)
constexpr index_t kOpSel = 0; // Always use OpSel=0
WarpGemm{}.template operator()<kOpSel, kOpSel>(
c_warp_tensor, a_warp_tensor, b_warp_tensor, a_scale, b_scale);
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() =
a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
// OpSel for A: selects byte within packed int32_t
constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto nIter = inpack * NXdlPack + inxdl;
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() =
b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{},
b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// OpSel for B: selects byte within packed int32_t
constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl;
// read C warp tensor from C block tensor
using c_iter_idx = std::conditional_t<TransposeC,
sequence<nIter, mIter>,
sequence<mIter, nIter>>;
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tensor.get_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM with MX scaling using pre-packed scale and OpSel
WarpGemm{}.template operator()<kOpSelA, kOpSelB>(c_warp_tensor,
a_warp_tensor,
b_warp_tensor,
a_scale_packed,
b_scale_packed);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
});
});
});