[rocm-libraries] ROCm/rocm-libraries#8554 (commit be9af54)

refactor(ck): mx gemm kernel unification

## Motivation

CK tile currently has two separate MX GEMM kernels for gfx950 and
gfx1250. This pull request refactors and modernizes the MX GEMM kernel
and example to use new scale tensor handling, improved kernel argument
structures, and updated pipeline and kernel APIs. The changes simplify
the interface and improve type safety.

JIRA ID ROCM-26313

## Technical Details

- Add support for gfx950 in MX GEMM kernel for gfx1250 and remove unused
kernel
 - Unify comp async pipeline for GEMM and MX GEMM
 - Unify eight waves pipeline for GEMM and MX GEMM
 - Move preshuffle MX GEMM pipeline to gemm ops and remove gemm_mx ops
 - Unify testing framework for MX GEMM
 - Add gfx950 tests for grouped MX GEMM

## Test Plan

 - `test_mx_gemm_async.cpp` for MX GEMM on gfx950
 - `test_mx_grouped_gemm_comp_async.cpp` for grouped MX GEMM on gfx950

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Enrico Degregori
2026-07-01 08:21:02 +00:00
committed by assistant-librarian[bot]
parent 604c56bc0e
commit d559ec00a8
60 changed files with 3703 additions and 5217 deletions

View File

