[rocm-libraries] ROCm/rocm-libraries#5323 (commit 5454e9e)

CK Tile MX GEMM Packing Improvement

## Motivation

Reduce the scale loading size and also has better utilization of MFMA
scale selection.

## Technical Details

Add up the packing of mx scales.

## Test Plan

Use the existing test cases.

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Thomas Ning
2026-03-17 18:58:56 +00:00
committed by assistant-librarian[bot]
parent 859acb5ae7
commit 5f90f69795
9 changed files with 399 additions and 130 deletions

View File

@@ -98,6 +98,30 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
static constexpr auto APackedSize = numeric_traits<ADataType>::PackedSize;
static constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
// XdlPack: desired packing of e8m0_t scale values into int32_t
static constexpr index_t MXdlPack = 2;
static constexpr index_t NXdlPack = 2;
static constexpr index_t KXdlPack = 2;
// Effective pack sizes: fall back to 1 when dimension is too small
using BlockWarps_ = typename BlockGemmShape::BlockWarps;
static constexpr index_t MPerBlock_ = BlockGemmShape::kM;
static constexpr index_t NPerBlock_ = BlockGemmShape::kN;
static constexpr index_t KPerBlock_ = BlockGemmShape::kK;
static constexpr index_t MWarp_ = BlockWarps_::at(number<0>{});
static constexpr index_t NWarp_ = BlockWarps_::at(number<1>{});
static constexpr index_t KPerXdl_ = BlockGemmShape::WarpTile::at(number<2>{});
static constexpr index_t MIterPerWarp_ = MPerBlock_ / (MWarp_ * MThreadPerXdl);
static constexpr index_t NIterPerWarp_ = NPerBlock_ / (NWarp_ * NThreadPerXdl);
static constexpr index_t KIterPerWarp_ = KPerBlock_ / KPerXdl_;
static constexpr index_t MXdlPackEff =
(MIterPerWarp_ >= MXdlPack && MIterPerWarp_ % MXdlPack == 0) ? MXdlPack : 1;
static constexpr index_t NXdlPackEff =
(NIterPerWarp_ >= NXdlPack && NIterPerWarp_ % NXdlPack == 0) ? NXdlPack : 1;
static constexpr index_t KXdlPackEff =
(KIterPerWarp_ >= KXdlPack && KIterPerWarp_ % KXdlPack == 0) ? KXdlPack : 1;
static constexpr int kBlockPerCu = 1;
static_assert(DsLayout::size() == DsDataType::size(),
@@ -245,7 +269,9 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
return c_block_window;
}
// Create scale A block windows following the pattern of MakeABlockWindows
// Create scale A block windows with packed int32_t layout
// Host packs 2M x 2K e8m0_t values into one int32_t
// Tensor view: [M/MXdlPack, K/32/KXdlPack] of int32_t
template <typename ScaleM, typename ScaleN>
CK_TILE_DEVICE static auto MakeScaleABlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs,
const index_t i_m)
@@ -253,28 +279,28 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
auto scale_a = kargs.scale_m_ptr;
static constexpr int BlockScaleSize = ScaleM::GranularityK;
const auto scale_k_size = kargs.K / BlockScaleSize;
const auto scale_k_packed = kargs.K / BlockScaleSize / KXdlPackEff;
const auto scale_m_packed = kargs.M / MXdlPackEff;
// A scale tensor view - layout [M, scale_k_size] with e8m0_t elements
// Use e8m0_t directly without packing
// A scale tensor view - layout [M/MXdlPackEff, K/32/KXdlPackEff] with int32_t elements
const auto scale_a_tensor_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const e8m0_t*>(scale_a.ptr),
make_tuple(kargs.M, scale_k_size),
make_tuple(scale_k_size, 1));
reinterpret_cast<const int32_t*>(scale_a.ptr),
make_tuple(scale_m_packed, scale_k_packed),
make_tuple(scale_k_packed, 1));
// Create block window for scale A
// K dimension: scale_k_size e8m0_t elements
// i_m is element offset (iM * MPerBlock), not tile index
auto scale_a_block_window =
make_tile_window(scale_a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / BlockScaleSize>{}),
{i_m, 0});
// Tile window shape: [MPerBlock/MXdlPackEff, KPerBlock/32/KXdlPackEff]
auto scale_a_block_window = make_tile_window(
scale_a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock / MXdlPackEff>{},
number<TilePartitioner::KPerBlock / BlockScaleSize / KXdlPackEff>{}),
{i_m / MXdlPackEff, 0});
return scale_a_block_window;
}
// Create scale B block windows following the pattern of MakeBBlockWindows
// Create scale B block windows with packed int32_t layout
// Host packs 2N x 2K e8m0_t values into one int32_t
// Tensor view: [N/NXdlPack, K/32/KXdlPack] of int32_t
template <typename ScaleM, typename ScaleN>
CK_TILE_DEVICE static auto MakeScaleBBlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs,
const index_t i_n)
@@ -282,23 +308,21 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
auto scale_b = kargs.scale_n_ptr;
static constexpr int BlockScaleSize = ScaleN::GranularityK;
const auto scale_k_size = kargs.K / BlockScaleSize;
const auto scale_k_packed = kargs.K / BlockScaleSize / KXdlPackEff;
const auto scale_n_packed = kargs.N / NXdlPackEff;
// B scale tensor view
// Host stores as [K/32, N] col-major = [N, K/32] row-major from access perspective
// B scale tensor view - [N/NXdlPackEff, K/32/KXdlPackEff] of int32_t
const auto scale_b_tensor_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const e8m0_t*>(scale_b.ptr),
make_tuple(kargs.N, scale_k_size), // [N, K/32] for access
make_tuple(scale_k_size, 1)); // stride to match col-major storage
reinterpret_cast<const int32_t*>(scale_b.ptr),
make_tuple(scale_n_packed, scale_k_packed),
make_tuple(scale_k_packed, 1));
// Create block window for scale B
// Tile window shape matches access pattern: [NPerBlock, KPerBlock/32]
// i_n is element offset (iN * NPerBlock)
auto scale_b_block_window =
make_tile_window(scale_b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / BlockScaleSize>{}),
{i_n, 0});
// Tile window shape: [NPerBlock/NXdlPackEff, KPerBlock/32/KXdlPackEff]
auto scale_b_block_window = make_tile_window(
scale_b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock / NXdlPackEff>{},
number<TilePartitioner::KPerBlock / BlockScaleSize / KXdlPackEff>{}),
{i_n / NXdlPackEff, 0});
return scale_b_block_window;
}

