[rocm-libraries] ROCm/rocm-libraries#5323 (commit 5454e9e)

CK Tile MX GEMM Packing Improvement

## Motivation

Reduce the scale loading size and also has better utilization of MFMA
scale selection.

## Technical Details

Add up the packing of mx scales.

## Test Plan

Use the existing test cases.

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Thomas Ning
2026-03-17 18:58:56 +00:00
committed by assistant-librarian[bot]
parent 859acb5ae7
commit 5f90f69795
9 changed files with 399 additions and 130 deletions

View File

@@ -45,6 +45,55 @@ class TestMxGemmUtil : public ::testing::Test
using ScaleM = ck_tile::MXScalePointer<ScaleType, 1, 32>;
using ScaleN = ck_tile::MXScalePointer<ScaleType, 1, 32>;
// Pack [MN, K/32] e8m0_t scales into [MN/MNPack, K/32/KPack] int32_t
// Each int32_t contains MNPack * KPack e8m0_t values with byte layout matching
// the GPU tile distribution: values are XdlMNThread apart in M and XdlKThread apart in K.
// byte[ik * MNPack + imn] = e8m0 at strided (mn, k) position
// kLast=true for A scales (layout [M, K/32]), kLast=false for B scales (layout [K/32, N])
template <ck_tile::index_t MNPack = 2,
ck_tile::index_t KPack = 2,
ck_tile::index_t XdlMNThread = 16,
ck_tile::index_t XdlKThread = 4>
static auto packScalesMNxK(const ck_tile::HostTensor<ck_tile::e8m0_t>& src, bool kLast)
{
auto src_lengths = src.get_lengths();
const ck_tile::index_t MN = kLast ? src_lengths[0] : src_lengths[1];
const ck_tile::index_t K_scale = kLast ? src_lengths[1] : src_lengths[0];
const ck_tile::index_t MN_packed = MN / MNPack;
const ck_tile::index_t K_packed = K_scale / KPack;
const ck_tile::index_t total_packed = MN_packed * K_packed;
std::vector<int32_t> packed(total_packed);
for(ck_tile::index_t packed_mn = 0; packed_mn < MN_packed; packed_mn++)
{
for(ck_tile::index_t packed_k = 0; packed_k < K_packed; packed_k++)
{
int32_t val = 0;
ck_tile::index_t mn_lane = packed_mn % XdlMNThread;
ck_tile::index_t mn_group = packed_mn / XdlMNThread;
ck_tile::index_t k_lane = packed_k % XdlKThread;
ck_tile::index_t k_group = packed_k / XdlKThread;
for(ck_tile::index_t ik = 0; ik < KPack; ik++)
{
for(ck_tile::index_t imn = 0; imn < MNPack; imn++)
{
ck_tile::index_t byteIdx = ik * MNPack + imn;
ck_tile::index_t orig_mn =
mn_group * XdlMNThread * MNPack + imn * XdlMNThread + mn_lane;
ck_tile::index_t orig_k =
k_group * XdlKThread * KPack + ik * XdlKThread + k_lane;
ck_tile::e8m0_t v = kLast ? src(orig_mn, orig_k) : src(orig_k, orig_mn);
val |= (static_cast<int32_t>(v.get()) << (byteIdx * 8));
}
}
packed[packed_mn * K_packed + packed_k] = val;
}
}
return packed;
}
void Run(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, int seed = 1234)
{
const ck_tile::index_t scale_k_size = K / 32;
@@ -75,17 +124,43 @@ class TestMxGemmUtil : public ::testing::Test
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_a_host);
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_b_host);
// Compute effective XdlPack sizes based on GemmConfig tile dimensions
constexpr ck_tile::index_t MPerXdl = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t NPerXdl = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t KPerXdl = GemmConfig::K_Warp_Tile;
constexpr ck_tile::index_t MIterPerWarp =
GemmConfig::M_Tile / (GemmConfig::M_Warp * MPerXdl);
constexpr ck_tile::index_t NIterPerWarp =
GemmConfig::N_Tile / (GemmConfig::N_Warp * NPerXdl);
constexpr ck_tile::index_t KIterPerWarp = GemmConfig::K_Tile / KPerXdl;
constexpr ck_tile::index_t MXdlPackEff =
(MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1;
constexpr ck_tile::index_t NXdlPackEff =
(NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1;
constexpr ck_tile::index_t KXdlPackEff =
(KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1;
constexpr ck_tile::index_t XdlMNThread = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread;
// Pack scales into int32_t for GPU consumption
auto scale_a_packed =
packScalesMNxK<MXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(scale_a_host, true);
auto scale_b_packed =
packScalesMNxK<NXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(scale_b_host, false);
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_dev_buf(b_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem scale_a_dev_buf(scale_a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem scale_b_dev_buf(scale_b_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem scale_a_dev_buf(scale_a_packed.size() * sizeof(int32_t));
ck_tile::DeviceMem scale_b_dev_buf(scale_b_packed.size() * sizeof(int32_t));
a_dev_buf.ToDevice(a_host.data());
b_dev_buf.ToDevice(b_host.data());
c_dev_buf.SetZero();
scale_a_dev_buf.ToDevice(scale_a_host.data());
scale_b_dev_buf.ToDevice(scale_b_host.data());
scale_a_dev_buf.ToDevice(scale_a_packed.data());
scale_b_dev_buf.ToDevice(scale_b_packed.data());
ScaleM scale_m(reinterpret_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()));
ScaleN scale_n(reinterpret_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()));