mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-04 22:27:42 +00:00
[rocm-libraries] ROCm/rocm-libraries#8554 (commit be9af54)
refactor(ck): mx gemm kernel unification ## Motivation CK tile currently has two separate MX GEMM kernels for gfx950 and gfx1250. This pull request refactors and modernizes the MX GEMM kernel and example to use new scale tensor handling, improved kernel argument structures, and updated pipeline and kernel APIs. The changes simplify the interface and improve type safety. JIRA ID ROCM-26313 ## Technical Details - Add support for gfx950 in MX GEMM kernel for gfx1250 and remove unused kernel - Unify comp async pipeline for GEMM and MX GEMM - Unify eight waves pipeline for GEMM and MX GEMM - Move preshuffle MX GEMM pipeline to gemm ops and remove gemm_mx ops - Unify testing framework for MX GEMM - Add gfx950 tests for grouped MX GEMM ## Test Plan - `test_mx_gemm_async.cpp` for MX GEMM on gfx950 - `test_mx_grouped_gemm_comp_async.cpp` for grouped MX GEMM on gfx950 ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
604c56bc0e
commit
d559ec00a8
@@ -1184,4 +1184,131 @@ void reference_batched_gemm_gpu(ADataType* a_ptr,
|
||||
return;
|
||||
}
|
||||
|
||||
// GPU reference for MX (microscaling) GEMM with e8m0 block scales.
|
||||
//
|
||||
// This is the device counterpart of the host `reference_mx_gemm` above. It exists so the
|
||||
// reference can be computed entirely on the GPU for large problems (e.g. M*N ~ 1e9) where the
|
||||
// host reference is intractable and where copying the 39 GB of inputs back to host is not
|
||||
// feasible. It is a faithful mirror of the host semantics:
|
||||
// - per-element dot product over K, with each A/B element dequantized by its e8m0 block scale.
|
||||
// - all addressing uses `long`; the existing `naive_gemm_kernel`/`blockwise_gemm_kernel` use
|
||||
// `int` and silently overflow once M*N exceeds INT_MAX.
|
||||
//
|
||||
// Layout assumptions match the fp4 CompAsync grouped-GEMM test row (RowMajor A, ColumnMajor B,
|
||||
// RowMajor C):
|
||||
// - A is RowMajor: element (m,k) at linear offset m*K + k.
|
||||
// - B is ColumnMajor: element (k,n) at linear offset n*K + k (K is the fast dimension).
|
||||
// - C is RowMajor: element (m,n) at linear offset m*N + n.
|
||||
// - scale_a is (M, num_scale_k) RowMajor: scale for (m, k) at m*num_scale_k + k/scale_block_size.
|
||||
// - scale_b is (N, num_scale_k) RowMajor: scale for (k, n) at n*num_scale_k + k/scale_block_size.
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AScaleDataType,
|
||||
typename BScaleDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType>
|
||||
__global__ void reference_mx_gemm_kernel(const ADataType* __restrict__ a_ptr,
|
||||
const BDataType* __restrict__ b_ptr,
|
||||
const AScaleDataType* __restrict__ scale_a_ptr,
|
||||
const BScaleDataType* __restrict__ scale_b_ptr,
|
||||
CDataType* __restrict__ c_ptr,
|
||||
long M,
|
||||
long N,
|
||||
long K,
|
||||
long num_scale_k,
|
||||
long scale_block_size)
|
||||
{
|
||||
const long total = M * N;
|
||||
const long idx0 = static_cast<long>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
const long nthr = static_cast<long>(gridDim.x) * blockDim.x;
|
||||
|
||||
for(long idx = idx0; idx < total; idx += nthr)
|
||||
{
|
||||
const long m = idx / N;
|
||||
const long n = idx % N;
|
||||
|
||||
AccDataType acc = 0;
|
||||
for(long k = 0; k < K; ++k)
|
||||
{
|
||||
// --- A element (RowMajor) ---
|
||||
AccDataType a_val;
|
||||
const long a_lin = m * K + k;
|
||||
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t a_f2 = pk_fp4_to_fp32x2(a_ptr[a_lin / 2], 1.0f);
|
||||
a_val = type_convert<AccDataType>((a_lin % 2 == 0) ? a_f2.lo : a_f2.hi);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_val = type_convert<AccDataType>(a_ptr[a_lin]);
|
||||
}
|
||||
const float a_sc =
|
||||
type_convert<float>(scale_a_ptr[m * num_scale_k + k / scale_block_size]);
|
||||
|
||||
// --- B element (ColumnMajor, K fast) ---
|
||||
AccDataType b_val;
|
||||
const long b_lin = n * K + k;
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t b_f2 = pk_fp4_to_fp32x2(b_ptr[b_lin / 2], 1.0f);
|
||||
b_val = type_convert<AccDataType>((b_lin % 2 == 0) ? b_f2.lo : b_f2.hi);
|
||||
}
|
||||
else
|
||||
{
|
||||
b_val = type_convert<AccDataType>(b_ptr[b_lin]);
|
||||
}
|
||||
const float b_sc =
|
||||
type_convert<float>(scale_b_ptr[n * num_scale_k + k / scale_block_size]);
|
||||
|
||||
acc += (a_val * type_convert<AccDataType>(a_sc)) *
|
||||
(b_val * type_convert<AccDataType>(b_sc));
|
||||
}
|
||||
c_ptr[m * N + n] = type_convert<CDataType>(acc);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AScaleDataType,
|
||||
typename BScaleDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType>
|
||||
void reference_mx_gemm_gpu(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const AScaleDataType* scale_a_ptr,
|
||||
const BScaleDataType* scale_b_ptr,
|
||||
CDataType* c_ptr,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t num_scale_k,
|
||||
index_t scale_block_size,
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
const long total = static_cast<long>(M) * N;
|
||||
constexpr int threads = 256;
|
||||
constexpr long max_blocks = 2097152; // grid-stride cap (~2M blocks)
|
||||
const long needed = (total + threads - 1) / threads;
|
||||
const long blocks = needed < max_blocks ? needed : max_blocks;
|
||||
|
||||
reference_mx_gemm_kernel<ADataType,
|
||||
BDataType,
|
||||
AScaleDataType,
|
||||
BScaleDataType,
|
||||
AccDataType,
|
||||
CDataType>
|
||||
<<<dim3(static_cast<unsigned>(blocks)), dim3(threads), 0, stream>>>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
scale_a_ptr,
|
||||
scale_b_ptr,
|
||||
c_ptr,
|
||||
static_cast<long>(M),
|
||||
static_cast<long>(N),
|
||||
static_cast<long>(K),
|
||||
static_cast<long>(num_scale_k),
|
||||
static_cast<long>(scale_block_size));
|
||||
hip_check_error(hipGetLastError());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user