View File

@@ -315,14 +315,36 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
},
number<BsLayout::size()>{});
////////////// MX Scale windows /////////////////
////////////// MX Scale windows (pre-packed int32_t) /////////////////
// Get WarpGemm configuration
using BlockWarps = typename BlockGemmShape::BlockWarps;
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t MWarp = BlockWarps::at(I0{});
constexpr index_t NWarp = BlockWarps::at(I1{});
// Calculate scale dimensions: KPerBlock elements need KPerBlock/32 e8m0_t scales
constexpr index_t ScaleKDimPerBlock = KPerBlock / ScaleBlockSize;
// Compute effective XdlPack sizes (fall back to 1 when iter count < pack)
constexpr index_t MPerXdl = WarpTile::at(I0{});
constexpr index_t NPerXdl = WarpTile::at(I1{});
constexpr index_t KPerXdl = WarpTile::at(I2{});
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl);
constexpr index_t KIterPerWarp = KPerBlock / KPerXdl;
constexpr index_t MXdlPackEff =
(MIterPerWarp >= Policy::MXdlPack && MIterPerWarp % Policy::MXdlPack == 0)
? Policy::MXdlPack
: 1;
constexpr index_t NXdlPackEff =
(NIterPerWarp >= Policy::NXdlPack && NIterPerWarp % Policy::NXdlPack == 0)
? Policy::NXdlPack
: 1;
constexpr index_t KXdlPackEff =
(KIterPerWarp >= Policy::KXdlPack && KIterPerWarp % Policy::KXdlPack == 0)
? Policy::KXdlPack
: 1;
// Packed scale dimensions
constexpr index_t ScaleKDimPerBlock = KPerBlock / ScaleBlockSize / KXdlPackEff;
// Scale tensor views and base origins for creating tile windows per iteration
const auto& scale_a_tensor_view = scale_a_window.get_bottom_tensor_view();
@@ -330,18 +352,18 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
auto scale_a_base_origin = scale_a_window.get_window_origin();
auto scale_b_base_origin = scale_b_window.get_window_origin();
// Create sample scale windows to determine tile types
auto scale_a_dram_window =
make_tile_window(scale_a_tensor_view,
make_tuple(number<MPerBlock>{}, number<ScaleKDimPerBlock>{}),
scale_a_base_origin,
Policy::template MakeMX_ScaleA_DramTileDistribution<Problem>());
// Create scale windows with packed int32_t dimensions
auto scale_a_dram_window = make_tile_window(
scale_a_tensor_view,
make_tuple(number<MPerBlock / MXdlPackEff>{}, number<ScaleKDimPerBlock>{}),
scale_a_base_origin,
Policy::template MakeMX_ScaleA_DramTileDistribution<Problem>());
auto scale_b_dram_window =
make_tile_window(scale_b_tensor_view,
make_tuple(number<NPerBlock>{}, number<ScaleKDimPerBlock>{}),
scale_b_base_origin,
Policy::template MakeMX_ScaleB_DramTileDistribution<Problem>());
auto scale_b_dram_window = make_tile_window(
scale_b_tensor_view,
make_tuple(number<NPerBlock / NXdlPackEff>{}, number<ScaleKDimPerBlock>{}),
scale_b_base_origin,
Policy::template MakeMX_ScaleB_DramTileDistribution<Problem>());
// this pipeline has a pair of LDS buffers per logical tile
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
@@ -427,8 +449,8 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
"SmemSizeB size is wrong!");
////////////// MX Scale register tiles (ping-pong buffers) /////////////////
// No packing needed - each thread gets e8m0_t elements directly
// Each thread will cast e8m0_t to int32_t for WarpGemm with OpSel=0
// Scales are pre-packed int32_t: each int32_t holds 2M/N x 2K e8m0_t values
// Block GEMM uses OpSel (0-3) to select the right byte per MFMA call
using ScaleATileType = decltype(load_tile(scale_a_dram_window));
using ScaleBTileType = decltype(load_tile(scale_b_dram_window));

