mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 05:37:34 +00:00
[rocm-libraries] ROCm/rocm-libraries#8554 (commit be9af54)
refactor(ck): mx gemm kernel unification ## Motivation CK tile currently has two separate MX GEMM kernels for gfx950 and gfx1250. This pull request refactors and modernizes the MX GEMM kernel and example to use new scale tensor handling, improved kernel argument structures, and updated pipeline and kernel APIs. The changes simplify the interface and improve type safety. JIRA ID ROCM-26313 ## Technical Details - Add support for gfx950 in MX GEMM kernel for gfx1250 and remove unused kernel - Unify comp async pipeline for GEMM and MX GEMM - Unify eight waves pipeline for GEMM and MX GEMM - Move preshuffle MX GEMM pipeline to gemm ops and remove gemm_mx ops - Unify testing framework for MX GEMM - Add gfx950 tests for grouped MX GEMM ## Test Plan - `test_mx_gemm_async.cpp` for MX GEMM on gfx950 - `test_mx_grouped_gemm_comp_async.cpp` for grouped MX GEMM on gfx950 ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
604c56bc0e
commit
d559ec00a8
@@ -24,13 +24,13 @@ static constexpr inline auto is_row_major(Layout layout_)
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AScaleDataType,
|
||||
typename BScaleDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool UsePersistentKernel = false>
|
||||
float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
|
||||
ck_tile::DeviceMem& b_dev_buf,
|
||||
@@ -42,36 +42,38 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
ScaleM scale_m,
|
||||
ScaleN scale_n,
|
||||
ck_tile::DeviceMem& scale_m,
|
||||
ck_tile::DeviceMem& scale_n,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
MXGemmHostArgs<ScaleM, ScaleN> args(a_dev_buf.GetDeviceBuffer(),
|
||||
b_dev_buf.GetDeviceBuffer(),
|
||||
c_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
scale_m,
|
||||
scale_n);
|
||||
ck_tile::MxGemmHostArgs<1, 1, 0> args({static_cast<const void*>(a_dev_buf.GetDeviceBuffer())},
|
||||
{static_cast<const void*>(scale_m.GetDeviceBuffer())},
|
||||
{static_cast<const void*>(b_dev_buf.GetDeviceBuffer())},
|
||||
{static_cast<const void*>(scale_n.GetDeviceBuffer())},
|
||||
{},
|
||||
c_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
{stride_A},
|
||||
{stride_B},
|
||||
{},
|
||||
stride_C);
|
||||
|
||||
// Simplified invocation - comp_async handles hot loop and tail internally
|
||||
auto invoke_splitk_path = [&](auto split_k_) {
|
||||
return mx_gemm_calc<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AScaleDataType,
|
||||
BScaleDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ScaleM,
|
||||
ScaleN,
|
||||
UsePersistentKernel,
|
||||
split_k_.value>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
@@ -9,49 +9,8 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_mx.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp"
|
||||
|
||||
template <typename ScaleM, typename ScaleN>
|
||||
struct MXGemmHostArgs : ck_tile::UniversalGemmHostArgs<1, 1, 0>
|
||||
{
|
||||
using Base = ck_tile::UniversalGemmHostArgs<1, 1, 0>;
|
||||
|
||||
MXGemmHostArgs(const void* a_ptr,
|
||||
const void* b_ptr,
|
||||
void* c_ptr_,
|
||||
ck_tile::index_t k_batch_,
|
||||
ck_tile::index_t M_,
|
||||
ck_tile::index_t N_,
|
||||
ck_tile::index_t K_,
|
||||
ck_tile::index_t stride_A_,
|
||||
ck_tile::index_t stride_B_,
|
||||
ck_tile::index_t stride_C_,
|
||||
ScaleM scale_m_,
|
||||
ScaleN scale_n_)
|
||||
: Base({a_ptr},
|
||||
{b_ptr},
|
||||
{},
|
||||
c_ptr_,
|
||||
k_batch_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
{stride_A_},
|
||||
{stride_B_},
|
||||
{},
|
||||
stride_C_),
|
||||
scale_m(scale_m_),
|
||||
scale_n(scale_n_)
|
||||
{
|
||||
}
|
||||
|
||||
ScaleM scale_m;
|
||||
ScaleN scale_n;
|
||||
};
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
|
||||
struct MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
@@ -78,12 +37,15 @@ struct MxGemmConfig
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false; // comp_async uses double buffer
|
||||
static constexpr bool DoubleSmemBuffer = true; // comp_async uses double buffer
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr ck_tile::index_t BContiguousItemsPerAccess = 16;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
|
||||
using AScaleDataType = ck_tile::e8m0_t;
|
||||
using BScaleDataType = ck_tile::e8m0_t;
|
||||
};
|
||||
|
||||
struct MX_GemmConfigEightWaves : MxGemmConfig
|
||||
|
||||
@@ -5,10 +5,9 @@
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "mx_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
using is_row_major_t = ck_tile::bool_constant<
|
||||
@@ -17,16 +16,16 @@ using is_row_major_t = ck_tile::bool_constant<
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AScaleDataType,
|
||||
typename BScaleDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
bool Splitk>
|
||||
float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args, const ck_tile::stream_config& s)
|
||||
float mx_gemm_calc(const ck_tile::MxGemmHostArgs<1, 1, 0>& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
@@ -51,29 +50,38 @@ float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args, const ck_tile::st
|
||||
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
|
||||
"mixed_prec_gemm requires ADataType is a wider type than BDataType");
|
||||
|
||||
using MXPipelineProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
MXGemmTraits,
|
||||
GemmConfig::Scheduler>;
|
||||
using AComputeDataType = ADataType;
|
||||
using BComputeDataType = BDataType;
|
||||
|
||||
using MXPipelineProblem = ck_tile::MxGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
MXGemmTraits,
|
||||
GemmConfig::Scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
AComputeDataType,
|
||||
BComputeDataType,
|
||||
AScaleDataType,
|
||||
BScaleDataType>;
|
||||
|
||||
// Use the MX GEMM Preshuffle pipeline or
|
||||
// the new MX comp_async pipeline with MX scaling support
|
||||
constexpr bool IsEightWave =
|
||||
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp) == 8;
|
||||
using MXGemmPipeline = std::conditional_t<
|
||||
GemmConfig::Preshuffle,
|
||||
ck_tile::MXGemmPreshufflePipelineAGmemBGmemCRegV1<MXPipelineProblem>,
|
||||
std::conditional_t<IsEightWave,
|
||||
ck_tile::MXGemmPipelineAgBgCrCompAsyncEightWaves<MXPipelineProblem>,
|
||||
ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>>>;
|
||||
ck_tile::GemmPipelineAgBgCrCompAsyncEightWaves<MXPipelineProblem>,
|
||||
ck_tile::GemmPipelineAgBgCrCompAsync<MXPipelineProblem>>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
constexpr ck_tile::index_t BlockedXDLNPerWarp = GemmConfig::Preshuffle ? 2 : 1;
|
||||
|
||||
using GemmEpilogue =
|
||||
std::conditional_t<GemmConfig::TiledMMAPermuteN,
|
||||
ck_tile::PermuteNEpilogue<
|
||||
@@ -115,29 +123,16 @@ float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args, const ck_tile::st
|
||||
GemmConfig::NumWaveGroups,
|
||||
false, // FixedVectorSize_ (Default)
|
||||
1, // VectorSizeC_ (Default)
|
||||
ck_tile::MXEpilogueTraits<GemmConfig>::BlockedXDLNPerWarp,
|
||||
BlockedXDLNPerWarp,
|
||||
false, // DoubleSmemBuffer_ (Default)
|
||||
ComputeDataType, // AComputeDataType
|
||||
ComputeDataType, // BComputeDataType
|
||||
!GemmConfig::Preshuffle>>>; // TilesPacked_ (because of
|
||||
// packed scales)
|
||||
|
||||
using Kernel = ck_tile::MXGemmKernel<TilePartitioner, MXGemmPipeline, GemmEpilogue>;
|
||||
using Kernel = ck_tile::MxGemmKernel<TilePartitioner, MXGemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(std::array<const void*, 1>{args.as_ptr},
|
||||
std::array<const void*, 1>{args.bs_ptr},
|
||||
std::array<const void*, 0>{},
|
||||
args.e_ptr,
|
||||
args.k_batch,
|
||||
args.M,
|
||||
args.N,
|
||||
args.K,
|
||||
std::array<ck_tile::index_t, 1>{args.stride_As},
|
||||
std::array<ck_tile::index_t, 1>{args.stride_Bs},
|
||||
std::array<ck_tile::index_t, 0>{},
|
||||
args.stride_E,
|
||||
args.scale_m,
|
||||
args.scale_n);
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
@@ -145,8 +140,10 @@ float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args, const ck_tile::st
|
||||
"MX GEMM: unsupported shape/configuration (set CK_TILE_LOGGING=1 for details).");
|
||||
}
|
||||
|
||||
const auto kernel = ck_tile::make_kernel<Kernel::kBlockPerCu>(
|
||||
Kernel{}, Kernel::GridSize(kargs), Kernel::BlockSize(), 0, kargs);
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
const auto kernel = ck_tile::make_kernel<kBlockPerCu>(
|
||||
Kernel{}, Kernel::GridSize(args.M, args.N, args.k_batch), Kernel::BlockSize(), 0, kargs);
|
||||
|
||||
// For split-K (k_batch > 1) the kernel's epilogue uses atomic_add into C, so C must be
|
||||
// zeroed before every kernel launch -- not just once before the first warmup iteration.
|
||||
|
||||
@@ -16,6 +16,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K, const float max_accumulated_v
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AScaleDataType,
|
||||
typename BScaleDataType,
|
||||
typename AccDataType,
|
||||
typename GemmConfig,
|
||||
bool UsePersistentKernel,
|
||||
@@ -61,53 +63,48 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout)
|
||||
// Scale tensors - follow parent matrix layouts for optimal memory access
|
||||
// A scales: [M, K/32] with A's layout
|
||||
// B scales: [K/32, N] with B's layout
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
ck_tile::index_t scale_k_size = K / 32;
|
||||
|
||||
// Follow A/BLayout to get the layouts for the scale tensors
|
||||
ck_tile::index_t stride_scale_a =
|
||||
ck_tile::get_default_stride(M, scale_k_size, 0, is_row_major(ALayout{}));
|
||||
ck_tile::index_t stride_scale_b =
|
||||
ck_tile::get_default_stride(scale_k_size, N, 0, is_row_major(BLayout{}));
|
||||
ck_tile::HostTensor<AScaleDataType> scale_a_host(
|
||||
{static_cast<std::size_t>(M), static_cast<std::size_t>(scale_k_size)},
|
||||
{static_cast<std::size_t>(scale_k_size), static_cast<std::size_t>(1)});
|
||||
|
||||
// scale_b uses N as first dimension (col-major like B)
|
||||
ck_tile::HostTensor<BScaleDataType> scale_b_host(
|
||||
{static_cast<std::size_t>(N), static_cast<std::size_t>(scale_k_size)},
|
||||
{static_cast<std::size_t>(scale_k_size), static_cast<std::size_t>(1)});
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_a_host(
|
||||
ck_tile::host_tensor_descriptor(M, scale_k_size, stride_scale_a, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<ScaleType> scale_b_host(
|
||||
ck_tile::host_tensor_descriptor(scale_k_size, N, stride_scale_b, is_row_major(BLayout{})));
|
||||
std::mt19937 gen(42);
|
||||
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
|
||||
|
||||
auto gen_scales = [&](auto& scales, float range_min, float range_max) {
|
||||
// e8m0_t is basically an exponent of float32
|
||||
ck_tile::HostTensor<float> pow2(scales.get_lengths());
|
||||
ck_tile::FillUniformDistributionIntegerValue<float>{range_min, range_max, fill_seed(gen)}(
|
||||
pow2);
|
||||
scales.ForEach([&](auto& self, const auto& i) {
|
||||
self(i) = static_cast<ScaleType>(std::exp2(pow2(i)));
|
||||
});
|
||||
};
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
// Initialize A, B, and scales to random values
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f, fill_seed(gen)}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f, fill_seed(gen)}(b_host);
|
||||
gen_scales(scale_a_host, -2, 2);
|
||||
gen_scales(scale_b_host, -2, 2);
|
||||
ck_tile::FillUniformScaleDistribution<AScaleDataType>{0.125f, 2.0f, fill_seed(gen)}(
|
||||
scale_a_host);
|
||||
ck_tile::FillUniformScaleDistribution<BScaleDataType>{0.125f, 2.0f, fill_seed(gen)}(
|
||||
scale_b_host);
|
||||
break;
|
||||
case 1:
|
||||
// Initialize A, B, and scales to 1.0
|
||||
ck_tile::FillConstant<ADataType>{ADataType(1.f)}(a_host);
|
||||
ck_tile::FillConstant<BDataType>{BDataType(1.f)}(b_host);
|
||||
gen_scales(scale_a_host, 0, 0);
|
||||
gen_scales(scale_b_host, 0, 0);
|
||||
ck_tile::FillUniformScaleDistribution<AScaleDataType>{1.0f, 1.0f, fill_seed(gen)}(
|
||||
scale_a_host);
|
||||
ck_tile::FillUniformScaleDistribution<BScaleDataType>{1.0f, 1.0f, fill_seed(gen)}(
|
||||
scale_b_host);
|
||||
break;
|
||||
case 2:
|
||||
// Initialize A and B with random values but with constant 1.0 scales
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f, fill_seed(gen)}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f, fill_seed(gen)}(b_host);
|
||||
gen_scales(scale_a_host, 0, 0);
|
||||
gen_scales(scale_b_host, 0, 0);
|
||||
ck_tile::FillUniformScaleDistribution<AScaleDataType>{1.0f, 1.0f, fill_seed(gen)}(
|
||||
scale_a_host);
|
||||
ck_tile::FillUniformScaleDistribution<BScaleDataType>{1.0f, 1.0f, fill_seed(gen)}(
|
||||
scale_b_host);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -128,12 +125,31 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout)
|
||||
constexpr ck_tile::index_t XdlMNThread = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
auto scale_a_packed =
|
||||
ck_tile::packScalesMNxK<MXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(scale_a_host,
|
||||
true);
|
||||
auto scale_b_packed =
|
||||
ck_tile::packScalesMNxK<NXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(scale_b_host,
|
||||
false);
|
||||
ck_tile::HostTensor<AScaleDataType> scale_a_shuffled(
|
||||
{static_cast<std::size_t>(M / MXdlPackEff * 2),
|
||||
static_cast<std::size_t>(scale_k_size / KXdlPackEff * 2)},
|
||||
{static_cast<std::size_t>(scale_k_size / KXdlPackEff * 2), static_cast<std::size_t>(1)});
|
||||
|
||||
ck_tile::HostTensor<BScaleDataType> scale_b_shuffled(
|
||||
{static_cast<std::size_t>(N / NXdlPackEff * 2),
|
||||
static_cast<std::size_t>(scale_k_size / KXdlPackEff * 2)},
|
||||
{static_cast<std::size_t>(scale_k_size / KXdlPackEff * 2), static_cast<std::size_t>(1)});
|
||||
|
||||
ck_tile::preShuffleScaleBuffer_gfx950<MXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(
|
||||
scale_a_host.mData.data(), scale_a_shuffled.mData.data(), M, scale_k_size, true);
|
||||
|
||||
if constexpr(GemmConfig::Preshuffle && GemmConfig::TiledMMAPermuteN)
|
||||
{
|
||||
ck_tile::preShuffleScaleBufferPermuteN_gfx950<GemmConfig::N_Warp,
|
||||
GemmConfig::N_Tile,
|
||||
XdlMNThread>(
|
||||
scale_b_host.mData.data(), scale_b_shuffled.mData.data(), N, scale_k_size, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::preShuffleScaleBuffer_gfx950<NXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(
|
||||
scale_b_host.mData.data(), scale_b_shuffled.mData.data(), N, scale_k_size, true);
|
||||
}
|
||||
|
||||
const auto b_host_for_device = [&]() {
|
||||
if constexpr(GemmConfig::Preshuffle)
|
||||
@@ -145,73 +161,29 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout)
|
||||
return b_host;
|
||||
}();
|
||||
|
||||
const auto scale_a_host_for_device = [&]() {
|
||||
if constexpr(GemmConfig::Preshuffle)
|
||||
return ck_tile::preShuffleScale<GemmConfig::N_Warp_Tile>(scale_a_host, true);
|
||||
else
|
||||
return scale_a_packed;
|
||||
}();
|
||||
|
||||
constexpr ck_tile::index_t XdlNThread = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t NPerBlock = GemmConfig::N_Tile;
|
||||
constexpr ck_tile::index_t NWarp = GemmConfig::N_Warp;
|
||||
|
||||
const auto scale_b_host_for_device = [&]() {
|
||||
if constexpr(GemmConfig::Preshuffle)
|
||||
{
|
||||
if constexpr(GemmConfig::TiledMMAPermuteN)
|
||||
return ck_tile::preShuffleScalePermuteN<NWarp, NPerBlock, XdlNThread>(scale_b_host,
|
||||
false);
|
||||
else
|
||||
return ck_tile::preShuffleScale<GemmConfig::N_Warp_Tile>(scale_b_host, false);
|
||||
}
|
||||
else
|
||||
return scale_b_packed;
|
||||
}();
|
||||
|
||||
const auto scale_a_device_bytes = [&]() {
|
||||
if constexpr(GemmConfig::Preshuffle)
|
||||
return scale_a_host_for_device.get_element_space_size_in_bytes();
|
||||
else
|
||||
return scale_a_host_for_device.size() * sizeof(int32_t);
|
||||
}();
|
||||
|
||||
const auto scale_b_device_bytes = [&]() {
|
||||
if constexpr(GemmConfig::Preshuffle)
|
||||
return scale_b_host_for_device.get_element_space_size_in_bytes();
|
||||
else
|
||||
return scale_b_host_for_device.size() * sizeof(int32_t);
|
||||
}();
|
||||
|
||||
// Device buffers for A, B, C, and packed scale tensors
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_dev_buf(b_host_for_device.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_a_dev_buf(scale_a_device_bytes);
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_device_bytes);
|
||||
ck_tile::DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes());
|
||||
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
b_dev_buf.ToDevice(b_host_for_device.data());
|
||||
c_dev_buf.SetZero();
|
||||
scale_a_dev_buf.ToDevice(scale_a_host_for_device.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_host_for_device.data());
|
||||
|
||||
// Scale pointers - point to packed int32_t data, kernel reinterprets as int32_t*
|
||||
using ScaleM = ck_tile::MXScalePointer<ScaleType, 1, 32>;
|
||||
using ScaleN = ck_tile::MXScalePointer<ScaleType, 1, 32>;
|
||||
ScaleM scale_m(reinterpret_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()));
|
||||
ScaleN scale_n(reinterpret_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()));
|
||||
scale_a_dev_buf.ToDevice(scale_a_shuffled.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_shuffled.data());
|
||||
|
||||
float ave_time = invoke_mx_gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AScaleDataType,
|
||||
BScaleDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ScaleM,
|
||||
ScaleN,
|
||||
UsePersistentKernel>(a_dev_buf,
|
||||
b_dev_buf,
|
||||
c_dev_buf,
|
||||
@@ -222,8 +194,8 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout)
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
scale_m,
|
||||
scale_n,
|
||||
scale_a_dev_buf,
|
||||
scale_b_dev_buf,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
@@ -240,9 +212,23 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout)
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::
|
||||
reference_mx_gemm<ADataType, BDataType, ScaleType, ScaleType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_m_n_host_ref, scale_a_host, scale_b_host);
|
||||
// Host reference computation using reference_mx_gemm
|
||||
// reference_mx_gemm expects scale_a(M, K/ScaleBlockSize) and scale_b(K/ScaleBlockSize, N)
|
||||
// We need to create scale_b in (K/ScaleBlockSize, N) format for the reference
|
||||
ck_tile::HostTensor<BScaleDataType> scale_b_ref(
|
||||
{static_cast<std::size_t>(scale_k_size), static_cast<std::size_t>(N)},
|
||||
{static_cast<std::size_t>(1), static_cast<std::size_t>(scale_k_size)});
|
||||
// Copy scale_b data (our scale_b is (N, scale_k_size) row-major,
|
||||
// reference expects (scale_k_size, N) col-major, which is the same memory layout)
|
||||
std::copy(scale_b_host.mData.begin(), scale_b_host.mData.end(), scale_b_ref.mData.begin());
|
||||
|
||||
ck_tile::reference_mx_gemm<ADataType,
|
||||
BDataType,
|
||||
AScaleDataType,
|
||||
BScaleDataType,
|
||||
AccDataType,
|
||||
CDataType>(
|
||||
a_host, b_host, c_m_n_host_ref, scale_a_host, scale_b_ref);
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
@@ -285,6 +271,8 @@ int run_mx_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
typename GemmConfig::AScaleDataType,
|
||||
typename GemmConfig::BScaleDataType,
|
||||
float,
|
||||
MXfp4_GemmConfig16_Preshuffle,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
@@ -293,6 +281,8 @@ int run_mx_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
typename GemmConfig::AScaleDataType,
|
||||
typename GemmConfig::BScaleDataType,
|
||||
float,
|
||||
GemmConfig,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
@@ -304,6 +294,8 @@ int run_mx_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_with_layouts<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
typename GemmConfig::AScaleDataType,
|
||||
typename GemmConfig::BScaleDataType,
|
||||
float,
|
||||
MXfp8_GemmConfig16_Preshuffle,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
@@ -312,6 +304,8 @@ int run_mx_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_with_layouts<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
typename GemmConfig::AScaleDataType,
|
||||
typename GemmConfig::BScaleDataType,
|
||||
float,
|
||||
GemmConfig,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
|
||||
Reference in New Issue
Block a user