now offsetting with M/MPerXdl to get scales

This commit is contained in:
Sami Remes
2026-02-05 17:31:32 +00:00
parent c4daaf2334
commit a8d48f9224
3 changed files with 54 additions and 38 deletions

View File

@@ -386,6 +386,7 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
// Support both persistent and non-persistent modes
do
{
if (get_block_id() == 0 && get_thread_id() == 0) printf("partition_idx: %d\n", partition_idx);
const auto [iM, iN] =
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);

View File

@@ -420,27 +420,24 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
constexpr index_t ScaleKDimPerBlock = KPerBlock / ScaleBlockSize / KXdlPack; // Packed int32s per block
static_assert(ScaleBlockSize == 32, "Scale block size must be 32 for MX format");
// Scale A DRAM Window: [MWarp * MPerXdl, ScaleKDimPerBlock]
// With strided packing: KXdlPack kIters share each int32 via OpSel
auto scale_a_dram_window = make_tile_window(
scale_a_window.get_bottom_tensor_view(),
make_tuple(number<MWarp * MPerXdl>{}, number<ScaleKDimPerBlock>{}),
scale_a_window.get_window_origin(),
// Scale tensor views and base origins for creating tile windows per iteration
const auto& scale_a_tensor_view = scale_a_window.get_bottom_tensor_view();
const auto& scale_b_tensor_view = scale_b_window.get_bottom_tensor_view();
auto scale_a_base_origin = scale_a_window.get_window_origin();
auto scale_b_base_origin = scale_b_window.get_window_origin();
// Create sample scale windows to determine tile types
auto scale_a_dram_window_sample = make_tile_window(
scale_a_tensor_view,
make_tuple(number<MPerXdl>{}, number<ScaleKDimPerBlock>{}),
scale_a_base_origin,
Policy::template MakeMX_ScaleA_DramTileDistribution<Problem>());
const auto scale_a_dram_step_m = amd_wave_read_first_lane(
scale_a_dram_window.get_load_offset(tuple<number<MWarp * MPerXdl>, number<0>>{}));
// Scale B DRAM Window: [ScaleKDimPerBlock, NWarp * NPerXdl]
// With strided packing: KXdlPack kIters share each int32 via OpSel
auto scale_b_dram_window = make_tile_window(
scale_b_window.get_bottom_tensor_view(),
make_tuple(number<ScaleKDimPerBlock>{}, number<NWarp * NPerXdl>{}),
scale_b_window.get_window_origin(),
auto scale_b_dram_window_sample = make_tile_window(
scale_b_tensor_view,
make_tuple(number<NPerXdl>{}, number<ScaleKDimPerBlock>{}),
scale_b_base_origin,
Policy::template MakeMX_ScaleB_DramTileDistribution<Problem>());
const auto scale_b_dram_step_n = amd_wave_read_first_lane(
scale_b_dram_window.get_load_offset(tuple<number<NWarp * NPerXdl>, number<0>>{}));
// this pipeline has a pair of LDS buffers per logical tile
// TODO: check for packed size - are these blocks too big?
@@ -561,8 +558,8 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
static_assert(ScaleKPackedPerIter > 0, "ScaleKPackedPerIter must be positive!");
// Load a sample scale tile to get the type after distribution
auto scale_a_sample = load_tile_with_offset(scale_a_dram_window, tuple<number<0>, number<0>>{});
auto scale_b_sample = load_tile_with_offset(scale_b_dram_window, tuple<number<0>, number<0>>{});
auto scale_a_sample = load_tile(scale_a_dram_window_sample);
auto scale_b_sample = load_tile(scale_b_dram_window_sample);
using ScaleTileElementA = remove_cvref_t<decltype(scale_a_sample)>;
using ScaleTileElementB = remove_cvref_t<decltype(scale_b_sample)>;
@@ -578,22 +575,40 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
// Helper function to load scales
auto load_scales_ = [&](auto& scale_a, auto& scale_b) {
// Load scales for each M/N iteration
// Create tile windows from scratch with correct origins for each iteration
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) {
// scale_a(mIter)(kPacked) = load_tile_with_offset(
// scale_a_dram_window, mIter * scale_a_dram_step_m + kPacked * scale_a_dram_step_k);
// });
scale_a(mIter) = load_tile_with_offset(scale_a_dram_window, make_tuple(mIter * scale_a_dram_step_m, number<0>{}));
// Scale A: create window at origin {base_m + mIter * MPerXdl, base_k}
auto scale_a_origin = scale_a_base_origin;
scale_a_origin[number<0>{}] += mIter * MPerXdl;
auto scale_a_tile_window = make_tile_window(
scale_a_tensor_view,
make_tuple(number<MPerXdl>{}, number<ScaleKDimPerBlock>{}),
scale_a_origin,
Policy::template MakeMX_ScaleA_DramTileDistribution<Problem>());
scale_a(mIter) = load_tile(scale_a_tile_window);
});
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// Scale B viewed as [N, K], so N is first dimension
scale_b(nIter) = load_tile_with_offset(scale_b_dram_window, make_tuple(nIter * scale_b_dram_step_n, number<0>{}));
// Scale B: layout is [N, K], create window at origin {base_n + nIter * NPerXdl, base_k}
auto scale_b_origin = scale_b_base_origin;
scale_b_origin[number<0>{}] += nIter * NPerXdl;
auto scale_b_tile_window = make_tile_window(
scale_b_tensor_view,
make_tuple(number<NPerXdl>{}, number<ScaleKDimPerBlock>{}),
scale_b_origin,
Policy::template MakeMX_ScaleB_DramTileDistribution<Problem>());
scale_b(nIter) = load_tile(scale_b_tile_window);
});
// Advance to next KPerBlock
// Scale A: [M, K] -> advance in K (second dimension)
// Scale B: viewed as [N, K] -> advance in K (second dimension)
move_tile_window(scale_a_dram_window, {0, KPerBlock / ScaleBlockSize / KXdlPack});
move_tile_window(scale_b_dram_window, {0, KPerBlock / ScaleBlockSize / KXdlPack});
// Advance base origins to next KPerBlock
// Scale A: [M, K] -> advance in K (second dimension, index 1)
// Scale B: [N, K] -> advance in K (second dimension, index 1)
scale_a_base_origin[number<1>{}] += KPerBlock / ScaleBlockSize / KXdlPack;
scale_b_base_origin[number<1>{}] += KPerBlock / ScaleBlockSize / KXdlPack;
};
// constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() {