View File

@@ -131,7 +131,15 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
}
// MX Scale tile distributions for loading from global memory
// XdlPack: how many e8m0_t scale values are packed into one int32_t per dimension
// Host packs MXdlPack * KXdlPack e8m0_t into one int32_t for A scales
// Host packs NXdlPack * KXdlPack e8m0_t into one int32_t for B scales
static constexpr int MXdlPack = 2;
static constexpr int NXdlPack = 2;
static constexpr int KXdlPack = 2;
// MX Scale tile distributions for loading pre-packed int32_t from global memory
// Packed layout: [M/MXdlPack, K/32/KXdlPack] of int32_t
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
{
@@ -145,21 +153,29 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
constexpr index_t MPerXdl = WarpTile::at(number<0>{});
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K_Lane = get_warp_size() / MPerXdl; // 64/16 = 4 threads in K dimension
constexpr index_t K_Lane = get_warp_size() / MPerXdl;
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl);
constexpr index_t KPerXdl = WarpTile::at(number<2>{});
constexpr index_t KIterPerWarp = KPerBlock / KPerXdl;
constexpr index_t KPerLane = KPerXdl / BlockScaleSize / K_Lane;
// Effective pack sizes: fall back to 1 when iteration count < pack size
constexpr index_t MXdlPackEff =
(MIterPerWarp >= MXdlPack && MIterPerWarp % MXdlPack == 0) ? MXdlPack : 1;
constexpr index_t KXdlPackEff =
(KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1;
constexpr index_t MIterPerWarp_packed = MIterPerWarp / MXdlPackEff;
constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<NWarp>, // repeat over MWarps
tuple<sequence<MIterPerWarp, MWarp, MPerXdl>, // M dimension (first)
sequence<KIterPerWarp, K_Lane, KPerLane>>, // K dimension (second)
tuple<sequence<0, 1>, sequence<2, 1>>, // <MWarp, NWarp>, <K_Lane, MPerXdl>
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<2, 1, 2>, // <KIterPerWarp, MIterPerWarp, KPerLane>
sequence<0, 0, 2>>{});
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp_packed, MWarp, MPerXdl>,
sequence<KIterPerWarp_packed, K_Lane, KPerLane>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<2, 1, 2>,
sequence<0, 0, 2>>{});
}
template <typename Problem>
@@ -169,27 +185,35 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
using BlockWarps = typename BlockGemmShape::BlockWarps;
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t MWarp = BlockWarps::at(number<0>{});
constexpr index_t NWarp = BlockWarps::at(number<1>{});
constexpr index_t NPerXdl = WarpTile::at(number<1>{});
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 64/16 = 4 threads in K dimension
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t MWarp = BlockWarps::at(number<0>{});
constexpr index_t NWarp = BlockWarps::at(number<1>{});
constexpr index_t NPerXdl = WarpTile::at(number<1>{});
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K_Lane = get_warp_size() / NPerXdl;
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl);
constexpr index_t KPerXdl = WarpTile::at(number<2>{});
constexpr index_t KIterPerWarp = KPerBlock / KPerXdl;
constexpr index_t KPerLane = KPerXdl / BlockScaleSize / K_Lane;
// Effective pack sizes: fall back to 1 when iteration count < pack size
constexpr index_t NXdlPackEff =
(NIterPerWarp >= NXdlPack && NIterPerWarp % NXdlPack == 0) ? NXdlPack : 1;
constexpr index_t KXdlPackEff =
(KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1;
constexpr index_t NIterPerWarp_packed = NIterPerWarp / NXdlPackEff;
constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<MWarp>, // repeat over MWarps
tuple<sequence<NIterPerWarp, NWarp, NPerXdl>, // N dimension (first)
sequence<KIterPerWarp, K_Lane, KPerLane>>, // K dimension (second)
tuple<sequence<0, 1>, sequence<2, 1>>, // <MWarp, NWarp>, <K_Lane, MPerXdl>
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<2, 1, 2>, // <KIterPerWarp, NIterPerWarp, KPerLane>
sequence<0, 0, 2>>{});
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp_packed, NWarp, NPerXdl>,
sequence<KIterPerWarp_packed, K_Lane, KPerLane>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<2, 1, 2>,
sequence<0, 0, 2>>{});
}
};
} // namespace ck_tile