mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[rocm-libraries] ROCm/rocm-libraries#5323 (commit 5454e9e)
CK Tile MX GEMM Packing Improvement ## Motivation Reduce the scale loading size and also has better utilization of MFMA scale selection. ## Technical Details Add up the packing of mx scales. ## Test Plan Use the existing test cases. ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
859acb5ae7
commit
5f90f69795
@@ -315,14 +315,36 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
},
|
||||
number<BsLayout::size()>{});
|
||||
|
||||
////////////// MX Scale windows /////////////////
|
||||
////////////// MX Scale windows (pre-packed int32_t) /////////////////
|
||||
// Get WarpGemm configuration
|
||||
using BlockWarps = typename BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
constexpr index_t MWarp = BlockWarps::at(I0{});
|
||||
constexpr index_t NWarp = BlockWarps::at(I1{});
|
||||
|
||||
// Calculate scale dimensions: KPerBlock elements need KPerBlock/32 e8m0_t scales
|
||||
constexpr index_t ScaleKDimPerBlock = KPerBlock / ScaleBlockSize;
|
||||
// Compute effective XdlPack sizes (fall back to 1 when iter count < pack)
|
||||
constexpr index_t MPerXdl = WarpTile::at(I0{});
|
||||
constexpr index_t NPerXdl = WarpTile::at(I1{});
|
||||
constexpr index_t KPerXdl = WarpTile::at(I2{});
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / KPerXdl;
|
||||
|
||||
constexpr index_t MXdlPackEff =
|
||||
(MIterPerWarp >= Policy::MXdlPack && MIterPerWarp % Policy::MXdlPack == 0)
|
||||
? Policy::MXdlPack
|
||||
: 1;
|
||||
constexpr index_t NXdlPackEff =
|
||||
(NIterPerWarp >= Policy::NXdlPack && NIterPerWarp % Policy::NXdlPack == 0)
|
||||
? Policy::NXdlPack
|
||||
: 1;
|
||||
constexpr index_t KXdlPackEff =
|
||||
(KIterPerWarp >= Policy::KXdlPack && KIterPerWarp % Policy::KXdlPack == 0)
|
||||
? Policy::KXdlPack
|
||||
: 1;
|
||||
|
||||
// Packed scale dimensions
|
||||
constexpr index_t ScaleKDimPerBlock = KPerBlock / ScaleBlockSize / KXdlPackEff;
|
||||
|
||||
// Scale tensor views and base origins for creating tile windows per iteration
|
||||
const auto& scale_a_tensor_view = scale_a_window.get_bottom_tensor_view();
|
||||
@@ -330,18 +352,18 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
auto scale_a_base_origin = scale_a_window.get_window_origin();
|
||||
auto scale_b_base_origin = scale_b_window.get_window_origin();
|
||||
|
||||
// Create sample scale windows to determine tile types
|
||||
auto scale_a_dram_window =
|
||||
make_tile_window(scale_a_tensor_view,
|
||||
make_tuple(number<MPerBlock>{}, number<ScaleKDimPerBlock>{}),
|
||||
scale_a_base_origin,
|
||||
Policy::template MakeMX_ScaleA_DramTileDistribution<Problem>());
|
||||
// Create scale windows with packed int32_t dimensions
|
||||
auto scale_a_dram_window = make_tile_window(
|
||||
scale_a_tensor_view,
|
||||
make_tuple(number<MPerBlock / MXdlPackEff>{}, number<ScaleKDimPerBlock>{}),
|
||||
scale_a_base_origin,
|
||||
Policy::template MakeMX_ScaleA_DramTileDistribution<Problem>());
|
||||
|
||||
auto scale_b_dram_window =
|
||||
make_tile_window(scale_b_tensor_view,
|
||||
make_tuple(number<NPerBlock>{}, number<ScaleKDimPerBlock>{}),
|
||||
scale_b_base_origin,
|
||||
Policy::template MakeMX_ScaleB_DramTileDistribution<Problem>());
|
||||
auto scale_b_dram_window = make_tile_window(
|
||||
scale_b_tensor_view,
|
||||
make_tuple(number<NPerBlock / NXdlPackEff>{}, number<ScaleKDimPerBlock>{}),
|
||||
scale_b_base_origin,
|
||||
Policy::template MakeMX_ScaleB_DramTileDistribution<Problem>());
|
||||
|
||||
// this pipeline has a pair of LDS buffers per logical tile
|
||||
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
|
||||
@@ -427,8 +449,8 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
"SmemSizeB size is wrong!");
|
||||
|
||||
////////////// MX Scale register tiles (ping-pong buffers) /////////////////
|
||||
// No packing needed - each thread gets e8m0_t elements directly
|
||||
// Each thread will cast e8m0_t to int32_t for WarpGemm with OpSel=0
|
||||
// Scales are pre-packed int32_t: each int32_t holds 2M/N x 2K e8m0_t values
|
||||
// Block GEMM uses OpSel (0-3) to select the right byte per MFMA call
|
||||
|
||||
using ScaleATileType = decltype(load_tile(scale_a_dram_window));
|
||||
using ScaleBTileType = decltype(load_tile(scale_b_dram_window));
|
||||
|
||||
@@ -131,7 +131,15 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
// MX Scale tile distributions for loading from global memory
|
||||
// XdlPack: how many e8m0_t scale values are packed into one int32_t per dimension
|
||||
// Host packs MXdlPack * KXdlPack e8m0_t into one int32_t for A scales
|
||||
// Host packs NXdlPack * KXdlPack e8m0_t into one int32_t for B scales
|
||||
static constexpr int MXdlPack = 2;
|
||||
static constexpr int NXdlPack = 2;
|
||||
static constexpr int KXdlPack = 2;
|
||||
|
||||
// MX Scale tile distributions for loading pre-packed int32_t from global memory
|
||||
// Packed layout: [M/MXdlPack, K/32/KXdlPack] of int32_t
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
|
||||
{
|
||||
@@ -145,21 +153,29 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
constexpr index_t MPerXdl = WarpTile::at(number<0>{});
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K_Lane = get_warp_size() / MPerXdl; // 64/16 = 4 threads in K dimension
|
||||
constexpr index_t K_Lane = get_warp_size() / MPerXdl;
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl);
|
||||
constexpr index_t KPerXdl = WarpTile::at(number<2>{});
|
||||
constexpr index_t KIterPerWarp = KPerBlock / KPerXdl;
|
||||
constexpr index_t KPerLane = KPerXdl / BlockScaleSize / K_Lane;
|
||||
|
||||
// Effective pack sizes: fall back to 1 when iteration count < pack size
|
||||
constexpr index_t MXdlPackEff =
|
||||
(MIterPerWarp >= MXdlPack && MIterPerWarp % MXdlPack == 0) ? MXdlPack : 1;
|
||||
constexpr index_t KXdlPackEff =
|
||||
(KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1;
|
||||
|
||||
constexpr index_t MIterPerWarp_packed = MIterPerWarp / MXdlPackEff;
|
||||
constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<NWarp>, // repeat over MWarps
|
||||
tuple<sequence<MIterPerWarp, MWarp, MPerXdl>, // M dimension (first)
|
||||
sequence<KIterPerWarp, K_Lane, KPerLane>>, // K dimension (second)
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>, // <MWarp, NWarp>, <K_Lane, MPerXdl>
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<2, 1, 2>, // <KIterPerWarp, MIterPerWarp, KPerLane>
|
||||
sequence<0, 0, 2>>{});
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp_packed, MWarp, MPerXdl>,
|
||||
sequence<KIterPerWarp_packed, K_Lane, KPerLane>>,
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -169,27 +185,35 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
using BlockWarps = typename BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t MWarp = BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = BlockWarps::at(number<1>{});
|
||||
constexpr index_t NPerXdl = WarpTile::at(number<1>{});
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 64/16 = 4 threads in K dimension
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t MWarp = BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = BlockWarps::at(number<1>{});
|
||||
constexpr index_t NPerXdl = WarpTile::at(number<1>{});
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t K_Lane = get_warp_size() / NPerXdl;
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl);
|
||||
|
||||
constexpr index_t KPerXdl = WarpTile::at(number<2>{});
|
||||
constexpr index_t KIterPerWarp = KPerBlock / KPerXdl;
|
||||
constexpr index_t KPerLane = KPerXdl / BlockScaleSize / K_Lane;
|
||||
|
||||
// Effective pack sizes: fall back to 1 when iteration count < pack size
|
||||
constexpr index_t NXdlPackEff =
|
||||
(NIterPerWarp >= NXdlPack && NIterPerWarp % NXdlPack == 0) ? NXdlPack : 1;
|
||||
constexpr index_t KXdlPackEff =
|
||||
(KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1;
|
||||
|
||||
constexpr index_t NIterPerWarp_packed = NIterPerWarp / NXdlPackEff;
|
||||
constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<MWarp>, // repeat over MWarps
|
||||
tuple<sequence<NIterPerWarp, NWarp, NPerXdl>, // N dimension (first)
|
||||
sequence<KIterPerWarp, K_Lane, KPerLane>>, // K dimension (second)
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>, // <MWarp, NWarp>, <K_Lane, MPerXdl>
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<2, 1, 2>, // <KIterPerWarp, NIterPerWarp, KPerLane>
|
||||
sequence<0, 0, 2>>{});
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp_packed, NWarp, NPerXdl>,
|
||||
sequence<KIterPerWarp_packed, K_Lane, KPerLane>>,
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user