mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
[rocm-libraries] ROCm/rocm-libraries#8554 (commit be9af54)
refactor(ck): mx gemm kernel unification ## Motivation CK tile currently has two separate MX GEMM kernels for gfx950 and gfx1250. This pull request refactors and modernizes the MX GEMM kernel and example to use new scale tensor handling, improved kernel argument structures, and updated pipeline and kernel APIs. The changes simplify the interface and improve type safety. JIRA ID ROCM-26313 ## Technical Details - Add support for gfx950 in MX GEMM kernel for gfx1250 and remove unused kernel - Unify comp async pipeline for GEMM and MX GEMM - Unify eight waves pipeline for GEMM and MX GEMM - Move preshuffle MX GEMM pipeline to gemm ops and remove gemm_mx ops - Unify testing framework for MX GEMM - Add gfx950 tests for grouped MX GEMM ## Test Plan - `test_mx_gemm_async.cpp` for MX GEMM on gfx950 - `test_mx_grouped_gemm_comp_async.cpp` for grouped MX GEMM on gfx950 ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
604c56bc0e
commit
d559ec00a8
@@ -520,7 +520,7 @@ class GemmKernelBuilder:
|
||||
}
|
||||
elif self.kernel_name_prefix == "mx_gemm":
|
||||
pipeline_impl_map = {
|
||||
"comp_async": "ck_tile::MXGemmPipelineAgBgCrCompAsync",
|
||||
"comp_async": "ck_tile::GemmPipelineAgBgCrCompAsync",
|
||||
}
|
||||
base_pipeline_map = {}
|
||||
|
||||
@@ -581,9 +581,6 @@ class GemmKernelBuilder:
|
||||
instance_code += """#include <vector>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
"""
|
||||
elif self.kernel_name_prefix == "mx_gemm":
|
||||
instance_code += """#include "ck_tile/ops/gemm_mx.hpp"
|
||||
"""
|
||||
return instance_code
|
||||
|
||||
@@ -617,9 +614,7 @@ using CDataType = {get_dtype_string(c_type)};"""
|
||||
if self.kernel_name_prefix == "mx_gemm":
|
||||
instance_code += """
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
using ScaleM = ck_tile::MXScalePointer<ScaleType, 1, 32>;
|
||||
using ScaleN = ck_tile::MXScalePointer<ScaleType, 1, 32>;
|
||||
using MxGemmHostArgs = ck_tile::MXGemmKernelArgs<ScaleM, ScaleN, 1, 1, 0>;"""
|
||||
using MxGemmHostArgs = ck_tile::MxGemmHostArgs<1, 1, 0>;"""
|
||||
|
||||
if self.kernel_name_prefix == "gemm_multi_d":
|
||||
instance_code += f"""
|
||||
@@ -684,7 +679,7 @@ struct SelectedKernel {{
|
||||
static constexpr bool kPadN = {"true" if pad_n in [True, "true"] else "false"};
|
||||
static constexpr bool kPadK = {"true" if pad_k in [True, "true"] else "false"};
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool DoubleSmemBuffer = {"true" if pipeline in ["compv4", "preshufflev2"] else "false"};"""
|
||||
static constexpr bool DoubleSmemBuffer = {"true" if pipeline in ["compv4", "preshufflev2", "comp_async"] else "false"};"""
|
||||
|
||||
if self.kernel_name_prefix in [
|
||||
"gemm_universal",
|
||||
@@ -1069,17 +1064,17 @@ struct SelectedKernel {{
|
||||
instance_code += f"""
|
||||
|
||||
// Kernel type
|
||||
using Kernel = ck_tile::MXGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
using Kernel = ck_tile::MxGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
// Kernel arguments
|
||||
auto kargs = args;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
if(!Kernel::Underlying::IsSupportedArgument(kargs)) {{
|
||||
if(!Kernel::IsSupportedArgument(kargs)) {{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping mx gemm!");
|
||||
}}
|
||||
|
||||
// Get grid and block sizes
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(stream.log_level_ > 0) {{
|
||||
@@ -1245,7 +1240,15 @@ struct SelectedKernel {{
|
||||
WarpTileM,
|
||||
WarpTileN,
|
||||
WarpTileK,
|
||||
TransposeC>;
|
||||
TransposeC,
|
||||
1, // NumWaveGroups
|
||||
false, // FixedVectorSize_
|
||||
1, // VectorSizeC_
|
||||
1, // BlockedXDLNPerWarp
|
||||
false, // DoubleSmemBuffer_
|
||||
ADataType, // AComputeDataType
|
||||
BDataType, // BComputeDataType
|
||||
true>; // TilesPacked_
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
|
||||
return instance_code
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/ops/gemm_mx.hpp"
|
||||
#include "gemm/gemm_profiler.hpp"
|
||||
#include "mx_gemm_benchmark.hpp"
|
||||
|
||||
@@ -49,15 +48,13 @@ class MXGemmProfiler : public GemmProfiler<MXGemmProfiler, GemmProblem, MxGemmHo
|
||||
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
|
||||
|
||||
const ck_tile::index_t scale_k_size = gemm_problem.k_ / 32;
|
||||
const ck_tile::index_t stride_scale_a =
|
||||
ck_tile::get_default_stride(gemm_problem.m_, scale_k_size, 0, is_row_major(layout_a));
|
||||
const ck_tile::index_t stride_scale_b =
|
||||
ck_tile::get_default_stride(scale_k_size, gemm_problem.n_, 0, is_row_major(layout_b));
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_a_host(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.m_, scale_k_size, stride_scale_a, is_row_major(layout_a)));
|
||||
ck_tile::HostTensor<ScaleType> scale_b_host(ck_tile::host_tensor_descriptor(
|
||||
scale_k_size, gemm_problem.n_, stride_scale_b, is_row_major(layout_b)));
|
||||
ck_tile::HostTensor<ScaleType> scale_a_host(
|
||||
{static_cast<std::size_t>(gemm_problem.m_), static_cast<std::size_t>(scale_k_size)},
|
||||
{static_cast<std::size_t>(scale_k_size), static_cast<std::size_t>(1)});
|
||||
ck_tile::HostTensor<ScaleType> scale_b_host(
|
||||
{static_cast<std::size_t>(gemm_problem.n_), static_cast<std::size_t>(scale_k_size)},
|
||||
{static_cast<std::size_t>(scale_k_size), static_cast<std::size_t>(1)});
|
||||
|
||||
if(setting_.init_method == 0)
|
||||
{
|
||||
@@ -109,31 +106,47 @@ class MXGemmProfiler : public GemmProfiler<MXGemmProfiler, GemmProblem, MxGemmHo
|
||||
constexpr ck_tile::index_t xdl_mn_thread = SelectedKernel::WarpTileM;
|
||||
constexpr ck_tile::index_t xdl_k_thread = 64 / xdl_mn_thread;
|
||||
|
||||
auto scale_a_packed =
|
||||
pack_mx_scales_mn_x_k<m_xdl_pack, k_xdl_pack, xdl_mn_thread, xdl_k_thread>(scale_a_host,
|
||||
true);
|
||||
auto scale_b_packed =
|
||||
pack_mx_scales_mn_x_k<n_xdl_pack, k_xdl_pack, xdl_mn_thread, xdl_k_thread>(scale_b_host,
|
||||
false);
|
||||
ck_tile::HostTensor<ScaleType> scale_a_shuffled(
|
||||
{static_cast<std::size_t>(gemm_problem.m_ / m_xdl_pack * 2),
|
||||
static_cast<std::size_t>(scale_k_size / k_xdl_pack * 2)},
|
||||
{static_cast<std::size_t>(scale_k_size / k_xdl_pack * 2), static_cast<std::size_t>(1)});
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_b_shuffled(
|
||||
{static_cast<std::size_t>(gemm_problem.n_ / n_xdl_pack * 2),
|
||||
static_cast<std::size_t>(scale_k_size / k_xdl_pack * 2)},
|
||||
{static_cast<std::size_t>(scale_k_size / k_xdl_pack * 2), static_cast<std::size_t>(1)});
|
||||
|
||||
ck_tile::preShuffleScaleBuffer_gfx950<m_xdl_pack, k_xdl_pack, xdl_mn_thread, xdl_k_thread>(
|
||||
scale_a_host.mData.data(),
|
||||
scale_a_shuffled.mData.data(),
|
||||
gemm_problem.m_,
|
||||
scale_k_size,
|
||||
true);
|
||||
|
||||
ck_tile::preShuffleScaleBuffer_gfx950<n_xdl_pack, k_xdl_pack, xdl_mn_thread, xdl_k_thread>(
|
||||
scale_b_host.mData.data(),
|
||||
scale_b_shuffled.mData.data(),
|
||||
gemm_problem.n_,
|
||||
scale_k_size,
|
||||
true);
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_a_dev_buf(scale_a_packed.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_packed.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
scale_a_dev_buf.ToDevice(scale_a_packed.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_packed.data());
|
||||
|
||||
ScaleM scale_m(reinterpret_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()));
|
||||
ScaleN scale_n(reinterpret_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()));
|
||||
scale_a_dev_buf.ToDevice(scale_a_shuffled.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_shuffled.data());
|
||||
|
||||
MxGemmHostArgs gemm_args({a_m_k_dev_buf.GetDeviceBuffer()},
|
||||
{scale_a_dev_buf.GetDeviceBuffer()},
|
||||
{b_k_n_dev_buf.GetDeviceBuffer()},
|
||||
{scale_b_dev_buf.GetDeviceBuffer()},
|
||||
{},
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
gemm_problem.split_k_,
|
||||
@@ -143,17 +156,26 @@ class MXGemmProfiler : public GemmProfiler<MXGemmProfiler, GemmProblem, MxGemmHo
|
||||
{gemm_problem.stride_a_},
|
||||
{gemm_problem.stride_b_},
|
||||
{},
|
||||
gemm_problem.stride_c_,
|
||||
scale_m,
|
||||
scale_n);
|
||||
gemm_problem.stride_c_);
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_result(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
|
||||
|
||||
if(setting_.verify)
|
||||
{
|
||||
// Host reference computation using reference_mx_gemm
|
||||
// reference_mx_gemm expects scale_a(M, K/ScaleBlockSize) and scale_b(K/ScaleBlockSize,
|
||||
// N) We need to create scale_b in (K/ScaleBlockSize, N) format for the reference
|
||||
ck_tile::HostTensor<ScaleType> scale_b_ref(
|
||||
{static_cast<std::size_t>(scale_k_size), static_cast<std::size_t>(gemm_problem.n_)},
|
||||
{static_cast<std::size_t>(1), static_cast<std::size_t>(scale_k_size)});
|
||||
// Copy scale_b data (our scale_b is (N, scale_k_size) row-major,
|
||||
// reference expects (scale_k_size, N) col-major, which is the same memory layout)
|
||||
std::copy(
|
||||
scale_b_host.mData.begin(), scale_b_host.mData.end(), scale_b_ref.mData.begin());
|
||||
|
||||
mx_gemm_host_reference(
|
||||
setting_.verify, a_m_k, b_k_n, c_m_n_host_result, scale_a_host, scale_b_host);
|
||||
setting_.verify, a_m_k, b_k_n, c_m_n_host_result, scale_a_host, scale_b_ref);
|
||||
}
|
||||
|
||||
for(auto& callable : callables)
|
||||
|
||||
Reference in New Issue
Block a user