mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
[rocm-libraries] ROCm/rocm-libraries#8554 (commit be9af54)
refactor(ck): mx gemm kernel unification ## Motivation CK tile currently has two separate MX GEMM kernels for gfx950 and gfx1250. This pull request refactors and modernizes the MX GEMM kernel and example to use new scale tensor handling, improved kernel argument structures, and updated pipeline and kernel APIs. The changes simplify the interface and improve type safety. JIRA ID ROCM-26313 ## Technical Details - Add support for gfx950 in MX GEMM kernel for gfx1250 and remove unused kernel - Unify comp async pipeline for GEMM and MX GEMM - Unify eight waves pipeline for GEMM and MX GEMM - Move preshuffle MX GEMM pipeline to gemm ops and remove gemm_mx ops - Unify testing framework for MX GEMM - Add gfx950 tests for grouped MX GEMM ## Test Plan - `test_mx_gemm_async.cpp` for MX GEMM on gfx950 - `test_mx_grouped_gemm_comp_async.cpp` for grouped MX GEMM on gfx950 ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
604c56bc0e
commit
d559ec00a8
@@ -3,6 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
|
||||
@@ -72,6 +74,7 @@ struct MxGemmKernel
|
||||
using BaseKernel::PersistentKernel;
|
||||
using typename BaseKernel::AsLayout;
|
||||
using typename BaseKernel::BsLayout;
|
||||
using typename BaseKernel::CLayout;
|
||||
using typename BaseKernel::DsLayout;
|
||||
|
||||
using typename BaseKernel::ADataType;
|
||||
@@ -91,6 +94,8 @@ struct MxGemmKernel
|
||||
using BaseKernel::APackedSize;
|
||||
using BaseKernel::BPackedSize;
|
||||
|
||||
using BaseKernel::I1;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename MxGemmPipeline::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename MxGemmPipeline::BElementWise>;
|
||||
|
||||
@@ -100,12 +105,48 @@ struct MxGemmKernel
|
||||
static constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(number<1>{});
|
||||
|
||||
static constexpr int BlockScaleSize = MxGemmPipeline::ScaleBlockSize;
|
||||
static_assert(BlockScaleSize == 16 || BlockScaleSize == 32, "unsupported BlockScaleSize");
|
||||
// Scale tensor element type is always int32_t (4 packed e8m0 bytes).
|
||||
// For scale16, each thread needs 8 bytes = 2 int32_t elements.
|
||||
// For scale32, each thread needs 4 bytes = 1 int32_t element.
|
||||
static constexpr int ScalePackSize = 4;
|
||||
using ScalePtrType = const int32_t*;
|
||||
using ScalePtrType = const int32_t*;
|
||||
// Padding flags pulled from pipeline so the kernel can pad the (unscaled) C and scale views
|
||||
// consistently with the A/B views that the pipeline already pads via
|
||||
// Underlying::MakeA/BBlockWindows.
|
||||
static constexpr bool kPadM = MxGemmPipeline::kPadM;
|
||||
static constexpr bool kPadN = MxGemmPipeline::kPadN;
|
||||
static constexpr bool kPadK = MxGemmPipeline::kPadK;
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Compile-time padding-support invariants for the MX comp-async pipeline.
|
||||
//
|
||||
// - K padding is NOT supported: async_load_tile issues vector buffer reads whose
|
||||
// OOB check is per-vector-start, so a vector that straddles the K pad boundary
|
||||
// pulls in data from the adjacent row / next K tile rather than zero. The packed
|
||||
// scale tile has the same vector-load property. Until the async path learns how
|
||||
// to do per-element pad masking, we forbid kPadK at compile time.
|
||||
//
|
||||
// - kPadM / kPadN are supported only when the GEMM has at least one full block
|
||||
// along that dimension; the CShuffleEpilogue's LDS shuffle uses thread positions
|
||||
// that do not all participate when the entire dimension is smaller than a tile
|
||||
// (resulting in zeros being written into in-range output rows). The "entire
|
||||
// dimension < tile" case is rejected at runtime in IsSupportedArgument; we
|
||||
// cannot statically catch it because M and N are runtime values.
|
||||
// ------------------------------------------------------------------
|
||||
static_assert(!kPadK,
|
||||
"MX GEMM (comp-async pipeline): K padding (kPadK = true) is not supported. "
|
||||
"The async vector loads do not mask elements that straddle the K pad "
|
||||
"boundary, so partial K tiles produce silently wrong results. Choose K so "
|
||||
"that K is a multiple of KPerBlock * k_batch.");
|
||||
|
||||
// Single source of truth for the split-K atomic-add precondition, shared by the runtime
|
||||
// check in IsSupportedArgument and the atomic_add dispatch in operator(). Split-K
|
||||
// accumulates each k_id's partial C tile with atomic_add; the CShuffle epilogue can only
|
||||
// emit atomic_add for fp16/bf16 outputs when the C vector size is even. For an odd vector
|
||||
// size that combination is not instantiated, so such a config cannot run split-K. For all
|
||||
// shipped tile shapes GetVectorSizeC() is even, so this is defensive rather than reachable.
|
||||
static constexpr bool kSplitKAtomicAddSupported =
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 == 0 || !is_any_of<EDataType, fp16_t, bf16_t>::value;
|
||||
|
||||
static constexpr index_t MXdlPackEff = MxGemmPipeline::MXdlPackEff;
|
||||
static constexpr index_t NXdlPackEff = MxGemmPipeline::NXdlPackEff;
|
||||
static constexpr index_t KXdlPackEff = MxGemmPipeline::KXdlPackEff;
|
||||
|
||||
using KernelArgs = MxGemmKernelArgs<NumATensor, NumBTensor, NumDTensor>;
|
||||
|
||||
@@ -131,14 +172,57 @@ struct MxGemmKernel
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
const bool log = ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING));
|
||||
|
||||
if(kargs.k_batch < 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("SplitK (k_batch > 1) is not supported for MX GEMM!");
|
||||
}
|
||||
if(log)
|
||||
CK_TILE_ERROR("MX GEMM: k_batch must be >= 1.");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Split-K derives this k_id's logical K start from the row-major SplitKBatchOffset
|
||||
// (as_k_split_offset[0]) to offset the packed-scale / flat-B windows; for column-major A
|
||||
// that field is stride-scaled, so split-K with non-row-major A is not yet supported.
|
||||
// (k_batch == 1 is unaffected -- the offset is 0 and unused.) When col-major A lands for
|
||||
// non-preshuffle, extend the split-K K-offset here instead of this reject.
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
if constexpr(!std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.k_batch > 1)
|
||||
{
|
||||
if(log)
|
||||
CK_TILE_ERROR("MX GEMM: split-K (k_batch > 1) currently requires row-major A.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Scales are granular in K: each packed int32_t covers BlockScaleSize * KXdlPackEff
|
||||
// consecutive K elements. Every split-K boundary must land on that granularity so that
|
||||
// each split can compute a packed-scale K offset. K1 is the WarpTile K, which is a
|
||||
// multiple of that granularity for all shipped configs, but be defensive.
|
||||
constexpr index_t scale_granularity_k = BlockScaleSize * KXdlPackEff;
|
||||
if(kargs.k_batch > 1)
|
||||
{
|
||||
// splitk_batch_offset allocates K in units of K1 (warp-tile K). If K1 itself is
|
||||
// not a multiple of the scale granularity, split-K is not safe.
|
||||
constexpr index_t K1 = BlockGemmShape::WarpTile::at(number<2>{});
|
||||
static_assert(K1 % scale_granularity_k == 0,
|
||||
"MX GEMM: WarpTile K must be a multiple of BlockScaleSize * KXdlPack "
|
||||
"to support split-K.");
|
||||
// Defensive runtime check: K must split evenly along K1 boundaries so that each
|
||||
// k_id consumes a whole number of warp-tile K chunks (and therefore a whole
|
||||
// number of packed-scale K elements).
|
||||
if(kargs.K % (K1 * kargs.k_batch) != 0)
|
||||
{
|
||||
if(log)
|
||||
CK_TILE_ERROR("MX GEMM: with k_batch > 1, K must be a multiple of WarpTile_K * "
|
||||
"k_batch so that every split lands on a packed-scale boundary.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Delegate the remaining shape/vector-size checks to the universal kernel.
|
||||
return BaseKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
@@ -146,10 +230,14 @@ struct MxGemmKernel
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeScaleABlockWindow(const std::array<ScalePtrType, NumATensor>& as_scale_ptr,
|
||||
const KernelArgs& kargs,
|
||||
index_t block_idx_m)
|
||||
index_t block_idx_m,
|
||||
const index_t k_elem_offset = 0)
|
||||
{
|
||||
const auto&& scale_packs_m = integer_divide_ceil(kargs.M, MThreadPerXdl);
|
||||
const auto&& scale_packs_k = kargs.K / BlockScaleSize / ScalePackSize;
|
||||
const auto&& scale_packs_m = integer_divide_ceil(kargs.M, MThreadPerXdl * MXdlPackEff);
|
||||
const auto&& scale_packs_k = kargs.K / BlockScaleSize / KXdlPackEff;
|
||||
|
||||
// For split-K (k_batch > 1) advance the scale origin into this k_id's packed-K slice.
|
||||
const index_t k_scale_offset = k_elem_offset / BlockScaleSize / KXdlPackEff;
|
||||
|
||||
// Scale16: descriptor order [packs_m, MThreadPerXdl, packs_k] -- K contiguous per M-row,
|
||||
// no pre-shuffle needed (natural row-major layout matches).
|
||||
@@ -184,14 +272,28 @@ struct MxGemmKernel
|
||||
return make_tensor_view<address_space_enum::global>(as_scale_ptr[i], scale_a_desc);
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
// Pad the scale view so partial trailing tiles along M are handled safely (OOB scale
|
||||
// loads return zero; with A also zero on the padded region the contribution is zero
|
||||
// regardless of scale value). kPadK is statically disabled, so K never actually pads.
|
||||
const auto& scale_a_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
return pad_tensor_view(
|
||||
scale_a_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock / MXdlPackEff>{},
|
||||
number<TilePartitioner::KPerBlock / BlockScaleSize / KXdlPackEff>{}),
|
||||
sequence<kPadM, kPadK>{});
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
const auto& scale_a_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_tile_window(
|
||||
scale_a_tensor_view[i],
|
||||
scale_a_pad_view[i],
|
||||
make_tuple(
|
||||
number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / (BlockScaleSize * ScalePackSize)>{}),
|
||||
{block_idx_m, 0});
|
||||
number<TilePartitioner::MPerBlock / MXdlPackEff>{},
|
||||
number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPackEff)>{}),
|
||||
{block_idx_m / MXdlPackEff, k_scale_offset});
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
@@ -202,10 +304,14 @@ struct MxGemmKernel
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeScaleBBlockWindow(const std::array<ScalePtrType, NumBTensor>& bs_scale_ptr,
|
||||
const KernelArgs& kargs,
|
||||
index_t block_idx_n)
|
||||
index_t block_idx_n,
|
||||
const index_t k_elem_offset = 0)
|
||||
{
|
||||
const auto&& scale_packs_n = integer_divide_ceil(kargs.N, NThreadPerXdl);
|
||||
const auto&& scale_packs_k = kargs.K / BlockScaleSize / ScalePackSize;
|
||||
const auto&& scale_packs_n = integer_divide_ceil(kargs.N, NThreadPerXdl * NXdlPackEff);
|
||||
const auto&& scale_packs_k = kargs.K / BlockScaleSize / KXdlPackEff;
|
||||
|
||||
// For split-K (k_batch > 1) advance the scale origin into this k_id's packed-K slice.
|
||||
const index_t k_scale_offset = k_elem_offset / BlockScaleSize / KXdlPackEff;
|
||||
|
||||
const auto scale_b_naive_desc = [&]() {
|
||||
if constexpr(BlockScaleSize == 16)
|
||||
@@ -236,33 +342,120 @@ struct MxGemmKernel
|
||||
return make_tensor_view<address_space_enum::global>(bs_scale_ptr[i], scale_b_desc);
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
|
||||
// Pad the scale view so partial trailing tiles along N are handled safely (OOB scale
|
||||
// loads return zero; with B also zero on the padded region the contribution is zero
|
||||
// regardless of scale value). kPadK is statically disabled, so K never actually pads.
|
||||
const auto& scale_b_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
return pad_tensor_view(
|
||||
scale_b_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock / NXdlPackEff>{},
|
||||
number<TilePartitioner::KPerBlock / BlockScaleSize / KXdlPackEff>{}),
|
||||
sequence<kPadN, kPadK>{});
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
|
||||
const auto& scale_b_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_tile_window(
|
||||
scale_b_tensor_view[i],
|
||||
scale_b_pad_view[i],
|
||||
make_tuple(
|
||||
number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / (BlockScaleSize * ScalePackSize)>{}),
|
||||
{block_idx_n, 0});
|
||||
number<TilePartitioner::NPerBlock / NXdlPackEff>{},
|
||||
number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPackEff)>{}),
|
||||
{block_idx_n / NXdlPackEff, k_scale_offset});
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
return scale_b_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeBFlatBlockWindows(const std::array<const BDataType*, NumBTensor>& bs_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t i_n,
|
||||
const index_t k_elem_offset = 0)
|
||||
{
|
||||
static_assert(NumBTensor == 1, "MX GEMM preshuffle currently supports one B tensor");
|
||||
|
||||
constexpr index_t kKPerBlock = MxGemmPipeline::kKPerBlock;
|
||||
constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile;
|
||||
const index_t kFlatKBlocks = kargs.K / kKPerBlock;
|
||||
const index_t kFlatN = kargs.N / kNWarpTile;
|
||||
|
||||
const index_t k_flat_offset = (k_elem_offset / kKPerBlock) * flatKPerBlock;
|
||||
|
||||
auto b_flat_tensor_view = [&]() {
|
||||
static_assert(flatKPerBlock % MxGemmPipeline::GetVectorSizeB() == 0,
|
||||
"wrong! vector size for preshuffled B tensor");
|
||||
auto naive_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(kFlatN, kFlatKBlocks, number<flatKPerBlock>{}));
|
||||
auto desc = transform_tensor_descriptor(
|
||||
naive_desc,
|
||||
make_tuple(make_pass_through_transform(kFlatN),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(kFlatKBlocks, number<flatKPerBlock>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return make_tensor_view<address_space_enum::global>(bs_ptr[number<0>{}], desc);
|
||||
}();
|
||||
|
||||
return generate_tuple(
|
||||
[&](auto) {
|
||||
return make_tile_window(b_flat_tensor_view,
|
||||
make_tuple(number<MxGemmPipeline::flatNPerWarp>{},
|
||||
number<MxGemmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)),
|
||||
static_cast<int>(k_flat_offset)});
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp>
|
||||
CK_TILE_DEVICE static void RunGemm(const std::array<const ADataType*, NumATensor>& as_ptr,
|
||||
const std::array<const BDataType*, NumBTensor>& bs_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr,
|
||||
const KernelArgs& kargs,
|
||||
KernelArgs kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
const index_t block_idx_n,
|
||||
const index_t k_elem_offset = 0)
|
||||
{
|
||||
std::array<ScalePtrType, NumATensor> as_scale_ptr;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
as_scale_ptr[i] = reinterpret_cast<ScalePtrType>(kargs.as_scale_ptr[i]);
|
||||
});
|
||||
std::array<const ADataType*, NumATensor> as_ptr_;
|
||||
index_t block_idx_m_;
|
||||
// Large tensor support (when M is large, N and K are relatively small)
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
constexpr bool offset_ptrs_by_tile_coords =
|
||||
std::is_same_v<tensor_layout::gemm::RowMajor, ALayout> &&
|
||||
std::is_same_v<tensor_layout::gemm::RowMajor, CLayout> && !BaseKernel::ClusterLaunch;
|
||||
|
||||
if constexpr(offset_ptrs_by_tile_coords)
|
||||
{
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
as_ptr_[i] = as_ptr[i] + static_cast<std::ptrdiff_t>(block_idx_m) *
|
||||
kargs.stride_As[i] / APackedSize;
|
||||
});
|
||||
e_ptr += static_cast<std::ptrdiff_t>(block_idx_m) * kargs.stride_E;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
as_scale_ptr[i] = reinterpret_cast<ScalePtrType>(kargs.as_scale_ptr[i]) +
|
||||
static_cast<std::ptrdiff_t>(block_idx_m / MXdlPackEff) *
|
||||
(kargs.K / BlockScaleSize / KXdlPackEff);
|
||||
});
|
||||
|
||||
kargs.M = std::min(kargs.M - block_idx_m, TilePartitioner::MPerBlock);
|
||||
block_idx_m_ = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
as_scale_ptr[i] = reinterpret_cast<ScalePtrType>(kargs.as_scale_ptr[i]);
|
||||
});
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) { as_ptr_[i] = as_ptr[i]; });
|
||||
block_idx_m_ = block_idx_m;
|
||||
}
|
||||
|
||||
std::array<ScalePtrType, NumBTensor> bs_scale_ptr;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
@@ -272,18 +465,50 @@ struct MxGemmKernel
|
||||
// cluster launch pads grid to cluster boundaries; skip out-of-bound blocks
|
||||
if constexpr(BaseKernel::ClusterLaunch)
|
||||
{
|
||||
if(block_idx_m >= kargs.M || block_idx_n >= kargs.N)
|
||||
if(block_idx_m_ >= kargs.M || block_idx_n >= kargs.N)
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& as_block_window = BaseKernel::MakeABlockWindows(
|
||||
as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& bs_block_window = BaseKernel::MakeBBlockWindows(
|
||||
bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
// The preshuffle A async-load (MakeMX_AAsyncLoadBytesDramWindow) rebuilds the A
|
||||
// view with a packed descriptor, i.e. it assumes the leading (M) stride equals
|
||||
// the view's K extent. That only holds when the extent equals stride_A, which is
|
||||
// the case for k_batch == 1 (splitted_k == K) but NOT for split-K (splitted_k < K):
|
||||
// a packed extent of splitted_k would stride M by splitted_k instead of stride_A
|
||||
// and read the wrong rows (only row 0 lands correctly). Use the full K extent so
|
||||
// the packed M stride matches stride_A. The as_ptr K-offset already selects this
|
||||
// k_id's slice and num_loop bounds the blocks read, so reads stay within
|
||||
// [as_k_split_offset, as_k_split_offset + splitted_k) <= K (in-allocation).
|
||||
const auto& as_block_window = [&]() {
|
||||
if constexpr(MxGemmPipeline::Preshuffle)
|
||||
{
|
||||
return BaseKernel::MakeABlockWindows(as_ptr_, kargs, kargs.K, block_idx_m_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return BaseKernel::MakeABlockWindows(
|
||||
as_ptr_, kargs, splitk_batch_offset.splitted_k, block_idx_m_);
|
||||
}
|
||||
}();
|
||||
const auto& bs_block_window = [&]() {
|
||||
if constexpr(MxGemmPipeline::Preshuffle)
|
||||
{
|
||||
return MakeBFlatBlockWindows(bs_ptr, kargs, block_idx_n, k_elem_offset);
|
||||
}
|
||||
else
|
||||
{
|
||||
return BaseKernel::MakeBBlockWindows(
|
||||
bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
}
|
||||
}();
|
||||
const auto& ds_block_window =
|
||||
BaseKernel::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
const auto& scale_a_block_window = MakeScaleABlockWindow(as_scale_ptr, kargs, block_idx_m);
|
||||
const auto& scale_b_block_window = MakeScaleBBlockWindow(bs_scale_ptr, kargs, block_idx_n);
|
||||
BaseKernel::MakeDBlockWindows(ds_ptr, kargs, block_idx_m_, block_idx_n);
|
||||
|
||||
// Create scale block windows. For split-K (k_batch > 1), k_elem_offset advances the
|
||||
// scale origin into the correct packed-K slice for this k_id; otherwise it is zero.
|
||||
const auto& scale_a_block_window =
|
||||
MakeScaleABlockWindow(as_scale_ptr, kargs, block_idx_m_, k_elem_offset);
|
||||
const auto& scale_b_block_window =
|
||||
MakeScaleBBlockWindow(bs_scale_ptr, kargs, block_idx_n, k_elem_offset);
|
||||
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
@@ -297,10 +522,58 @@ struct MxGemmKernel
|
||||
num_loop,
|
||||
smem_ptr);
|
||||
|
||||
auto c_block_window = BaseKernel::template MakeCBlockWindows<memory_operation_enum::set>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
// Dispatch epilogue: when k_batch > 1 each split accumulates a partial result into
|
||||
// the same C tile, so we need atomic add (universal_gemm_kernel pattern). The
|
||||
// fp16/bf16 even-vector-size precondition is captured once in kSplitKAtomicAddSupported
|
||||
// and also rejected up front in IsSupportedArgument.
|
||||
// if(k_batch == 1)
|
||||
auto c_block_window = BaseKernel::template MakeCBlockWindows<DstInMemOp>(
|
||||
e_ptr, kargs, block_idx_m_, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static void RunGemm(const std::array<const ADataType*, NumATensor>& as_ptr,
|
||||
const std::array<const BDataType*, NumBTensor>& bs_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
RunGemm<memory_operation_enum::set>(as_ptr,
|
||||
bs_ptr,
|
||||
ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
block_idx_m,
|
||||
block_idx_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
// This k_id's logical K-element start. For row-major A, as_k_split_offset[0] is exactly
|
||||
// that offset, so reuse it rather than recomputing the split formula; the packed-scale
|
||||
// and flat-B K offsets are derived from it. Split-K with non-row-major A is rejected in
|
||||
// IsSupportedArgument; for k_batch == 1 this value is 0 and unused for any layout.
|
||||
const index_t k_elem_offset =
|
||||
amd_wave_read_first_lane(splitk_batch_offset.as_k_split_offset[number<0>{}]);
|
||||
RunGemm<memory_operation_enum::atomic_add>(as_ptr,
|
||||
bs_ptr,
|
||||
ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
block_idx_m,
|
||||
block_idx_n,
|
||||
k_elem_offset);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -116,7 +116,7 @@ struct MxGroupedGemmKernel
|
||||
using P_ = GemmPipeline;
|
||||
return concat('_', "mx_gemm_grouped", gemm_prec_str<ADataType, BDataType>(),
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
|
||||
(UsePersistentKernel ? "Persistent" : "NonPersistent"),
|
||||
(NumDTensor_ == 2 ? "MultiD" : "NoMultiD"),
|
||||
|
||||
@@ -1280,8 +1280,18 @@ struct UniversalGemmKernel
|
||||
|
||||
std::array<const BDataType*, NumBTensor> bs_ptr;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
bs_ptr[i] = static_cast<const BDataType*>(kargs.bs_ptr[i]) +
|
||||
splitk_batch_offset.bs_k_split_offset[i] / BPackedSize;
|
||||
if constexpr(GemmPipeline::Preshuffle)
|
||||
{
|
||||
// The preshuffle (flat-B) path applies the per-split K offset to the flat
|
||||
// window origin in when creating the window; bs_k_split_offset is derived from
|
||||
// the logical B stride and would mis-offset the flat buffer.
|
||||
bs_ptr[i] = static_cast<const BDataType*>(kargs.bs_ptr[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
bs_ptr[i] = static_cast<const BDataType*>(kargs.bs_ptr[i]) +
|
||||
splitk_batch_offset.bs_k_split_offset[i] / BPackedSize;
|
||||
}
|
||||
});
|
||||
|
||||
// Calculate output offset from tile partitioner and apply to output pointer
|
||||
|
||||
Reference in New Issue
Block a user