@@ -7,148 +7,163 @@
namespace ck_tile {
/// @brief Pre-shuffle scale buffer for gfx1250 wmma mx scale instruction.
///
/// Reorganizes the scale data from row-major (MN x K) layout to the hardware-specific
/// layout expected by the gfx1250 wmma instruction.
///
/// @tparam ScaleType Scale data type (e.g., e8m0_t)
/// @tparam ScaleBlockSize The block size for microscaling (e.g., 32)
/// @tparam KStride Whether K is the fast-moving dimension
template <typename ScaleType, ck_tile::index_t ScaleBlockSize, bool KStride>
void preShuffleScaleBuffer_gfx1250(const ScaleType* src,
ScaleType* dst,
ck_tile::index_t MN,
ck_tile::index_t K)
{
static_assert((ScaleBlockSize == 32 || ScaleBlockSize == 16) && sizeof(ScaleType) == 1,
"wrong! only support 8-bit scale with ScaleBlockSize=32 or 16");
// ScaleBlockSize == 16: the natural row-major scale layout already matches the gfx1250
// wmma scale distribution (one e8m0 per 16 K-elements lands warp-aligned), so the
// device-side shuffle is the identity transform for all K.
if constexpr(ScaleBlockSize == 16)
{
for(ck_tile::long_index_t mn = 0; mn < MN; ++mn)
for(ck_tile::long_index_t k = 0; k < K; ++k)
{
if constexpr(KStride)
dst[mn * K + k] = src[mn * K + k];
else
dst[mn * K + k] = src[k * MN + mn];
}
return;
}
constexpr ck_tile::long_index_t MPerXdlops = 16;
constexpr ck_tile::long_index_t KPerXdlops = 128;
ck_tile::long_index_t MNPack = 2;
ck_tile::long_index_t KPack = 1;
ck_tile::long_index_t MNStep = MPerXdlops;
ck_tile::long_index_t KStep = KPerXdlops / ScaleBlockSize;
ck_tile::long_index_t K0 = K / KPack / KStep;
for(ck_tile::long_index_t mn = 0; mn < MN; ++mn)
{
ck_tile::long_index_t iMNRepeat = mn / (MNStep * MNPack);
ck_tile::long_index_t tempmn = mn % (MNStep * MNPack);
for(ck_tile::long_index_t k = 0; k < K; ++k)
{
ck_tile::long_index_t iKRepeat = k / (KStep * KPack);
ck_tile::long_index_t tempk = k % (KStep * KPack);
ck_tile::long_index_t outputIndex =
(iMNRepeat * MNPack * MNStep) * (KStep * KPack * K0) +
(iKRepeat * KStep * KPack) * (MNStep * MNPack) + tempmn * (KStep * KPack) + tempk;
if constexpr(KStride)
{
dst[outputIndex] = src[mn * K + k];
}
else
dst[outputIndex] = src[k * MN + mn];
}
}
}
// Pack [MN, K/32] e8m0_t scales into [MN/MNPack, K/32/KPack] int32_t
// Each int32_t contains MNPack * KPack e8m0_t values with byte layout matching
// the GPU tile distribution: values are XdlMNThread apart in M and XdlKThread apart in K.
// byte[ik * MNPack + imn] = e8m0 at strided (mn, k) position
// kLast=true for A scales (layout [M, K/32]), kLast=false for B scales (layout [K/32, N])
template <index_t MNPack = 2, index_t KPack = 2, index_t XdlMNThread = 16, index_t XdlKThread = 4>
auto packScalesMNxK(const HostTensor<e8m0_t>& src, const bool kLast)
template <ck_tile::index_t MNPack = 2,
ck_tile::index_t KPack = 2,
ck_tile::index_t XdlMNThread = 16,
ck_tile::index_t XdlKThread = 4,
typename ScaleType>
void preShuffleScaleBuffer_gfx950(const ScaleType* src,
ScaleType* packed,
ck_tile::index_t MN,
ck_tile::index_t K_scale,
bool kLast)
{
auto src_lengths = src.get_lengths();
const index_t MN = kLast ? src_lengths[0] : src_lengths[1];
const index_t K_scale = kLast ? src_lengths[1] : src_lengths[0];
const index_t MN_packed = MN / MNPack;
const index_t K_packed = K_scale / KPack;
const ck_tile::long_index_t MN_packed = MN / MNPack;
const ck_tile::long_index_t K_packed = K_scale / KPack;
constexpr ck_tile::long_index_t NumScalesPerDword = 4 / sizeof(ScaleType);
// Output as flat vector of int32_t (row-major [MN/MNPack, K/32/KPack])
HostTensor<int32_t> packed(HostTensorDescriptor(
{static_cast<std::size_t>(MN_packed), static_cast<std::size_t>(K_packed)},
{static_cast<std::size_t>(K_packed), static_cast<std::size_t>(1)}));
for(index_t packed_mn = 0; packed_mn < MN_packed; packed_mn++)
for(ck_tile::long_index_t packed_mn = 0; packed_mn < MN_packed; packed_mn++)
{
for(index_t packed_k = 0; packed_k < K_packed; packed_k++)
for(ck_tile::long_index_t packed_k = 0; packed_k < K_packed; packed_k++)
{
uint32_t val = 0;
index_t mn_lane = packed_mn % XdlMNThread;
index_t mn_group = packed_mn / XdlMNThread;
index_t k_lane = packed_k % XdlKThread;
index_t k_group = packed_k / XdlKThread;
for(index_t ik = 0; ik < KPack; ik++)
ck_tile::long_index_t mn_lane = packed_mn % XdlMNThread;
ck_tile::long_index_t mn_group = packed_mn / XdlMNThread;
ck_tile::long_index_t k_lane = packed_k % XdlKThread;
ck_tile::long_index_t k_group = packed_k / XdlKThread;
for(ck_tile::long_index_t ik = 0; ik < KPack; ik++)
{
for(index_t imn = 0; imn < MNPack; imn++)
for(ck_tile::long_index_t imn = 0; imn < MNPack; imn++)
{
index_t byteIdx = ik * MNPack + imn;
index_t orig_mn = mn_group * XdlMNThread * MNPack + imn * XdlMNThread + mn_lane;
index_t orig_k = k_group * XdlKThread * KPack + ik * XdlKThread + k_lane;
ck_tile::long_index_t byteIdx = ik * MNPack + imn;
ck_tile::long_index_t orig_mn =
mn_group * XdlMNThread * MNPack + imn * XdlMNThread + mn_lane;
ck_tile::long_index_t orig_k =
k_group * XdlKThread * KPack + ik * XdlKThread + k_lane;
e8m0_t v = kLast ? src(orig_mn, orig_k) : src(orig_k, orig_mn);
val |= (static_cast<uint32_t>(v.get()) << (byteIdx * 8));
ck_tile::long_index_t inputIndex =
kLast ? orig_k + orig_mn * K_scale : orig_mn + orig_k * MN;
ScaleType v = src[inputIndex];
ck_tile::long_index_t outputIndex =
byteIdx + (packed_mn % XdlMNThread) * NumScalesPerDword +
packed_k * XdlMNThread * NumScalesPerDword +
(packed_mn / XdlMNThread) * XdlMNThread * NumScalesPerDword * K_packed;
packed[outputIndex] = v;
}
}
packed(packed_mn, packed_k) = static_cast<int32_t>(val);
}
}
return packed;
}
template <index_t XdlMNThread, typename dtype>
auto preShuffleScale(ck_tile::HostTensor<dtype>& src, const bool kLast)
template <ck_tile::index_t NWarp,
ck_tile::index_t NPerBlock,
ck_tile::index_t XdlMNThread,
typename ScaleType>
auto preShuffleScaleBufferPermuteN_gfx950(
const ScaleType* src, ScaleType* shuffled, ck_tile::index_t MN, ck_tile::index_t K, bool kLast)
{
auto src_lengths = src.get_lengths();
const index_t MN = kLast ? src_lengths[0] : src_lengths[1];
const index_t K = kLast ? src_lengths[1] : src_lengths[0];
constexpr index_t MNXdlPack = 2;
constexpr index_t KXdlPack = 2;
constexpr index_t XdlKThread = get_warp_size() / XdlMNThread;
const auto MNPadded = integer_least_multiple(MN, XdlMNThread * MNXdlPack);
HostTensor<dtype> shuffled(HostTensorDescriptor({static_cast<std::size_t>(MNPadded * K)},
{static_cast<std::size_t>(1)}));
constexpr ck_tile::long_index_t MNXdlPack = 2;
constexpr ck_tile::long_index_t KXdlPack = 2;
constexpr ck_tile::long_index_t NRepeat = NPerBlock / NWarp / XdlMNThread;
constexpr ck_tile::long_index_t XdlKThread = ck_tile::get_warp_size() / XdlMNThread;
if(K % (KXdlPack * XdlKThread) != 0)
{
throw std::runtime_error("wrong! K must be a multiple of (KXdlPack * XdlKThread)");
}
const ck_tile::long_index_t K0 = K / KXdlPack / XdlKThread;
const index_t K0 = K / KXdlPack / XdlKThread;
for(index_t n = 0; n < MNPadded; ++n)
for(ck_tile::long_index_t n = 0; n < MN; ++n)
{
for(index_t k = 0; k < K; ++k)
for(ck_tile::long_index_t k = 0; k < K; ++k)
{
const index_t n0 = n / (XdlMNThread * MNXdlPack);
const index_t tempn = n % (XdlMNThread * MNXdlPack);
const index_t n1 = tempn % XdlMNThread;
const index_t n2 = tempn / XdlMNThread;
const ck_tile::long_index_t n0 = n / NPerBlock;
const ck_tile::long_index_t tempn0 = n % NPerBlock;
const ck_tile::long_index_t n1 = tempn0 / (XdlMNThread * NRepeat);
const ck_tile::long_index_t tempn1 = tempn0 % (XdlMNThread * NRepeat);
const ck_tile::long_index_t n2 = tempn1 / (NRepeat);
const ck_tile::long_index_t tempn2 = tempn1 % (NRepeat);
const ck_tile::long_index_t n3 = tempn2 % MNXdlPack;
const ck_tile::long_index_t n4 = tempn2 / MNXdlPack;
const index_t k0 = k / (XdlKThread * KXdlPack);
const index_t tempk = k % (XdlKThread * KXdlPack);
const index_t k1 = tempk % XdlKThread;
const index_t k2 = tempk / XdlKThread;
const ck_tile::long_index_t k0 = k / (XdlKThread * KXdlPack);
const ck_tile::long_index_t tempk = k % (XdlKThread * KXdlPack);
const ck_tile::long_index_t k1 = tempk % XdlKThread;
const ck_tile::long_index_t k2 = tempk / XdlKThread;
const index_t outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread +
n1 * MNXdlPack * KXdlPack + k2 * MNXdlPack + n2;
if(n < MN)
{
shuffled(outputIndex) = kLast ? src(n, k) : src(k, n);
}
else
{
shuffled(outputIndex) = dtype{};
}
}
}
return shuffled;
}
template <index_t NWarp, index_t NPerBlock, index_t XdlMNThread, typename dtype>
auto preShuffleScalePermuteN(const HostTensor<dtype>& src, const bool kLast)
{
auto src_lengths = src.get_lengths();
const index_t MN = kLast ? src_lengths[0] : src_lengths[1];
const index_t K = kLast ? src_lengths[1] : src_lengths[0];
constexpr index_t MNXdlPack = 2;
constexpr index_t KXdlPack = 2;
constexpr index_t NRepeat = NPerBlock / NWarp / XdlMNThread;
constexpr index_t XdlKThread = get_warp_size() / XdlMNThread; // 4
const index_t MNPadded = integer_least_multiple(MN, NPerBlock);
HostTensor<dtype> shuffled(HostTensorDescriptor({static_cast<std::size_t>(MNPadded * K)},
{static_cast<std::size_t>(1)}));
if(K % (KXdlPack * XdlKThread) != 0)
{
throw std::runtime_error("wrong! K must be a multiple of (KXdlPack * XdlKThread)");
}
const index_t K0 = K / KXdlPack / XdlKThread;
for(index_t n = 0; n < MNPadded; ++n)
{
for(index_t k = 0; k < K; ++k)
{
const index_t n0 = n / NPerBlock;
const index_t tempn0 = n % NPerBlock;
const index_t n1 = tempn0 / (XdlMNThread * NRepeat);
const index_t tempn1 = tempn0 % (XdlMNThread * NRepeat);
const index_t n2 = tempn1 / (NRepeat);
const index_t tempn2 = tempn1 % (NRepeat);
const index_t n3 = tempn2 % MNXdlPack;
const index_t n4 = tempn2 / MNXdlPack;
const index_t k0 = k / (XdlKThread * KXdlPack);
const index_t tempk = k % (XdlKThread * KXdlPack);
const index_t k1 = tempk % XdlKThread;
const index_t k2 = tempk / XdlKThread;
const index_t outputIndex =
const ck_tile::long_index_t outputIndex =
n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 * NWarp *
(NRepeat / MNXdlPack) +
n1 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
@@ -156,13 +171,15 @@ auto preShuffleScalePermuteN(const HostTensor<dtype>& src, const bool kLast)
k1 * MNXdlPack * KXdlPack * XdlMNThread + k2 * MNXdlPack +
n4 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 * NWarp + n3;
ck_tile::long_index_t inputIndex = kLast ? k + n * K : n + k * MN;
if(n < MN)
{
shuffled(outputIndex) = kLast ? src(n, k) : src(k, n);
shuffled[outputIndex] = src[inputIndex];
}
else
{
shuffled(outputIndex) = dtype{};
shuffled[outputIndex] = ScaleType{};
}
}
}

View File

@@ -1184,4 +1184,131 @@ void reference_batched_gemm_gpu(ADataType* a_ptr,
return;
}
// GPU reference for MX (microscaling) GEMM with e8m0 block scales.
//
// This is the device counterpart of the host `reference_mx_gemm` above. It exists so the
// reference can be computed entirely on the GPU for large problems (e.g. M*N ~ 1e9) where the
// host reference is intractable and where copying the 39 GB of inputs back to host is not
// feasible. It is a faithful mirror of the host semantics:
// - per-element dot product over K, with each A/B element dequantized by its e8m0 block scale.
// - all addressing uses `long`; the existing `naive_gemm_kernel`/`blockwise_gemm_kernel` use
// `int` and silently overflow once M*N exceeds INT_MAX.
//
// Layout assumptions match the fp4 CompAsync grouped-GEMM test row (RowMajor A, ColumnMajor B,
// RowMajor C):
// - A is RowMajor: element (m,k) at linear offset m*K + k.
// - B is ColumnMajor: element (k,n) at linear offset n*K + k (K is the fast dimension).
// - C is RowMajor: element (m,n) at linear offset m*N + n.
// - scale_a is (M, num_scale_k) RowMajor: scale for (m, k) at m*num_scale_k + k/scale_block_size.
// - scale_b is (N, num_scale_k) RowMajor: scale for (k, n) at n*num_scale_k + k/scale_block_size.
template <typename ADataType,
typename BDataType,
typename AScaleDataType,
typename BScaleDataType,
typename AccDataType,
typename CDataType>
__global__ void reference_mx_gemm_kernel(const ADataType* __restrict__ a_ptr,
const BDataType* __restrict__ b_ptr,
const AScaleDataType* __restrict__ scale_a_ptr,
const BScaleDataType* __restrict__ scale_b_ptr,
CDataType* __restrict__ c_ptr,
long M,
long N,
long K,
long num_scale_k,
long scale_block_size)
{
const long total = M * N;
const long idx0 = static_cast<long>(blockIdx.x) * blockDim.x + threadIdx.x;
const long nthr = static_cast<long>(gridDim.x) * blockDim.x;
for(long idx = idx0; idx < total; idx += nthr)
{
const long m = idx / N;
const long n = idx % N;
AccDataType acc = 0;
for(long k = 0; k < K; ++k)
{
// --- A element (RowMajor) ---
AccDataType a_val;
const long a_lin = m * K + k;
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
{
const fp32x2_t a_f2 = pk_fp4_to_fp32x2(a_ptr[a_lin / 2], 1.0f);
a_val = type_convert<AccDataType>((a_lin % 2 == 0) ? a_f2.lo : a_f2.hi);
}
else
{
a_val = type_convert<AccDataType>(a_ptr[a_lin]);
}
const float a_sc =
type_convert<float>(scale_a_ptr[m * num_scale_k + k / scale_block_size]);
// --- B element (ColumnMajor, K fast) ---
AccDataType b_val;
const long b_lin = n * K + k;
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
{
const fp32x2_t b_f2 = pk_fp4_to_fp32x2(b_ptr[b_lin / 2], 1.0f);
b_val = type_convert<AccDataType>((b_lin % 2 == 0) ? b_f2.lo : b_f2.hi);
}
else
{
b_val = type_convert<AccDataType>(b_ptr[b_lin]);
}
const float b_sc =
type_convert<float>(scale_b_ptr[n * num_scale_k + k / scale_block_size]);
acc += (a_val * type_convert<AccDataType>(a_sc)) *
(b_val * type_convert<AccDataType>(b_sc));
}
c_ptr[m * N + n] = type_convert<CDataType>(acc);
}
}
template <typename ADataType,
typename BDataType,
typename AScaleDataType,
typename BScaleDataType,
typename AccDataType,
typename CDataType>
void reference_mx_gemm_gpu(const ADataType* a_ptr,
const BDataType* b_ptr,
const AScaleDataType* scale_a_ptr,
const BScaleDataType* scale_b_ptr,
CDataType* c_ptr,
index_t M,
index_t N,
index_t K,
index_t num_scale_k,
index_t scale_block_size,
hipStream_t stream = nullptr)
{
const long total = static_cast<long>(M) * N;
constexpr int threads = 256;
constexpr long max_blocks = 2097152; // grid-stride cap (~2M blocks)
const long needed = (total + threads - 1) / threads;
const long blocks = needed < max_blocks ? needed : max_blocks;
reference_mx_gemm_kernel<ADataType,
BDataType,
AScaleDataType,
BScaleDataType,
AccDataType,
CDataType>
<<<dim3(static_cast<unsigned>(blocks)), dim3(threads), 0, stream>>>(
a_ptr,
b_ptr,
scale_a_ptr,
scale_b_ptr,
c_ptr,
static_cast<long>(M),
static_cast<long>(N),
static_cast<long>(K),
static_cast<long>(num_scale_k),
static_cast<long>(scale_block_size));
hip_check_error(hipGetLastError());
}
} // namespace ck_tile