From cfb8ae528f697bc790022f1edb3ee3aa6bb3bff7 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Tue, 2 Dec 2025 14:21:12 +0800 Subject: [PATCH] [CK_Tile] Flatmm MX Cleanup & Explicite Offset Calculation (#3286) [ROCm/composable_kernel commit: f211156ce6e9a8411c9ab8c3647147b6a9cf78d8] --- .../ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp | 2 +- .../18_flatmm/mxgemm/mx_flatmm_instance.hpp | 2 +- include/ck_tile/core/tensor/tile_window.hpp | 16 +- ...mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 456 +++++++----------- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 24 +- 5 files changed, 206 insertions(+), 294 deletions(-) diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp index 8d3fd146bc..0134465347 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -158,7 +158,7 @@ auto create_args(int argc, char* argv[]) .insert("stride_c", "0", "Tensor C stride") .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert( - "mx_prec", "fp4xfp4", "data type for activation and weight, support: fp6xfp6, fp8xfp8") + "mx_prec", "fp4xfp4", "data type for activation and weight, support: fp4xfp4, fp8xfp8") .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp index 9c12509d59..f177ef04ca 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp @@ -75,7 +75,7 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, HasHotLoop, TailNum>; - using MXFlatmmPipeline = ck_tile::MXF4FlatmmPipelineAGmemBGmemCRegV1; + using MXFlatmmPipeline = ck_tile::MXFlatmmPipelineAGmemBGmemCRegV1; using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner{}, bool_constant{}); } + template + CK_TILE_DEVICE constexpr auto get_load_offset(offset_t = {}) const + { + constexpr auto bottom_tensor_idx_off = to_multi_index(offset_t{}); + const auto bottom_tensor_coord_off = make_tensor_coordinate( + this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off); + return amd_wave_read_first_lane(bottom_tensor_coord_off.get_offset()); + } + template ) return offset_t::value; else - { - auto bottom_tensor_idx_off = to_multi_index(offset_t{}); - auto bottom_tensor_coord_off = make_tensor_coordinate( - this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off); - return bottom_tensor_coord_off.get_offset(); - } + return get_load_offset(offset_t{}); }(); // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index e5c666de46..ff799cb0fc 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -46,8 +46,8 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem -struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 +template +struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 { using Underlying = FlatmmPipelineAGmemBGmemCRegV1; @@ -470,17 +470,39 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1(); } + template + CK_TILE_DEVICE auto operator()(Args&&... args) const + { + auto c_warp_tensors = Run_(std::forward(args)...); + + // Block GEMM Acc register tile + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + auto c_block_tile = BlockFlatmm{}.MakeCBlockTile(); + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensors(mIter)(nIter).get_thread_buffer()); + }); + }); + return c_block_tile; + } + template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_copy_dram_window_tmp, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - const ScaleADramBlockWindowTmp& scale_a_window, - const ScaleBDramBlockWindowTmp& scale_b_window, - index_t num_loop, - void* __restrict__ p_smem_ping, - void* __restrict__ p_smem_pong) const + CK_TILE_DEVICE auto Run_(const ADramBlockWindowTmp& a_copy_dram_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const ScaleADramBlockWindowTmp& scale_a_window, + const ScaleBDramBlockWindowTmp& scale_b_window, + index_t num_loop, + void* __restrict__ p_smem_ping, + void* __restrict__ p_smem_pong) const { #ifndef __gfx950__ static_assert(false, "Only gfx950 is supported for MXFP4 flatmm pipeline now."); @@ -497,19 +519,14 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}; - auto a_dram_window = - make_tile_window(PipelinePolicy::template MakeMXFP4_AAsyncLoadDramDescriptor( + make_tile_window(PipelinePolicy::template MakeMX_AAsyncLoadDramDescriptor( a_copy_dram_window_tmp.get_bottom_tensor_view()), a_copy_dram_window_tmp.get_window_lengths(), a_copy_dram_window_tmp.get_window_origin(), - PipelinePolicy::template MakeMXFP4_ADramTileDistribution()); + PipelinePolicy::template MakeMX_ADramTileDistribution()); __builtin_amdgcn_sched_barrier(0); @@ -518,7 +535,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1(p_smem_pong); constexpr auto a_lds_block_desc = - PipelinePolicy::template MakeMXFP4_ALdsBlockDescriptor(); + PipelinePolicy::template MakeMX_ALdsBlockDescriptor(); auto a_lds_block_ping = make_tensor_view(p_a_lds_ping, a_lds_block_desc); @@ -535,39 +552,34 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, number{}), {0, 0}, - PipelinePolicy::template MakeMXF4_ALDS_TileDistribution()); + PipelinePolicy::template MakeMX_ALDS_TileDistribution()); auto a_warp_window_pong = make_tile_window(a_lds_block_pong, make_tuple(number{}, number{}), {0, 0}, - PipelinePolicy::template MakeMXF4_ALDS_TileDistribution()); - - // Block GEMM - auto block_flatmm = BlockFlatmm(); - // Acc register tile - auto c_block_tile = block_flatmm.MakeCBlockTile(); + PipelinePolicy::template MakeMX_ALDS_TileDistribution()); // B flat DRAM window for load // pingpong buffer for B - auto b_flat_dram_windows = generate_tuple( + auto b_flat_dram_window = + make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeMX_BFlatDramTileDistribution()); + auto b_flat_dram_offsets = generate_tuple( [&](auto nIter) { constexpr auto packed_n_idx = nIter / number{}; constexpr auto packed_n_rank = nIter % number{}; - auto window_i = make_tile_window( - b_flat_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_flat_dram_block_window_tmp.get_window_origin(), - PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution()); - move_tile_window( - window_i, - {number{}, - number<0>{}}); - return window_i; + return b_flat_dram_window.get_load_offset( + tuple, + number<0>>{}) + + b_flat_dram_window.get_load_offset( + tuple, number<0>>{}); }, number{}); statically_indexed_array< - statically_indexed_array, + statically_indexed_array, NIterPerWarp> b_warp_tensor_ping, b_warp_tensor_pong; @@ -576,41 +588,37 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, number<64 / WG::kM>{}), scale_a_window.get_window_origin(), - PipelinePolicy::template MakeMXFP4_ScaleA_FlatDramTileDistribution()); + PipelinePolicy::template MakeMX_ScaleA_FlatDramTileDistribution()); + const auto scale_a_dram_step_m = amd_wave_read_first_lane( + scale_a_dram_window.get_load_offset(tuple, number<0>>{})); + const auto scale_a_dram_step_k = amd_wave_read_first_lane( + scale_a_dram_window.get_load_offset(tuple, number<64 / WG::kM>>{})); auto scale_b_dram_window = make_tile_window( scale_b_window.get_bottom_tensor_view(), make_tuple(number{}, number<64 / WG::kN>{}), scale_b_window.get_window_origin(), - PipelinePolicy::template MakeMXFP4_ScaleB_DramTileDistribution()); + PipelinePolicy::template MakeMX_ScaleB_DramTileDistribution()); + const auto scale_b_dram_step_n = amd_wave_read_first_lane( + scale_b_dram_window.get_load_offset(tuple, number<0>>{})); + const auto scale_b_dram_step_k = amd_wave_read_first_lane( + scale_b_dram_window.get_load_offset(tuple, number<64 / WG::kN>>{})); + + constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack; + constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack; + constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack; // ping pong buffer for scale A statically_indexed_array< - statically_indexed_array, - MIterPerWarp / MXdlPack> - scale_a_dram_windows; - statically_indexed_array, - MIterPerWarp / MXdlPack> - scale_a_tile_tensor_ping; - statically_indexed_array, - MIterPerWarp / MXdlPack> - scale_a_tile_tensor_pong; + statically_indexed_array, + MPackIterPerWarp> + scale_a_tile_tensor_ping, scale_a_tile_tensor_pong; // ping pong buffer for scale B statically_indexed_array< - statically_indexed_array, - NIterPerWarp / NXdlPack> - scale_b_dram_windows; - statically_indexed_array, - NIterPerWarp / NXdlPack> - scale_b_tile_tensor_ping; - statically_indexed_array, - NIterPerWarp / NXdlPack> - scale_b_tile_tensor_pong; + statically_indexed_array, + NPackIterPerWarp> + scale_b_tile_tensor_ping, scale_b_tile_tensor_pong; auto async_load_tile_ = [](auto lds, auto dram) { async_load_tile(lds, dram, number<-1>{}, true_type{}, false_type{}); @@ -625,35 +633,31 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_windows(nIter), number{}); + b_flat_dram_window, b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); }); // move B window to next flat K - move_tile_window(b_flat_dram_windows(nIter), {0, KIterPerWarp * KFlatPerBlockPerIter}); + b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( + tuple, number>{}); }); // prefetch Scale A - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = load_tile_with_offset( + scale_a_dram_window, - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k); }); }); // move Scale A window to next K move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); // prefetch Scale B - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = load_tile_with_offset( + scale_b_dram_window, + nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k); }); }); // move Scale B window to next K @@ -667,7 +671,12 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, MIterPerWarp> + c_warp_tensors; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}( + [&](auto nIter) { clear_tile(c_warp_tensors(mIter)(nIter)); }); + }); statically_indexed_array a_warp_tensor; @@ -688,40 +697,37 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_windows(nIter), number{}); + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + + // move B window to next flat K if constexpr(kIter == KIterPerWarp - 1) - move_tile_window(b_flat_dram_windows(nIter), - {0, BlockGemmShape::flatKPerBlock}); + b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( + tuple, number>{}); }); }); // prefetch Scale A and Scale B (2i+1) - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); - - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = load_tile_with_offset( + scale_a_dram_window, + mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k); }); }); - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = load_tile_with_offset( + scale_b_dram_window, + nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k); }); }); // GEMM 2i - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { static_for<0, MXdlPack, 1>{}([&](auto imxdl) { constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; @@ -729,39 +735,22 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto inxdl) { constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; - - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM WG{}.template operator()( - c_warp_tensor, + c_warp_tensors(number{})(number{}), a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), + b_warp_tensor_ping(number{})(number{}), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) .get_thread_buffer()[0]); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); // preload next A from lds constexpr auto addr = m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) + (nIter_pack == NPackIterPerWarp - 1)) { constexpr auto AmIter = addr % 2 + addr / 4 * 2; constexpr auto AkIter = addr / 2 % 2; @@ -802,81 +791,60 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_windows(nIter), number{}); + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + + // move B window to next flat K if constexpr(kIter == KIterPerWarp - 1) - move_tile_window(b_flat_dram_windows(nIter), - {0, BlockGemmShape::flatKPerBlock}); + b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( + tuple, number>{}); }); }); // prefetch Scale A and Scale B (2i+2) - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); - - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = load_tile_with_offset( + scale_a_dram_window, + mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k); }); }); - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = load_tile_with_offset( + scale_b_dram_window, + nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k); }); }); // GEMM 2i+1 - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { static_for<0, MXdlPack, 1>{}([&](auto imxdl) { constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; // warp GEMM WG{}.template operator()( - c_warp_tensor, + c_warp_tensors(number{})(number{}), a_warp_tensor(number{}), - b_warp_tensor_pong(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), + b_warp_tensor_pong(number{})(number{}), scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) .get_thread_buffer()[0]); // scale B - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; + constexpr auto addr = + m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) + (nIter_pack == NPackIterPerWarp - 1)) { constexpr auto AmIter = addr % 2 + addr / 4 * 2; constexpr auto AkIter = addr / 2 % 2; @@ -928,78 +896,54 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_windows(nIter), - make_tuple(number<0>{}, number{})); + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); }); }); // prefetch Scale A and Scale B (2i+1) - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); - - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = load_tile_with_offset( + scale_a_dram_window, + mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k); }); }); - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = load_tile_with_offset( + scale_b_dram_window, + nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k); }); }); // GEMM loopK-1 - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { static_for<0, MXdlPack, 1>{}([&](auto imxdl) { constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; // warp GEMM WG{}.template operator()( - c_warp_tensor, + c_warp_tensors(number{})(number{}), a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), + b_warp_tensor_ping(number{})(number{}), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) .get_thread_buffer()[0]); // scale B - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; + constexpr auto addr = + m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) + (nIter_pack == NPackIterPerWarp - 1)) { constexpr auto AmIter = addr % 2 + addr / 4 * 2; constexpr auto AkIter = addr / 2 % 2; @@ -1028,50 +972,32 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { static_for<0, MXdlPack, 1>{}([&](auto imxdl) { constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; // warp GEMM WG{}.template operator()( - c_warp_tensor, + c_warp_tensors(number{})(number{}), a_warp_tensor(number{}), - b_warp_tensor_pong(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), + b_warp_tensor_pong(number{})(number{}), scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) .get_thread_buffer()[0]); // scale B - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; + constexpr auto addr = + m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) + (nIter_pack == NPackIterPerWarp - 1)) { constexpr auto AmIter = addr % 2 + addr / 4 * 2; constexpr auto AkIter = addr / 2 % 2; @@ -1089,50 +1015,32 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { static_for<0, MXdlPack, 1>{}([&](auto imxdl) { constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; // warp GEMM WG{}.template operator()( - c_warp_tensor, + c_warp_tensors(number{})(number{}), a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), + b_warp_tensor_ping(number{})(number{}), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) .get_thread_buffer()[0]); // scale B - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; + constexpr auto addr = + m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) + (nIter_pack == NPackIterPerWarp - 1)) { constexpr auto AmIter = addr % 2 + addr / 4 * 2; constexpr auto AkIter = addr / 2 % 2; @@ -1151,7 +1059,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}; static constexpr auto I1 = number<1>{}; @@ -58,7 +58,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy template CK_TILE_DEVICE static constexpr auto - MakeMXFP4_AAsyncLoadDramDescriptor(const TensorView& naive_view) + MakeMX_AAsyncLoadDramDescriptor(const TensorView& naive_view) { using ADataType = remove_cvref_t; using ALayout = remove_cvref_t; @@ -107,7 +107,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution() + CK_TILE_DEVICE static constexpr auto MakeMX_ADramTileDistribution() { using ADataType = remove_cvref_t; @@ -140,7 +140,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor() + CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBlockDescriptor() { using ADataType = remove_cvref_t; using ALayout = remove_cvref_t; @@ -218,7 +218,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDS_TileDistribution() { using TileShape = typename Problem::BlockGemmShape; @@ -255,7 +255,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_BFlatDramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatDramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; @@ -298,7 +298,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_DramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape @@ -335,7 +335,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_DramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape @@ -372,7 +372,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_FlatDramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_FlatDramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; @@ -394,7 +394,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_FlatDramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_FlatDramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; @@ -420,8 +420,8 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy { using ADataType = remove_cvref_t; constexpr index_t APackedSize = numeric_traits::PackedSize; - return sizeof(ADataType) * - MakeMXFP4_ALdsBlockDescriptor().get_element_space_size() / APackedSize; + return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor().get_element_space_size() / + APackedSize; } template