mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 06:44:36 +00:00
now offsetting with M/MPerXdl to get scales
This commit is contained in:
@@ -386,6 +386,7 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
// Support both persistent and non-persistent modes
|
||||
do
|
||||
{
|
||||
if (get_block_id() == 0 && get_thread_id() == 0) printf("partition_idx: %d\n", partition_idx);
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
|
||||
@@ -420,27 +420,24 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
constexpr index_t ScaleKDimPerBlock = KPerBlock / ScaleBlockSize / KXdlPack; // Packed int32s per block
|
||||
static_assert(ScaleBlockSize == 32, "Scale block size must be 32 for MX format");
|
||||
|
||||
// Scale A DRAM Window: [MWarp * MPerXdl, ScaleKDimPerBlock]
|
||||
// With strided packing: KXdlPack kIters share each int32 via OpSel
|
||||
auto scale_a_dram_window = make_tile_window(
|
||||
scale_a_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<MWarp * MPerXdl>{}, number<ScaleKDimPerBlock>{}),
|
||||
scale_a_window.get_window_origin(),
|
||||
// Scale tensor views and base origins for creating tile windows per iteration
|
||||
const auto& scale_a_tensor_view = scale_a_window.get_bottom_tensor_view();
|
||||
const auto& scale_b_tensor_view = scale_b_window.get_bottom_tensor_view();
|
||||
auto scale_a_base_origin = scale_a_window.get_window_origin();
|
||||
auto scale_b_base_origin = scale_b_window.get_window_origin();
|
||||
|
||||
// Create sample scale windows to determine tile types
|
||||
auto scale_a_dram_window_sample = make_tile_window(
|
||||
scale_a_tensor_view,
|
||||
make_tuple(number<MPerXdl>{}, number<ScaleKDimPerBlock>{}),
|
||||
scale_a_base_origin,
|
||||
Policy::template MakeMX_ScaleA_DramTileDistribution<Problem>());
|
||||
|
||||
const auto scale_a_dram_step_m = amd_wave_read_first_lane(
|
||||
scale_a_dram_window.get_load_offset(tuple<number<MWarp * MPerXdl>, number<0>>{}));
|
||||
|
||||
// Scale B DRAM Window: [ScaleKDimPerBlock, NWarp * NPerXdl]
|
||||
// With strided packing: KXdlPack kIters share each int32 via OpSel
|
||||
auto scale_b_dram_window = make_tile_window(
|
||||
scale_b_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<ScaleKDimPerBlock>{}, number<NWarp * NPerXdl>{}),
|
||||
scale_b_window.get_window_origin(),
|
||||
auto scale_b_dram_window_sample = make_tile_window(
|
||||
scale_b_tensor_view,
|
||||
make_tuple(number<NPerXdl>{}, number<ScaleKDimPerBlock>{}),
|
||||
scale_b_base_origin,
|
||||
Policy::template MakeMX_ScaleB_DramTileDistribution<Problem>());
|
||||
|
||||
const auto scale_b_dram_step_n = amd_wave_read_first_lane(
|
||||
scale_b_dram_window.get_load_offset(tuple<number<NWarp * NPerXdl>, number<0>>{}));
|
||||
|
||||
// this pipeline has a pair of LDS buffers per logical tile
|
||||
// TODO: check for packed size - are these blocks too big?
|
||||
@@ -561,8 +558,8 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
static_assert(ScaleKPackedPerIter > 0, "ScaleKPackedPerIter must be positive!");
|
||||
|
||||
// Load a sample scale tile to get the type after distribution
|
||||
auto scale_a_sample = load_tile_with_offset(scale_a_dram_window, tuple<number<0>, number<0>>{});
|
||||
auto scale_b_sample = load_tile_with_offset(scale_b_dram_window, tuple<number<0>, number<0>>{});
|
||||
auto scale_a_sample = load_tile(scale_a_dram_window_sample);
|
||||
auto scale_b_sample = load_tile(scale_b_dram_window_sample);
|
||||
|
||||
using ScaleTileElementA = remove_cvref_t<decltype(scale_a_sample)>;
|
||||
using ScaleTileElementB = remove_cvref_t<decltype(scale_b_sample)>;
|
||||
@@ -578,22 +575,40 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
// Helper function to load scales
|
||||
auto load_scales_ = [&](auto& scale_a, auto& scale_b) {
|
||||
// Load scales for each M/N iteration
|
||||
// Create tile windows from scratch with correct origins for each iteration
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) {
|
||||
// scale_a(mIter)(kPacked) = load_tile_with_offset(
|
||||
// scale_a_dram_window, mIter * scale_a_dram_step_m + kPacked * scale_a_dram_step_k);
|
||||
// });
|
||||
scale_a(mIter) = load_tile_with_offset(scale_a_dram_window, make_tuple(mIter * scale_a_dram_step_m, number<0>{}));
|
||||
// Scale A: create window at origin {base_m + mIter * MPerXdl, base_k}
|
||||
auto scale_a_origin = scale_a_base_origin;
|
||||
scale_a_origin[number<0>{}] += mIter * MPerXdl;
|
||||
|
||||
auto scale_a_tile_window = make_tile_window(
|
||||
scale_a_tensor_view,
|
||||
make_tuple(number<MPerXdl>{}, number<ScaleKDimPerBlock>{}),
|
||||
scale_a_origin,
|
||||
Policy::template MakeMX_ScaleA_DramTileDistribution<Problem>());
|
||||
|
||||
scale_a(mIter) = load_tile(scale_a_tile_window);
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Scale B viewed as [N, K], so N is first dimension
|
||||
scale_b(nIter) = load_tile_with_offset(scale_b_dram_window, make_tuple(nIter * scale_b_dram_step_n, number<0>{}));
|
||||
// Scale B: layout is [N, K], create window at origin {base_n + nIter * NPerXdl, base_k}
|
||||
auto scale_b_origin = scale_b_base_origin;
|
||||
scale_b_origin[number<0>{}] += nIter * NPerXdl;
|
||||
|
||||
auto scale_b_tile_window = make_tile_window(
|
||||
scale_b_tensor_view,
|
||||
make_tuple(number<NPerXdl>{}, number<ScaleKDimPerBlock>{}),
|
||||
scale_b_origin,
|
||||
Policy::template MakeMX_ScaleB_DramTileDistribution<Problem>());
|
||||
|
||||
scale_b(nIter) = load_tile(scale_b_tile_window);
|
||||
});
|
||||
// Advance to next KPerBlock
|
||||
// Scale A: [M, K] -> advance in K (second dimension)
|
||||
// Scale B: viewed as [N, K] -> advance in K (second dimension)
|
||||
move_tile_window(scale_a_dram_window, {0, KPerBlock / ScaleBlockSize / KXdlPack});
|
||||
move_tile_window(scale_b_dram_window, {0, KPerBlock / ScaleBlockSize / KXdlPack});
|
||||
|
||||
// Advance base origins to next KPerBlock
|
||||
// Scale A: [M, K] -> advance in K (second dimension, index 1)
|
||||
// Scale B: [N, K] -> advance in K (second dimension, index 1)
|
||||
scale_a_base_origin[number<1>{}] += KPerBlock / ScaleBlockSize / KXdlPack;
|
||||
scale_b_base_origin[number<1>{}] += KPerBlock / ScaleBlockSize / KXdlPack;
|
||||
};
|
||||
|
||||
// constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() {
|
||||
|
||||
@@ -221,11 +221,11 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarp>, // repeat over NWarps
|
||||
tuple<sequence<MWarp, MPerXdl>, // M dimension
|
||||
sequence<ScaleKDimPerBlock, K_Lane>>, // K dimension
|
||||
sequence<ScaleKDimPerBlock / K_Lane, K_Lane, 1>>, // K dimension
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>, // <MWarp, NWarp>, <K_Lane, MPerXdl>
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<2>, // ScaleKDimPerBlock, all int32 needed to cover KPerBlock
|
||||
sequence<0>>{});
|
||||
sequence<2, 2>, // ScaleKDimPerBlock, all int32 needed to cover KPerBlock
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -251,11 +251,11 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarp>, // repeat over MWarps
|
||||
tuple<sequence<NWarp, NPerXdl>, // N dimension (first)
|
||||
sequence<ScaleKDimPerBlock, K_Lane>>, // K dimension (second)
|
||||
sequence<ScaleKDimPerBlock / K_Lane, K_Lane, 1>>, // K dimension (second)
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>, // which index
|
||||
sequence<2>, // replicate N
|
||||
sequence<0>>{});
|
||||
sequence<2, 2>, // replicate N
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user