[rocm-libraries] ROCm/rocm-libraries#5323 (commit 5454e9e)

CK Tile MX GEMM Packing Improvement

## Motivation

Reduce the scale loading size and also has better utilization of MFMA
scale selection.

## Technical Details

Add up the packing of mx scales.

## Test Plan

Use the existing test cases.

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Thomas Ning
2026-03-17 18:58:56 +00:00
committed by assistant-librarian[bot]
parent 859acb5ae7
commit 5f90f69795
9 changed files with 399 additions and 130 deletions

View File

@@ -14,7 +14,56 @@ auto calculate_rtol_atol(const ck_tile::index_t K, const float max_accumulated_v
return ck_tile::make_tuple(rtol, atol);
}
// Use e8m0_t directly without packing - simpler and cleaner approach
// Pack [MN, K/32] e8m0_t scales into [MN/MNPack, K/32/KPack] int32_t
// Each int32_t contains MNPack * KPack e8m0_t values with byte layout matching
// the GPU tile distribution: values are XdlMNThread apart in M and XdlKThread apart in K.
// byte[ik * MNPack + imn] = e8m0 at strided (mn, k) position
// kLast=true for A scales (layout [M, K/32]), kLast=false for B scales (layout [K/32, N])
template <ck_tile::index_t MNPack = 2,
ck_tile::index_t KPack = 2,
ck_tile::index_t XdlMNThread = 16,
ck_tile::index_t XdlKThread = 4>
auto packScalesMNxK(const ck_tile::HostTensor<ck_tile::e8m0_t>& src, bool kLast)
{
auto src_lengths = src.get_lengths();
const ck_tile::index_t MN = kLast ? src_lengths[0] : src_lengths[1];
const ck_tile::index_t K_scale = kLast ? src_lengths[1] : src_lengths[0];
const ck_tile::index_t MN_packed = MN / MNPack;
const ck_tile::index_t K_packed = K_scale / KPack;
const ck_tile::index_t total_packed = MN_packed * K_packed;
// Output as flat vector of int32_t (row-major [MN/MNPack, K/32/KPack])
std::vector<int32_t> packed(total_packed);
for(ck_tile::index_t packed_mn = 0; packed_mn < MN_packed; packed_mn++)
{
for(ck_tile::index_t packed_k = 0; packed_k < K_packed; packed_k++)
{
int32_t val = 0;
ck_tile::index_t mn_lane = packed_mn % XdlMNThread;
ck_tile::index_t mn_group = packed_mn / XdlMNThread;
ck_tile::index_t k_lane = packed_k % XdlKThread;
ck_tile::index_t k_group = packed_k / XdlKThread;
for(ck_tile::index_t ik = 0; ik < KPack; ik++)
{
for(ck_tile::index_t imn = 0; imn < MNPack; imn++)
{
ck_tile::index_t byteIdx = ik * MNPack + imn;
ck_tile::index_t orig_mn =
mn_group * XdlMNThread * MNPack + imn * XdlMNThread + mn_lane;
ck_tile::index_t orig_k =
k_group * XdlKThread * KPack + ik * XdlKThread + k_lane;
ck_tile::e8m0_t v = kLast ? src(orig_mn, orig_k) : src(orig_k, orig_mn);
val |= (static_cast<int32_t>(v.get()) << (byteIdx * 8));
}
}
packed[packed_mn * K_packed + packed_k] = val;
}
}
return packed;
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
@@ -101,21 +150,43 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout)
break;
}
// Device buffers for A, B, C, and scale tensors
// Compute effective XdlPack sizes based on GemmConfig tile dimensions
constexpr ck_tile::index_t MPerXdl_ = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t NPerXdl_ = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t KPerXdl_ = GemmConfig::K_Warp_Tile;
constexpr ck_tile::index_t MIterPerWarp_ = GemmConfig::M_Tile / (GemmConfig::M_Warp * MPerXdl_);
constexpr ck_tile::index_t NIterPerWarp_ = GemmConfig::N_Tile / (GemmConfig::N_Warp * NPerXdl_);
constexpr ck_tile::index_t KIterPerWarp_ = GemmConfig::K_Tile / KPerXdl_;
constexpr ck_tile::index_t MXdlPackEff = (MIterPerWarp_ >= 2 && MIterPerWarp_ % 2 == 0) ? 2 : 1;
constexpr ck_tile::index_t NXdlPackEff = (NIterPerWarp_ >= 2 && NIterPerWarp_ % 2 == 0) ? 2 : 1;
constexpr ck_tile::index_t KXdlPackEff = (KIterPerWarp_ >= 2 && KIterPerWarp_ % 2 == 0) ? 2 : 1;
// Pack scales: [M, K/32] e8m0_t → [M/MXdlPackEff, K/32/KXdlPackEff] int32_t
// Original unpacked tensors are kept for CPU reference validation
constexpr ck_tile::index_t XdlMNThread = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread;
auto scale_a_packed =
packScalesMNxK<MXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(scale_a_host, true);
auto scale_b_packed =
packScalesMNxK<NXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(scale_b_host, false);
// Device buffers for A, B, C, and packed scale tensors
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_dev_buf(b_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem scale_a_dev_buf(scale_a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem scale_b_dev_buf(scale_b_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem scale_a_dev_buf(scale_a_packed.size() * sizeof(int32_t));
ck_tile::DeviceMem scale_b_dev_buf(scale_b_packed.size() * sizeof(int32_t));
a_dev_buf.ToDevice(a_host.data());
b_dev_buf.ToDevice(b_host.data());
c_dev_buf.SetZero(); // Initialize C buffer to zero
scale_a_dev_buf.ToDevice(scale_a_host.data());
scale_b_dev_buf.ToDevice(scale_b_host.data());
c_dev_buf.SetZero();
scale_a_dev_buf.ToDevice(scale_a_packed.data());
scale_b_dev_buf.ToDevice(scale_b_packed.data());
// Scale pointers - use e8m0_t* directly
using ScaleM = ck_tile::MXScalePointer<ScaleType, 1, 32>; // in blocks of 32 in K
// Scale pointers - point to packed int32_t data, kernel reinterprets as int32_t*
using ScaleM = ck_tile::MXScalePointer<ScaleType, 1, 32>;
using ScaleN = ck_tile::MXScalePointer<ScaleType, 1, 32>;
ScaleM scale_m(reinterpret_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()));
ScaleN scale_n(reinterpret_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()));