mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[rocm-libraries] ROCm/rocm-libraries#5323 (commit 5454e9e)
CK Tile MX GEMM Packing Improvement ## Motivation Reduce the scale loading size and also has better utilization of MFMA scale selection. ## Technical Details Add up the packing of mx scales. ## Test Plan Use the existing test cases. ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
859acb5ae7
commit
5f90f69795
@@ -98,6 +98,30 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
static constexpr auto APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
// XdlPack: desired packing of e8m0_t scale values into int32_t
|
||||
static constexpr index_t MXdlPack = 2;
|
||||
static constexpr index_t NXdlPack = 2;
|
||||
static constexpr index_t KXdlPack = 2;
|
||||
|
||||
// Effective pack sizes: fall back to 1 when dimension is too small
|
||||
using BlockWarps_ = typename BlockGemmShape::BlockWarps;
|
||||
static constexpr index_t MPerBlock_ = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock_ = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock_ = BlockGemmShape::kK;
|
||||
static constexpr index_t MWarp_ = BlockWarps_::at(number<0>{});
|
||||
static constexpr index_t NWarp_ = BlockWarps_::at(number<1>{});
|
||||
static constexpr index_t KPerXdl_ = BlockGemmShape::WarpTile::at(number<2>{});
|
||||
static constexpr index_t MIterPerWarp_ = MPerBlock_ / (MWarp_ * MThreadPerXdl);
|
||||
static constexpr index_t NIterPerWarp_ = NPerBlock_ / (NWarp_ * NThreadPerXdl);
|
||||
static constexpr index_t KIterPerWarp_ = KPerBlock_ / KPerXdl_;
|
||||
|
||||
static constexpr index_t MXdlPackEff =
|
||||
(MIterPerWarp_ >= MXdlPack && MIterPerWarp_ % MXdlPack == 0) ? MXdlPack : 1;
|
||||
static constexpr index_t NXdlPackEff =
|
||||
(NIterPerWarp_ >= NXdlPack && NIterPerWarp_ % NXdlPack == 0) ? NXdlPack : 1;
|
||||
static constexpr index_t KXdlPackEff =
|
||||
(KIterPerWarp_ >= KXdlPack && KIterPerWarp_ % KXdlPack == 0) ? KXdlPack : 1;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
@@ -245,7 +269,9 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
return c_block_window;
|
||||
}
|
||||
|
||||
// Create scale A block windows following the pattern of MakeABlockWindows
|
||||
// Create scale A block windows with packed int32_t layout
|
||||
// Host packs 2M x 2K e8m0_t values into one int32_t
|
||||
// Tensor view: [M/MXdlPack, K/32/KXdlPack] of int32_t
|
||||
template <typename ScaleM, typename ScaleN>
|
||||
CK_TILE_DEVICE static auto MakeScaleABlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs,
|
||||
const index_t i_m)
|
||||
@@ -253,28 +279,28 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
auto scale_a = kargs.scale_m_ptr;
|
||||
|
||||
static constexpr int BlockScaleSize = ScaleM::GranularityK;
|
||||
const auto scale_k_size = kargs.K / BlockScaleSize;
|
||||
const auto scale_k_packed = kargs.K / BlockScaleSize / KXdlPackEff;
|
||||
const auto scale_m_packed = kargs.M / MXdlPackEff;
|
||||
|
||||
// A scale tensor view - layout [M, scale_k_size] with e8m0_t elements
|
||||
// Use e8m0_t directly without packing
|
||||
// A scale tensor view - layout [M/MXdlPackEff, K/32/KXdlPackEff] with int32_t elements
|
||||
const auto scale_a_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const e8m0_t*>(scale_a.ptr),
|
||||
make_tuple(kargs.M, scale_k_size),
|
||||
make_tuple(scale_k_size, 1));
|
||||
reinterpret_cast<const int32_t*>(scale_a.ptr),
|
||||
make_tuple(scale_m_packed, scale_k_packed),
|
||||
make_tuple(scale_k_packed, 1));
|
||||
|
||||
// Create block window for scale A
|
||||
// K dimension: scale_k_size e8m0_t elements
|
||||
// i_m is element offset (iM * MPerBlock), not tile index
|
||||
auto scale_a_block_window =
|
||||
make_tile_window(scale_a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / BlockScaleSize>{}),
|
||||
{i_m, 0});
|
||||
// Tile window shape: [MPerBlock/MXdlPackEff, KPerBlock/32/KXdlPackEff]
|
||||
auto scale_a_block_window = make_tile_window(
|
||||
scale_a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock / MXdlPackEff>{},
|
||||
number<TilePartitioner::KPerBlock / BlockScaleSize / KXdlPackEff>{}),
|
||||
{i_m / MXdlPackEff, 0});
|
||||
|
||||
return scale_a_block_window;
|
||||
}
|
||||
|
||||
// Create scale B block windows following the pattern of MakeBBlockWindows
|
||||
// Create scale B block windows with packed int32_t layout
|
||||
// Host packs 2N x 2K e8m0_t values into one int32_t
|
||||
// Tensor view: [N/NXdlPack, K/32/KXdlPack] of int32_t
|
||||
template <typename ScaleM, typename ScaleN>
|
||||
CK_TILE_DEVICE static auto MakeScaleBBlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs,
|
||||
const index_t i_n)
|
||||
@@ -282,23 +308,21 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
auto scale_b = kargs.scale_n_ptr;
|
||||
|
||||
static constexpr int BlockScaleSize = ScaleN::GranularityK;
|
||||
const auto scale_k_size = kargs.K / BlockScaleSize;
|
||||
const auto scale_k_packed = kargs.K / BlockScaleSize / KXdlPackEff;
|
||||
const auto scale_n_packed = kargs.N / NXdlPackEff;
|
||||
|
||||
// B scale tensor view
|
||||
// Host stores as [K/32, N] col-major = [N, K/32] row-major from access perspective
|
||||
// B scale tensor view - [N/NXdlPackEff, K/32/KXdlPackEff] of int32_t
|
||||
const auto scale_b_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const e8m0_t*>(scale_b.ptr),
|
||||
make_tuple(kargs.N, scale_k_size), // [N, K/32] for access
|
||||
make_tuple(scale_k_size, 1)); // stride to match col-major storage
|
||||
reinterpret_cast<const int32_t*>(scale_b.ptr),
|
||||
make_tuple(scale_n_packed, scale_k_packed),
|
||||
make_tuple(scale_k_packed, 1));
|
||||
|
||||
// Create block window for scale B
|
||||
// Tile window shape matches access pattern: [NPerBlock, KPerBlock/32]
|
||||
// i_n is element offset (iN * NPerBlock)
|
||||
auto scale_b_block_window =
|
||||
make_tile_window(scale_b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / BlockScaleSize>{}),
|
||||
{i_n, 0});
|
||||
// Tile window shape: [NPerBlock/NXdlPackEff, KPerBlock/32/KXdlPackEff]
|
||||
auto scale_b_block_window = make_tile_window(
|
||||
scale_b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock / NXdlPackEff>{},
|
||||
number<TilePartitioner::KPerBlock / BlockScaleSize / KXdlPackEff>{}),
|
||||
{i_n / NXdlPackEff, 0});
|
||||
|
||||
return scale_b_block_window;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user