View File

@@ -221,11 +221,11 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarp>, // repeat over NWarps
tuple<sequence<MWarp, MPerXdl>, // M dimension
sequence<ScaleKDimPerBlock, K_Lane>>, // K dimension
sequence<ScaleKDimPerBlock / K_Lane, K_Lane, 1>>, // K dimension
tuple<sequence<1, 0>, sequence<2, 1>>, // <MWarp, NWarp>, <K_Lane, MPerXdl>
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<2>, // ScaleKDimPerBlock, all int32 needed to cover KPerBlock
sequence<0>>{});
sequence<2, 2>, // ScaleKDimPerBlock, all int32 needed to cover KPerBlock
sequence<0, 2>>{});
}
template <typename Problem>
@@ -251,11 +251,11 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarp>, // repeat over MWarps
tuple<sequence<NWarp, NPerXdl>, // N dimension (first)
sequence<ScaleKDimPerBlock, K_Lane>>, // K dimension (second)
sequence<ScaleKDimPerBlock / K_Lane, K_Lane, 1>>, // K dimension (second)
tuple<sequence<0, 1>, sequence<2, 1>>, // which direction
tuple<sequence<0, 0>, sequence<1, 1>>, // which index
sequence<2>, // replicate N
sequence<0>>{});
sequence<2, 2>, // replicate N
sequence<0, 2>>{});
}
};
} // namespace ck_tile