From 7816812ef8e2fe2ef1cf6d11729ac8b79e6f8c32 Mon Sep 17 00:00:00 2001 From: Christopher Millette <63608002+cgmillette@users.noreply.github.com> Date: Tue, 7 Apr 2026 08:36:45 -0600 Subject: [PATCH] [CK_TILE] Flatten nested static_for loops into static_ford (#5939) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Mechanical conversion of 129 nested `static_for`/`static_ford` patterns to flat `static_ford` across 29 ck_tile header files. Each conversion eliminates intermediate lambda closure instantiations by replacing nested compile-time loops with a single flat iteration using index decomposition. ### What `static_ford` eliminates When `static_for` loops are nested, each level creates unique closure types: ```cpp // BEFORE: M + M×N = 20 IR functions (for M=4, N=4) static_for<0, 4, 1>{}([&](auto m) { // 4 closure instantiations static_for<0, 4, 1>{}([&](auto n) { // 4×4 = 16 closure instantiations body(m, n); }); }); // AFTER: M×N = 16 IR functions (with ford_applier, no intermediates) static_ford>{}([&](auto mn) { constexpr auto m = number{}]>{}; constexpr auto n = number{}]>{}; body(m, n); }); ``` ### Pattern categories converted | Category | Count | Description | |----------|-------|-------------| | C (2-level `static_for` chains) | 112 | Nested `static_for` → `static_ford` | | C3 (3-level `static_for` chains) | 9 | Three consecutive nests → `static_ford` | | Partial rescue | 3 | Outer 2 levels of blocked 4-level nests | | B (nested `static_ford` merge) | 5 | Two nested `static_ford` → single higher-dim `static_ford` | | **Total** | **129** | Across 29 files | 6 false positives were detected and reverted (in `tensor_adaptor.hpp`, `tile_distribution.hpp`, `tile_distribution_encoding.hpp`) where the inner loop bound depended on the outer variable. ### Files changed by family | Family | Files | Sites | |--------|-------|-------| | Block GEMM | 12 | ~20 | | FlatMM pipelines | 4 | ~69 (including 5 ford-ford merges) | | GEMM quant | 7 | ~22 | | FlatMM kernel | 1 | 2 | | FMHA | 1 | 2 | | Reduce/norm | 2 | 2 | | Epilogue | 1 | 1 | ### Blocked locations from review comments - **block_gemm_areg_breg_creg_v1.hpp:356** — BLOCKED: runtime scale loads (`scale_a_slice`, `scale_b_slice`, A warp tensor load) between every nesting level - **block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp:228** — BLOCKED: `zero_accumulators()` before inner loop; `sched_barrier` + conditional `block_sync_lds()` after inner loop - **block_universal_gemm_as_aquant_bs_bquant_cr.hpp:298** — BLOCKED: runtime `CWarpTensor` construction before inner loop; quantization scale application code after inner loop - **block_universal_gemm_as_aquant_bs_cr.hpp:277** — BLOCKED: same pattern as above - **block_universal_gemm_as_bs_bquant_cr.hpp:367** — BLOCKED: same pattern as above ## Depends on - #5938 ([CK_TILE] Optimize static_ford and sequence compile-time infrastructure) — provides the `ford_applier` that makes these conversions beneficial. Without it, `static_ford` uses a recursive implementation that provides no IR function savings. ## Results (combined with #5938) ### Build Time (Wilcoxon signed-rank, 7 paired trials, gfx942) | Target | Base (s) | Treat (s) | Delta | % | Significant? | |--------|----------|-----------|-------|---|-------------| | **flatmm** | 161.1 | 149.0 | **-12.1s** | **-7.5%** | **YES** (p<0.01, 7/7 wins) | | **universal_gemm** | 225.4 | 220.3 | **-5.1s** | **-2.3%** | **YES** (p<0.01, 7/7 wins) | ### IR Function Counts (device trace, gfx942) | Target | InstFunc | CodeGen | |--------|----------|---------| | universal_gemm | **-8.5%** | **-9.2%** | | flatmm | **-7.6%** | **-10.5%** | ### ASM Equivalence 5/5 PASS — 650,151 lines verified identical (gfx942). TUs: universal_gemm, flatmm_basic, fmha_bwd, reduce, bscale. ## Test plan - [x] ASM equivalence verified (650K lines, gfx942) - [x] Wilcoxon timing verified (7 trials, p<0.01) - [x] IR function counts verified (-7.6% to -10.5% CodeGen reduction) - [ ] CI 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> --- .../ops/epilogue/chainer/epilogue_chainer.hpp | 10 +- .../block_flatmm_asmem_bsmem_creg_v1.hpp | 38 +- .../ops/flatmm/kernel/moe_flatmm_kernel.hpp | 41 +- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 410 +++--- ...ec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 1175 ++++++++--------- .../moe_flatmm_pipeline_agmem_bgmem_creg.hpp | 495 ++++--- ...mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 161 ++- ...block_fmha_bwd_pipeline_default_policy.hpp | 48 +- .../block/block_gemm_areg_breg_creg_v1.hpp | 166 ++- .../block/block_gemm_areg_breg_creg_v2.hpp | 104 +- ...block_gemm_areg_bsmem_creg_one_warp_v1.hpp | 58 +- .../block/block_gemm_areg_bsmem_creg_v1.hpp | 58 +- .../block/block_gemm_areg_bsmem_creg_v2.hpp | 60 +- .../block/block_gemm_areg_bsmem_creg_v2r1.hpp | 68 +- .../block/block_gemm_asmem_breg_creg_v1.hpp | 58 +- .../block/block_gemm_asmem_bsmem_creg_v1.hpp | 64 +- .../block_gemm_mx_areg_bsmem_creg_v1.hpp | 89 +- .../block/block_universal_gemm_as_bs_cr.hpp | 150 +-- .../gemm/block/block_wp_asmem_breg_creg.hpp | 85 +- .../block/block_wp_asmem_bsmem_creg_v1.hpp | 36 +- ...versal_gemm_ar_aquant_flatbr_bquant_cr.hpp | 68 +- ...ock_universal_gemm_ar_flatbr_bquant_cr.hpp | 174 ++- ..._universal_gemm_as_aquant_bs_bquant_cr.hpp | 204 ++- .../block_universal_gemm_as_aquant_bs_cr.hpp | 77 +- .../block_universal_gemm_as_bs_bquant_cr.hpp | 276 ++-- .../gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp | 88 +- .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 88 +- .../norm_reduce/block/block_norm_reduce.hpp | 10 +- .../ops/reduce/block/block_reduce2d.hpp | 20 +- 29 files changed, 2160 insertions(+), 2219 deletions(-) diff --git a/include/ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp b/include/ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp index 25ef000cc3..f22919d922 100644 --- a/include/ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp +++ b/include/ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp @@ -187,11 +187,11 @@ struct EpilogueGraph Context& context) const { // For each iteration, process all epilogues in order - static_for<0, Steps, 1>{}([&](auto iAccess) { - static_for<0, sizeof...(EpilogueTypes), 1>{}([&](auto I) { - epilogues.template get()( - out_window, acc_tile, aux_windows, p_smem, context, iAccess); - }); + static_ford>{}([&](auto iI) { + constexpr auto iAccess = number{}]>{}; + constexpr auto I = number{}]>{}; + epilogues.template get()( + out_window, acc_tile, aux_windows, p_smem, context, iAccess); }); } }; diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp index 2b8e9e4b1a..de73e4f1ff 100644 --- a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp @@ -92,29 +92,29 @@ struct BlockFlatmmASmemBSmemCRegV1 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block window - const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - __builtin_amdgcn_sched_barrier(0x7F6); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + __builtin_amdgcn_sched_barrier(0x7F6); }); }); } diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 13d5e65155..81cf76cb07 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -1105,15 +1105,14 @@ struct MoeFlatmmKernel statically_indexed_array scale_m_offsets; if constexpr(!BMXFP4_Pipeline) - static_for<0, MRepeat, 1>{}([&](auto mIter) { - static_for<0, kM0, 1>{}([&](auto m0) { - static_for<0, kM2, 1>{}([&](auto m2) { - const auto row_idx = - coord_m + mIter * MPerXdl + m0 * kM1 * kM2 + m2 + scale_m_coord[I0]; - scale_m_offsets[mIter * number{} + m0 * number{} + m2] = - row_to_token_idx(row_idx); - }); - }); + static_ford>{}([&](auto mmm) { + constexpr auto mIter = number{}]>{}; + constexpr auto m0 = number{}]>{}; + constexpr auto m2 = number{}]>{}; + const auto row_idx = + coord_m + mIter * MPerXdl + m0 * kM1 * kM2 + m2 + scale_m_coord[I0]; + scale_m_offsets[mIter * number{} + m0 * number{} + m2] = + row_to_token_idx(row_idx); }); constexpr int DynamicTileOffsetFlag = 0; @@ -1426,19 +1425,19 @@ struct MoeFlatmmKernel statically_indexed_array, NumMEpiTile> c_scatter_valids; auto c_coord = dram_tile_distribution.calculate_index(); - static_for<0, NumMEpiTile, 1>{}([&](auto mIter) { - static_for<0, MPerThread, 1>{}([&](auto m0) { - auto row_idx = coord_m + mIter * MPerIterationShuffle + c_coord[0] + m0; - auto fused_token = - kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0] + static_ford>{}([&](auto mm) { + constexpr auto mIter = number{}]>{}; + constexpr auto m0 = number{}]>{}; + auto row_idx = coord_m + mIter * MPerIterationShuffle + c_coord[0] + m0; + auto fused_token = + kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0] - index_t scatter_token_id = fused_token & token_id_mask; - c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); - if constexpr(IsInputGemm) - scatter_token_id = - scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); - c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; - }); + index_t scatter_token_id = fused_token & token_id_mask; + c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); + if constexpr(IsInputGemm) + scatter_token_id = + scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); + c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; }); //===----------------------------------------------------------------------===// diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index ee8527c458..8f40c9be7a 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -606,16 +606,16 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 MIterPerWarp> a_warp_windows_pong; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; - a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; - move_tile_window(a_warp_windows_ping(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - move_tile_window(a_warp_windows_pong(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + move_tile_window(a_warp_windows_pong(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); // Block GEMM @@ -656,15 +656,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // prefetch B - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); // move B window to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -701,15 +701,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 while(iCounter > 0) { // prefetch B(2i+1) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); // Prefill A(2i+1) @@ -722,44 +722,44 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // GEMM 2i - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); // move B window to next flat K @@ -776,15 +776,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 // Next K // prefetch B(2i+2) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); // Prefill A(2i+2) @@ -797,43 +797,43 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // GEMM 2i+1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_pong(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_pong(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_pong(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); // move B window to next flat K @@ -854,15 +854,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 if constexpr(TailNum == TailNumber::Even) { // prefetch B(loopK) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); // Prefill A(loopK) @@ -870,44 +870,44 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 store_tile(a_copy_lds_window_pong, a_block_tile_tmp); // GEMM loopK-1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); static_for<0, m_preload, 1>{}([&](auto loadIter) { @@ -920,86 +920,86 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 Last2ndHotLoopScheduler(); // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_pong(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_pong(number{})(number{})); - } - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_pong(number{})(number{})); + } + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); LastHotLoopScheduler(); } else if constexpr(TailNum == TailNumber::Odd) { // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); LastHotLoopScheduler(); } diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 11b978813a..0f7f742fa0 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -537,22 +537,22 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 a_warp_windows_pong; auto A_Lds_Stride = 8; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; - a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; - auto weight_k_idx = kIter / number{}; - auto weight_k_rank = kIter % number{}; - move_tile_window( - a_warp_windows_ping(mIter)(kIter), - {mIter * MPerBlockPerIter, - weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK}); - move_tile_window( - a_warp_windows_pong(mIter)(kIter), - {mIter * MPerBlockPerIter, - weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK}); - }); + auto weight_k_idx = kIter / number{}; + auto weight_k_rank = kIter % number{}; + move_tile_window( + a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, + weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK}); + move_tile_window( + a_warp_windows_pong(mIter)(kIter), + {mIter * MPerBlockPerIter, + weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK}); }); // Block GEMM @@ -657,33 +657,32 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // prefetch B - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) { - if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0) - { - auto scale_n_iter = nIter / number{}; - auto scale_k_iter = kIter / number{}; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0) + { + auto scale_n_iter = nIter / number{}; + auto scale_k_iter = kIter / number{}; - scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) = - scale_b_flat_dram_window; - move_tile_window( - scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter), - {scale_n_iter * NFlatPerBlockPerIter, scale_k_iter * ScaleKFlatPerWarp}); - scale_b_warp_tensor_ping(scale_n_iter)(scale_k_iter) = - load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter)); - } - auto packed_n_idx = nIter / number{}; - auto packed_n_rank = nIter % number{}; + scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) = scale_b_flat_dram_window; + move_tile_window( + scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter), + {scale_n_iter * NFlatPerBlockPerIter, scale_k_iter * ScaleKFlatPerWarp}); + scale_b_warp_tensor_ping(scale_n_iter)(scale_k_iter) = + load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter)); + } + auto packed_n_idx = nIter / number{}; + auto packed_n_rank = nIter % number{}; - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter + - packed_n_rank, - kIter * KFlatPerBlockPerIter}); + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + move_tile_window( + b_flat_dram_windows(nIter)(kIter), + {packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter + packed_n_rank, + kIter * KFlatPerBlockPerIter}); - ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); - b_warp_tensor_ping(nIter)(kIter) = ub.u; - }); + ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); + b_warp_tensor_ping(nIter)(kIter) = ub.u; }); // move B window to next flat K move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter}); @@ -794,38 +793,37 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 while(iCounter > 0) { // prefetch B(2i+1) - static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0) - { - auto scale_n_iter = nIter / number{}; - auto scale_k_iter = kIter / number{}; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0) + { + auto scale_n_iter = nIter / number{}; + auto scale_k_iter = kIter / number{}; - scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) = - scale_b_flat_dram_window; - - move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter), - {scale_n_iter * NFlatPerBlockPerIter, - scale_k_iter * ScaleKFlatPerWarp}); - - scale_b_warp_tensor_pong(scale_n_iter)(scale_k_iter) = - load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter)); - } - - auto packed_n_idx = nIter / number{}; - auto packed_n_rank = nIter % number{}; - - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) = + scale_b_flat_dram_window; move_tile_window( - b_flat_dram_windows(nIter)(kIter), - {packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter + - packed_n_rank, - kIter * KFlatPerBlockPerIter}); + scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter), + {scale_n_iter * NFlatPerBlockPerIter, scale_k_iter * ScaleKFlatPerWarp}); - ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); - b_warp_tensor_pong(nIter)(kIter) = ub.u; - }); + scale_b_warp_tensor_pong(scale_n_iter)(scale_k_iter) = + load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter)); + } + + auto packed_n_idx = nIter / number{}; + auto packed_n_rank = nIter % number{}; + + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter + + packed_n_rank, + kIter * KFlatPerBlockPerIter}); + + ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); + b_warp_tensor_pong(nIter)(kIter) = ub.u; }); // Prefill A(2i+1) @@ -835,51 +833,50 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 prefill_lds_a_stage1( a_copy_lds_window_ping, a_copy_dram_window, number{}); // GEMM 2i - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - if constexpr(mIter == 0) - dequant_mxfp4( - b_warp_tensor_ping(nIter)(kIter / number{}), - scale_b_warp_tensor_ping(nIter / number{})( - kIter / number{}), - nIter, - kIter); + if constexpr(mIter == 0) + dequant_mxfp4(b_warp_tensor_ping(nIter)(kIter / number{}), + scale_b_warp_tensor_ping(nIter / number{})( + kIter / number{}), + nIter, + kIter); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor(number{}), dequant_B_n[nIter]); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor(number{}), dequant_B_n[nIter]); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - __builtin_amdgcn_s_waitcnt(Bload_total_num); - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + __builtin_amdgcn_s_waitcnt(Bload_total_num); + block_sync_lds(); + } }); prefill_lds_a_stage1( a_copy_lds_window_ping, a_copy_dram_window, number{}); @@ -902,37 +899,36 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 // Next K // prefetch B(2i+2) - static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0) - { - auto scale_n_iter = nIter / number{}; - auto scale_k_iter = kIter / number{}; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0) + { + auto scale_n_iter = nIter / number{}; + auto scale_k_iter = kIter / number{}; - scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) = - scale_b_flat_dram_window; + scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) = + scale_b_flat_dram_window; - move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter), - {scale_n_iter * NFlatPerBlockPerIter, - scale_k_iter * ScaleKFlatPerWarp}); - - scale_b_warp_tensor_ping(scale_n_iter)(scale_k_iter) = - load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter)); - } - - auto packed_n_idx = nIter / number{}; - auto packed_n_rank = nIter % number{}; - - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; move_tile_window( - b_flat_dram_windows(nIter)(kIter), - {packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter + - packed_n_rank, - kIter * KFlatPerBlockPerIter}); + scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter), + {scale_n_iter * NFlatPerBlockPerIter, scale_k_iter * ScaleKFlatPerWarp}); - ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); - b_warp_tensor_ping(nIter)(kIter) = ub.u; - }); + scale_b_warp_tensor_ping(scale_n_iter)(scale_k_iter) = + load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter)); + } + + auto packed_n_idx = nIter / number{}; + auto packed_n_rank = nIter % number{}; + + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter + + packed_n_rank, + kIter * KFlatPerBlockPerIter}); + + ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); + b_warp_tensor_ping(nIter)(kIter) = ub.u; }); // Prefill A(2i+2) @@ -943,50 +939,49 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 a_copy_lds_window_pong, a_copy_dram_window, number{}); // GEMM 2i+1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - if constexpr(mIter == 0) - dequant_mxfp4( - b_warp_tensor_pong(nIter)(kIter / number{}), - scale_b_warp_tensor_pong(nIter / number{})( - kIter / number{}), - nIter, - kIter); + if constexpr(mIter == 0) + dequant_mxfp4(b_warp_tensor_pong(nIter)(kIter / number{}), + scale_b_warp_tensor_pong(nIter / number{})( + kIter / number{}), + nIter, + kIter); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor(number{}), dequant_B_n[nIter]); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor(number{}), dequant_B_n[nIter]); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_pong(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - __builtin_amdgcn_s_waitcnt(Bload_total_num); - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_pong(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + __builtin_amdgcn_s_waitcnt(Bload_total_num); + block_sync_lds(); + } }); prefill_lds_a_stage1( a_copy_lds_window_pong, a_copy_dram_window, number{}); @@ -1058,51 +1053,50 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 prefill_lds_a_stage2(a_copy_lds_window_pong); // GEMM loopK-1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - if constexpr(mIter == 0) - dequant_mxfp4( - b_warp_tensor_ping(nIter)(kIter / number{}), - scale_b_warp_tensor_ping(nIter / number{})( - kIter / number{}), - nIter, - kIter); + if constexpr(mIter == 0) + dequant_mxfp4(b_warp_tensor_ping(nIter)(kIter / number{}), + scale_b_warp_tensor_ping(nIter / number{})( + kIter / number{}), + nIter, + kIter); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor(number{}), dequant_B_n[nIter]); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor(number{}), dequant_B_n[nIter]); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - __builtin_amdgcn_s_waitcnt(Bload_total_num); - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + __builtin_amdgcn_s_waitcnt(Bload_total_num); + block_sync_lds(); + } }); static_for<0, m_preload, 1>{}([&](auto loadIter) { @@ -1170,50 +1164,49 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 else if constexpr(TailNum == TailNumber::Odd) { // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - if constexpr(mIter == 0) - dequant_mxfp4( - b_warp_tensor_ping(nIter)(kIter / number{}), - scale_b_warp_tensor_ping(nIter / number{})( - kIter / number{}), - nIter, - kIter); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor(number{}), dequant_B_n[nIter]); + if constexpr(mIter == 0) + dequant_mxfp4(b_warp_tensor_ping(nIter)(kIter / number{}), + scale_b_warp_tensor_ping(nIter / number{})( + kIter / number{}), + nIter, + kIter); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor(number{}), dequant_B_n[nIter]); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - __builtin_amdgcn_s_waitcnt(Bload_total_num); - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + __builtin_amdgcn_s_waitcnt(Bload_total_num); + block_sync_lds(); + } }); LastHotLoopScheduler(); } @@ -1904,29 +1897,29 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 }); // prefetch Scale A - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + static_ford>{}([&](auto mk) { + constexpr auto mIter_pack = number{}]>{}; + constexpr auto kIter_pack = number{}]>{}; + scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; + move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), + {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); - }); + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = + load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); }); // move Scale A window to next K move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); // prefetch Scale B - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); + static_ford>{}([&](auto nk) { + constexpr auto nIter_pack = number{}]>{}; + constexpr auto kIter_pack = number{}]>{}; + scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; + move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), + {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); - }); + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = + load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); }); // move Scale B window to next K move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); @@ -1957,95 +1950,90 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 // MAIN LOOP auto main_body_implx2 = [&]() mutable { // prefetch B(2i+1) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_windows(nIter), number{}); - if constexpr(kIter == KIterPerWarp - 1) - move_tile_window(b_flat_dram_windows(nIter), - {0, BlockGemmShape::flatKPerBlock}); - }); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_windows(nIter), number{}); + if constexpr(kIter == KIterPerWarp - 1) + move_tile_window(b_flat_dram_windows(nIter), + {0, BlockGemmShape::flatKPerBlock}); }); // prefetch Scale A and Scale B (2i+1) - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + static_ford>{}([&](auto mk) { + constexpr auto mIter_pack = number{}]>{}; + constexpr auto kIter_pack = number{}]>{}; + scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; + move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), + {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); - }); + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = + load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); }); - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); + static_ford>{}([&](auto nk) { + constexpr auto nIter_pack = number{}]>{}; + constexpr auto kIter_pack = number{}]>{}; + scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; + move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), + {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); - }); + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = + load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); }); // GEMM 2i - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; - constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; - constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; + static_ford>{}([&](auto idx) { + constexpr auto kIter_pack = number{}]>{}; + constexpr auto mIter_pack = number{}]>{}; + constexpr auto nIter_pack = number{}]>{}; + constexpr auto ikxdl = number{}]>{}; + constexpr auto imxdl = number{}]>{}; + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}.template - // operator()( - operator()( - c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) - .get_thread_buffer()[0], - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) - .get_thread_buffer()[0]); + // warp GEMM + WG{}.template + // operator()( + operator()( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter_pack * number{} + + inxdl)(kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack).get_thread_buffer()[0], + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack).get_thread_buffer()[0]); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - constexpr auto addr = - m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; - if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) - { - constexpr auto AmIter = addr % 2 + addr / 4 * 2; - constexpr auto AkIter = addr / 2 % 2; - a_warp_tensor(number{}) = load_tile_with_offset( - a_warp_window_ping, - tuple, number>{}); - } - }); - }); - }); + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + constexpr auto addr = m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_ping, + tuple, number>{}); + } }); // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished s_waitcnt< // vmcnt @@ -2072,96 +2060,94 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 ////////////////////////////// Next K ////////////////////////////// // prefetch B(2i+2) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_windows(nIter), number{}); - if constexpr(kIter == KIterPerWarp - 1) - move_tile_window(b_flat_dram_windows(nIter), - {0, BlockGemmShape::flatKPerBlock}); - }); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_windows(nIter), number{}); + if constexpr(kIter == KIterPerWarp - 1) + move_tile_window(b_flat_dram_windows(nIter), + {0, BlockGemmShape::flatKPerBlock}); }); // prefetch Scale A and Scale B (2i+2) - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + static_ford>{}([&](auto mk) { + constexpr auto mIter_pack = number{}]>{}; + constexpr auto kIter_pack = number{}]>{}; + scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; + move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), + {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); - }); + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = + load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); }); - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); + static_ford>{}([&](auto nk) { + constexpr auto nIter_pack = number{}]>{}; + constexpr auto kIter_pack = number{}]>{}; + scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; + move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), + {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); - }); + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = + load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); }); // GEMM 2i+1 - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + static_ford>{}([&](auto idx) { + constexpr auto kIter_pack = number{}]>{}; + constexpr auto mIter_pack = number{}]>{}; + constexpr auto nIter_pack = number{}]>{}; + constexpr auto ikxdl = number{}]>{}; + constexpr auto imxdl = number{}]>{}; + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}.template - // operator()( - operator()( - c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) - .get_thread_buffer()[0], // scale A - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) - .get_thread_buffer()[0]); // scale B + // warp GEMM + WG{}.template + // operator()( + operator()( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_pong(nIter_pack * number{} + + inxdl)(kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], // scale A + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); // scale B - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; - if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) - { - constexpr auto AmIter = addr % 2 + addr / 4 * 2; - constexpr auto AkIter = addr / 2 % 2; - a_warp_tensor(number{}) = load_tile_with_offset( - a_warp_window_pong, - tuple, number>{}); - } - }); - }); - }); + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + + (kIter_pack * KXdlPack + ikxdl) * 2 + + (mIter_pack * MXdlPack + imxdl) / 2 * 4 + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_pong, + tuple, number>{}); + } }); // barrier as ds_load A(2i + 1) and buffer_load_lds A(2i + 2) finished s_waitcnt< // vmcnt @@ -2199,92 +2185,89 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 if constexpr(TailNum == TailNumber::Even) { // prefetch B(loopK) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_windows(nIter), - make_tuple(number<0>{}, number{})); - }); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_windows(nIter), + make_tuple(number<0>{}, number{})); }); // prefetch Scale A and Scale B (2i+1) - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + static_ford>{}([&](auto mk) { + constexpr auto mIter_pack = number{}]>{}; + constexpr auto kIter_pack = number{}]>{}; + scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; + move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), + {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); - }); + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = + load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); }); - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); + static_ford>{}([&](auto nk) { + constexpr auto nIter_pack = number{}]>{}; + constexpr auto kIter_pack = number{}]>{}; + scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; + move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), + {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); - }); + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = + load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); }); // GEMM loopK-1 - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + static_ford>{}([&](auto idx) { + constexpr auto kIter_pack = number{}]>{}; + constexpr auto mIter_pack = number{}]>{}; + constexpr auto nIter_pack = number{}]>{}; + constexpr auto ikxdl = number{}]>{}; + constexpr auto imxdl = number{}]>{}; + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}.template - operator()( - c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) - .get_thread_buffer()[0], // scale A - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) - .get_thread_buffer()[0]); // scale B + // warp GEMM + WG{}.template operator()( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter_pack * number{} + + inxdl)(kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], // scale A + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); // scale B - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; - if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) - { - constexpr auto AmIter = addr % 2 + addr / 4 * 2; - constexpr auto AkIter = addr / 2 % 2; - a_warp_tensor(number{}) = load_tile_with_offset( - a_warp_window_ping, - tuple, number>{}); - } - }); - }); - }); + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + + (kIter_pack * KXdlPack + ikxdl) * 2 + + (mIter_pack * MXdlPack + imxdl) / 2 * 4 + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_ping, + tuple, number>{}); + } }); // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished s_waitcnt< // vmcnt @@ -2302,123 +2285,115 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 // Last2ndHotLoopScheduler(); // GEMM loopK - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + static_ford>{}([&](auto idx) { + constexpr auto kIter_pack = number{}]>{}; + constexpr auto mIter_pack = number{}]>{}; + constexpr auto nIter_pack = number{}]>{}; + constexpr auto ikxdl = number{}]>{}; + constexpr auto imxdl = number{}]>{}; + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}.template - operator()( - // operator()( - c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) - .get_thread_buffer()[0], // scale A - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) - .get_thread_buffer()[0]); // scale B + // warp GEMM + WG{}.template operator()( + // operator()( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_pong(nIter_pack * number{} + + inxdl)(kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], // scale A + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); // scale B - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; - if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) - { - constexpr auto AmIter = addr % 2 + addr / 4 * 2; - constexpr auto AkIter = addr / 2 % 2; - a_warp_tensor(number{}) = load_tile_with_offset( - a_warp_window_pong, - tuple, number>{}); - } - }); - }); - }); + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + + (kIter_pack * KXdlPack + ikxdl) * 2 + + (mIter_pack * MXdlPack + imxdl) / 2 * 4 + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_pong, + tuple, number>{}); + } }); // LastHotLoopScheduler(); } else if constexpr(TailNum == TailNumber::Odd) { // GEMM loopK - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; - constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; - constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; + static_ford>{}([&](auto idx) { + constexpr auto kIter_pack = number{}]>{}; + constexpr auto mIter_pack = number{}]>{}; + constexpr auto nIter_pack = number{}]>{}; + constexpr auto ikxdl = number{}]>{}; + constexpr auto imxdl = number{}]>{}; + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}.template - // operator()( - operator()( - c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) - .get_thread_buffer()[0], - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) - .get_thread_buffer()[0]); + // warp GEMM + WG{}.template + // operator()( + operator()( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter_pack * number{} + + inxdl)(kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack).get_thread_buffer()[0], + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack).get_thread_buffer()[0]); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - constexpr auto addr = - m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; - if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) - { - constexpr auto AmIter = addr % 2 + addr / 4 * 2; - constexpr auto AkIter = addr / 2 % 2; - a_warp_tensor(number{}) = load_tile_with_offset( - a_warp_window_ping, - tuple, number>{}); - } - }); - }); - }); + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + constexpr auto addr = m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_ping, + tuple, number>{}); + } }); // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished s_waitcnt< // vmcnt diff --git a/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp b/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp index 5681726afe..543f4dc92a 100644 --- a/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp @@ -529,22 +529,22 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 MIterPerWarp> a_warp_windows_pong; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; - move_tile_window(a_warp_windows_ping(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; - move_tile_window(a_warp_windows_pong(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(a_warp_windows_pong(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); // Block GEMM @@ -592,26 +592,26 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 2; // prefetch B - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - if constexpr(!IsGateUpMode) - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + if constexpr(!IsGateUpMode) + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + else + { + if constexpr(nIter % 2 == 0) + move_tile_window( + b_flat_dram_windows(nIter)(kIter), + {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); else - { - if constexpr(nIter % 2 == 0) - move_tile_window( - b_flat_dram_windows(nIter)(kIter), - {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - else - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride, - kIter * KFlatPerBlockPerIter}); - } - b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride, + kIter * KFlatPerBlockPerIter}); + } + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); // move B window to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -648,28 +648,27 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 while(iCounter > 0) { // prefetch B(2i+1) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - if constexpr(!IsGateUpMode) + if constexpr(!IsGateUpMode) + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + else + { + if constexpr(nIter % 2 == 0) move_tile_window( b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); else - { - if constexpr(nIter % 2 == 0) - move_tile_window( - b_flat_dram_windows(nIter)(kIter), - {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - else - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride, - kIter * KFlatPerBlockPerIter}); - } + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride, + kIter * KFlatPerBlockPerIter}); + } - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); // Prefill A(2i+1) @@ -682,44 +681,44 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // GEMM 2i - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); // move B window to next flat K @@ -736,28 +735,27 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 // Next K // prefetch B(2i+2) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - if constexpr(!IsGateUpMode) + if constexpr(!IsGateUpMode) + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + else + { + if constexpr(nIter % 2 == 0) move_tile_window( b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); else - { - if constexpr(nIter % 2 == 0) - move_tile_window( - b_flat_dram_windows(nIter)(kIter), - {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - else - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride, - kIter * KFlatPerBlockPerIter}); - } + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride, + kIter * KFlatPerBlockPerIter}); + } - b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); // Prefill A(2i+2) @@ -770,43 +768,43 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // GEMM 2i+1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_pong(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_pong(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_pong(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); // move B window to next flat K @@ -827,28 +825,27 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 if constexpr(TailNum == TailNumber::Even) { // prefetch B(loopK) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - if constexpr(!IsGateUpMode) + if constexpr(!IsGateUpMode) + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + else + { + if constexpr(nIter % 2 == 0) move_tile_window( b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); else - { - if constexpr(nIter % 2 == 0) - move_tile_window( - b_flat_dram_windows(nIter)(kIter), - {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - else - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride, - kIter * KFlatPerBlockPerIter}); - } + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride, + kIter * KFlatPerBlockPerIter}); + } - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); // Prefill A(loopK) @@ -856,44 +853,44 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 store_tile(a_copy_lds_window_pong, a_block_tile_tmp); // GEMM loopK-1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); static_for<0, m_preload, 1>{}([&](auto loadIter) { @@ -906,86 +903,86 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 Last2ndHotLoopScheduler(); // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_pong(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_pong(number{})(number{})); - } - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_pong(number{})(number{})); + } + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); LastHotLoopScheduler(); } else if constexpr(TailNum == TailNumber::Odd) { // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); LastHotLoopScheduler(); } diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 23d7a9fca9..f698541dbf 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -486,13 +486,13 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}; auto c_block_tile = BlockFlatmm{}.MakeCBlockTile(); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensors(mIter)(nIter).get_thread_buffer()); - }); + static_ford>{}([&](auto mn) { + constexpr auto mIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensors(mIter)(nIter).get_thread_buffer()); }); return c_block_tile; } @@ -643,24 +643,23 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto impack) { - static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { - scale_a_tile_tensor_ping(impack)(ikpack) = load_tile_with_offset( - scale_a_dram_window, + static_ford>{}([&](auto ii) { + constexpr auto impack = number{}]>{}; + constexpr auto ikpack = number{}]>{}; + scale_a_tile_tensor_ping(impack)(ikpack) = + load_tile_with_offset(scale_a_dram_window, - impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); - }); + impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); }); // move Scale A window to next K move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); // prefetch Scale B - static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { - static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { - scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset( - scale_b_dram_window, - inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); - }); + static_ford>{}([&](auto ii) { + constexpr auto inpack = number{}]>{}; + constexpr auto ikpack = number{}]>{}; + scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset( + scale_b_dram_window, inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); }); // move Scale B window to next K move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); @@ -698,34 +697,34 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_window, - b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); - // move B window to next flat K - if constexpr(kIter == KIterPerWarp - 1) - b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( - tuple, number>{}); - }); + // move B window to next flat K + if constexpr(kIter == KIterPerWarp - 1) + b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( + tuple, number>{}); }); // prefetch Scale A and Scale B (2i+1) - static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { - static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) { - scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset( - scale_a_dram_window, - impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); - }); + static_ford>{}([&](auto ii) { + constexpr auto ikpack = number{}]>{}; + constexpr auto impack = number{}]>{}; + scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset( + scale_a_dram_window, + impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); }); - static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { - static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { - scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset( - scale_b_dram_window, - inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); - }); + static_ford>{}([&](auto ii) { + constexpr auto ikpack = number{}]>{}; + constexpr auto inpack = number{}]>{}; + scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset( + scale_b_dram_window, + inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); }); // GEMM 2i @@ -788,34 +787,34 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_window, - b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); - // move B window to next flat K - if constexpr(kIter == KIterPerWarp - 1) - b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( - tuple, number>{}); - }); + // move B window to next flat K + if constexpr(kIter == KIterPerWarp - 1) + b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( + tuple, number>{}); }); // prefetch Scale A and Scale B (2i+2) - static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { - static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) { - scale_a_tile_tensor_ping(impack)(ikpack) = load_tile_with_offset( - scale_a_dram_window, - impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); - }); + static_ford>{}([&](auto ii) { + constexpr auto ikpack = number{}]>{}; + constexpr auto impack = number{}]>{}; + scale_a_tile_tensor_ping(impack)(ikpack) = load_tile_with_offset( + scale_a_dram_window, + impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); }); - static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { - static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { - scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset( - scale_b_dram_window, - inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); - }); + static_ford>{}([&](auto ii) { + constexpr auto ikpack = number{}]>{}; + constexpr auto inpack = number{}]>{}; + scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset( + scale_b_dram_window, + inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); }); // GEMM 2i+1 @@ -888,28 +887,28 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_window, - b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); - }); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); }); // prefetch Scale A and Scale B (2i+1) - static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) { - static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { - scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset( - scale_a_dram_window, - impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); - }); + static_ford>{}([&](auto ii) { + constexpr auto impack = number{}]>{}; + constexpr auto ikpack = number{}]>{}; + scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset( + scale_a_dram_window, + impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); }); - static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { - static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { - scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset( - scale_b_dram_window, - inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); - }); + static_ford>{}([&](auto ii) { + constexpr auto inpack = number{}]>{}; + constexpr auto ikpack = number{}]>{}; + scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset( + scale_b_dram_window, + inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); }); // GEMM loopK-1 diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index e67a525ac4..bb3fa8c411 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -1706,22 +1706,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - p_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + p_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); #if defined(__gfx11__) - PermuteWarpGemmCToA(pt_warp_tensor, p_warp_tensor); + PermuteWarpGemmCToA(pt_warp_tensor, p_warp_tensor); #else - pt_warp_tensor.get_thread_buffer() = p_warp_tensor.get_thread_buffer(); + pt_warp_tensor.get_thread_buffer() = p_warp_tensor.get_thread_buffer(); #endif - pt_out.set_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), - pt_warp_tensor.get_thread_buffer()); - }); + pt_out.set_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), + pt_warp_tensor.get_thread_buffer()); }); } else @@ -1763,22 +1763,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - ds_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + ds_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); #if defined(__gfx11__) - PermuteWarpGemmCToA(dst_warp_tensor, ds_warp_tensor); + PermuteWarpGemmCToA(dst_warp_tensor, ds_warp_tensor); #else - dst_warp_tensor.get_thread_buffer() = ds_warp_tensor.get_thread_buffer(); + dst_warp_tensor.get_thread_buffer() = ds_warp_tensor.get_thread_buffer(); #endif - dst_out.set_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), - dst_warp_tensor.get_thread_buffer()); - }); + dst_out.set_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), + dst_warp_tensor.get_thread_buffer()); }); } else diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp index 108afd9b1c..0ac8efbc8d 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp @@ -213,38 +213,38 @@ struct BlockGemmARegBRegCRegV1 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A Block window - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - // read C warp tensor from C block tensor - using c_iter_idx = std:: - conditional_t, sequence>; - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + // read C warp tensor from C block tensor + using c_iter_idx = + std::conditional_t, sequence>; + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } @@ -323,73 +323,69 @@ struct BlockGemmARegBRegCRegV1 // hot loop with MX scaling and pre-packed int32_t scales: // Outer loops iterate over pack groups (scale tile indices) - static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { - static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) { - // Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t) - auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data( - sequence{}, sequence<1, 1, 1>{}); - const int32_t a_scale_packed = bit_cast(scale_a_slice[number<0>{}]); + static_ford>{}([&](auto ii) { + constexpr auto ikpack = number{}]>{}; + constexpr auto impack = number{}]>{}; + // Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t) + auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data( + sequence{}, sequence<1, 1, 1>{}); + const int32_t a_scale_packed = bit_cast(scale_a_slice[number<0>{}]); - static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { - // Get pre-packed int32_t B scale - auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data( - sequence{}, sequence<1, 1, 1>{}); - const int32_t b_scale_packed = bit_cast(scale_b_slice[number<0>{}]); + static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { + // Get pre-packed int32_t B scale + auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data( + sequence{}, sequence<1, 1, 1>{}); + const int32_t b_scale_packed = bit_cast(scale_b_slice[number<0>{}]); - // Inner loops: issue MFMAs within the pack group using OpSel - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - constexpr auto kIter = ikpack * KXdlPack + ikxdl; - constexpr auto mIter = impack * MXdlPack + imxdl; + // Inner loops: issue MFMAs within the pack group using OpSel + static_ford>{}([&](auto jj) { + constexpr auto ikxdl = number{}]>{}; + constexpr auto imxdl = number{}]>{}; + constexpr auto kIter = ikpack * KXdlPack + ikxdl; + constexpr auto mIter = impack * MXdlPack + imxdl; - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = - a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - // OpSel for A: selects byte within packed int32_t - constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl; + // OpSel for A: selects byte within packed int32_t + constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl; - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto nIter = inpack * NXdlPack + inxdl; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto nIter = inpack * NXdlPack + inxdl; - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = - b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, - b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - // OpSel for B: selects byte within packed int32_t - constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl; + // OpSel for B: selects byte within packed int32_t + constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl; - // read C warp tensor from C block tensor - using c_iter_idx = std::conditional_t, - sequence>; - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tensor.get_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + // read C warp tensor from C block tensor + using c_iter_idx = std::conditional_t, + sequence>; + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM with MX scaling using pre-packed scale and OpSel - WarpGemm{}.template operator()(c_warp_tensor, - a_warp_tensor, - b_warp_tensor, - a_scale_packed, - b_scale_packed); + // warp GEMM with MX scaling using pre-packed scale and OpSel + WarpGemm{}.template operator()(c_warp_tensor, + a_warp_tensor, + b_warp_tensor, + a_scale_packed, + b_scale_packed); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); }); diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp index 960a685792..a559206b98 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp @@ -250,74 +250,74 @@ struct BlockGemmARegBRegCRegV2 // hot loop: if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A Block window - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - // read A warp tensor from A Block window - AWarpTensor a_warp_tensor; + static_ford>{}([&](auto mnk) { + constexpr auto mIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - }); + // warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); } } diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp index 3302d149ca..a7f1cef519 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp @@ -109,13 +109,13 @@ struct BlockGemmARegBSmemCRegOneWarpV1 NIterPerWarp> b_warp_windows; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); }); #endif @@ -141,35 +141,35 @@ struct BlockGemmARegBSmemCRegOneWarpV1 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp index 14d59ff373..0118258668 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp @@ -116,13 +116,13 @@ struct BlockGemmARegBSmemCRegV1 NIterPerWarp> b_warp_windows; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); }); #endif @@ -148,35 +148,35 @@ struct BlockGemmARegBSmemCRegV1 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp index 0aa7509b1e..d292cade24 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp @@ -103,13 +103,13 @@ struct BlockGemmARegBSmemCRegV2 NIterPerWarp> b_warp_windows; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); }); #endif @@ -135,36 +135,36 @@ struct BlockGemmARegBSmemCRegV2 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp index 2ba01d91c5..9ffc9f2070 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp @@ -90,13 +90,13 @@ struct BlockGemmARegBSmemCRegV2R1 NIterPerWarp> b_warp_windows; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); }); // check C-block-distribution @@ -126,43 +126,43 @@ struct BlockGemmARegBSmemCRegV2R1 NIterPerWarp> b_warp_tensors; - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_warp_tensors(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter)); - }); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_warp_tensors(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter)); }); // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = b_warp_tensors(nIter)(kIter); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + // read B warp tensor from B Block window + const auto b_warp_tensor = b_warp_tensors(nIter)(kIter); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp index b1223f8755..2b750c75b3 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp @@ -116,13 +116,13 @@ struct BlockGemmASmemBRegCRegV1 MIterPerWarp> a_warp_windows; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); #endif @@ -148,34 +148,34 @@ struct BlockGemmASmemBRegCRegV1 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A Block window - const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A Block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp index 6eedfabaf8..32776b786d 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp @@ -85,13 +85,13 @@ struct BlockGemmASmemBSmemCRegV1 MIterPerWarp> a_warp_windows; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); #endif @@ -120,13 +120,13 @@ struct BlockGemmASmemBSmemCRegV1 NIterPerWarp> b_warp_windows; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); }); #endif @@ -138,31 +138,31 @@ struct BlockGemmASmemBSmemCRegV1 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block window - const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } diff --git a/include/ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1.hpp index 5dde03912a..9ad8c4cc97 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1.hpp @@ -165,61 +165,60 @@ struct BlockGemmMxARegBSmemCRegV1 uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - auto b_warp_window = b_warp_window_tmp; - move_tile_window( - b_warp_window, - {nIter * (NPerBlock / NIterPerWarp), kIter * (KPerBlock / KIterPerWarp)}); - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_window); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + auto b_warp_window = b_warp_window_tmp; + move_tile_window( + b_warp_window, + {nIter * (NPerBlock / NIterPerWarp), kIter * (KPerBlock / KIterPerWarp)}); + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_window); - BScaleWarpTensor b_scale_warp_tensor; + BScaleWarpTensor b_scale_warp_tensor; - b_scale_warp_tensor.get_thread_buffer() = - b_scale_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, - b_scale_warp_y_index_zeros), - merge_sequences(sequence<1, 1, 1>{}, b_scale_warp_y_lengths)); + b_scale_warp_tensor.get_thread_buffer() = b_scale_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, + b_scale_warp_y_index_zeros), + merge_sequences(sequence<1, 1, 1>{}, b_scale_warp_y_lengths)); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - AScaleWarpTensor a_scale_warp_tensor; + AScaleWarpTensor a_scale_warp_tensor; - a_scale_warp_tensor.get_thread_buffer() = - a_scale_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_scale_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_scale_warp_y_lengths)); + a_scale_warp_tensor.get_thread_buffer() = + a_scale_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_scale_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_scale_warp_y_lengths)); - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WarpGemm{}.template operator()<0, 0>( - c_warp_tensor, - a_warp_tensor, - b_warp_tensor, - int32_t(a_scale_warp_tensor.get_thread_buffer()[0]), - int32_t(b_scale_warp_tensor.get_thread_buffer()[0])); + // warp GEMM + WarpGemm{}.template operator()<0, 0>( + c_warp_tensor, + a_warp_tensor, + b_warp_tensor, + int32_t(a_scale_warp_tensor.get_thread_buffer()[0]), + int32_t(b_scale_warp_tensor.get_thread_buffer()[0])); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index f7f5cd33db..2b64f6e340 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -239,39 +239,39 @@ struct BlockUniversalGemmAsBsCr "C block tensor data type!"); // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } @@ -392,63 +392,59 @@ struct BlockUniversalGemmAsBsCr 0); // Prevents instruction reordering across this boundary } - static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kInnerIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = - b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, - b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - // read C warp tensor from C block tensor- - CWarpTensor c_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + // read C warp tensor from C block tensor- + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // The block_sync_lds() here performs double duty: - // A) safeguard against data hazard because barrier from - // blockwise_gemm is moved here B) reduce VMEM FIFO congestion - // by applying small delays to different wavefronts It is - // performed near the end of MAC cluster to minimize lgkmcnt - // penalty - if constexpr(kIter.value == KRepeat - 1 && - kInnerIter.value == KInnerLoopIter - 1 && - mIter.value == MIterPerWarp - 1 && - nIter.value == NIterPerWarp - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from + // blockwise_gemm is moved here B) reduce VMEM FIFO congestion + // by applying small delays to different wavefronts It is + // performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(kIter.value == KRepeat - 1 && + kInnerIter.value == KInnerLoopIter - 1 && + mIter.value == MIterPerWarp - 1 && + nIter.value == NIterPerWarp - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + // warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); - if constexpr(kInnerIter.value == 0 && mIter.value == 0 && - nIter.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } - }); + if constexpr(kInnerIter.value == 0 && mIter.value == 0 && nIter.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); diff --git a/include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp b/include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp index 4fc180b42b..45602f3064 100644 --- a/include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp +++ b/include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp @@ -156,55 +156,54 @@ struct BlockWeightPreshuffleASmemBRegCReg uniform_sequence_gen_t{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - BWarpTensor b_warp_tensor; - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + BWarpTensor b_warp_tensor; + CWarpTensor c_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, - typename sequence_split::right_type{}), - merge_sequences( - sequence<1, 1>{}, - typename sequence_split::right_type{})); + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences( + sequence{}, + typename sequence_split::right_type{}), + merge_sequences( + sequence<1, 1>{}, + typename sequence_split::right_type{})); - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WarpGemm{}( - c_warp_tensor, preloaded_a_warp_tensor(number{}), b_warp_tensor); + // warp GEMM + WarpGemm{}( + c_warp_tensor, preloaded_a_warp_tensor(number{}), b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - - load_tile(preloaded_a_warp_tensor(number{}), - a_load_windows[number{}][number{}]); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + __builtin_amdgcn_sched_barrier(0x7F6); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + + load_tile(preloaded_a_warp_tensor(number{}), + a_load_windows[number{}][number{}]); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); } }; diff --git a/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp index 49c26fab6c..08a7e7a3ea 100644 --- a/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp @@ -88,28 +88,28 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block window - const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp index a068001482..94fabe6f65 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp @@ -210,45 +210,45 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg : public BlockGemmQuantBase c_acc; auto zero_accumulators = [&] { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, (WG::kM * WG::kN) / warp_size, 1>{}([&](auto i) { - c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f; - }); // make sure WG::CWarpTensor exposes a clear/zero + static_ford>{}( + [&](auto mni) { + constexpr auto mIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + constexpr auto i = number{}]>{}; + c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f; }); - }); }; static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) { zero_accumulators(); - static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) { - constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // warp GEMM - WG{}(c_acc(mIter)(nIter), - a_warp_tensor(number{}), - b_warp_tensor(nIter)(number{})); - }); - __builtin_amdgcn_sched_barrier(0x7F6); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - - load_and_convert_tile( - a_warp_tensor(number{}), - a_warp_windows(number{})(number{})); - } - // barrier - // Could be deleted - if constexpr((mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + static_ford>{}([&](auto km) { + constexpr auto kIterInQScale = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // warp GEMM + WG{}(c_acc(mIter)(nIter), + a_warp_tensor(number{}), + b_warp_tensor(nIter)(number{})); }); + __builtin_amdgcn_sched_barrier(0x7F6); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + + load_and_convert_tile( + a_warp_tensor(number{}), + a_warp_windows(number{})(number{})); + } + // barrier + // Could be deleted + if constexpr((mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { AQPickerCommon aq_picker(aq_block_tensor); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index d2cfaca7b7..1ee3b227b7 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -127,105 +127,103 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg c_acc; auto zero_accumulators = [&] { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, (WG::kM * WG::kN) / warp_size, 1>{}([&](auto i) { - c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f; - }); // make sure WG::CWarpTensor exposes a clear/zero + static_ford>{}( + [&](auto mni) { + constexpr auto mIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + constexpr auto i = number{}]>{}; + c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f; }); - }); }; static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) { zero_accumulators(); - static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) { - constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // warp GEMM - WG{}(c_acc(mIter)(nIter), - a_warp_tensor(number{}), - b_warp_tensor(nIter)(number{})); - }); - __builtin_amdgcn_sched_barrier(0x7F6); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows(number{})(number{})); - } - // barrier - // Could be deleted - if constexpr((mIter == MIter_2nd_last)) - { - block_sync_lds(); - } - }); - }); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_ford>{}([&](auto km) { + constexpr auto kIterInQScale = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - constexpr auto tbuf_offset = - number{}, - c_warp_y_index_zeros)) / - CBlockTensor::PackedSize>{}; + // warp GEMM + WG{}(c_acc(mIter)(nIter), + a_warp_tensor(number{}), + b_warp_tensor(nIter)(number{})); + }); + __builtin_amdgcn_sched_barrier(0x7F6); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows(number{})(number{})); + } + // barrier + // Could be deleted + if constexpr((mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + static_ford>{}([&](auto mn) { + constexpr auto mIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + constexpr auto tbuf_offset = + number{}, c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; - if constexpr(BPreshuffleQuant) + if constexpr(BPreshuffleQuant) + { + constexpr index_t reg_offset = nIter; + auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale; + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + // cross lane ops + uint32_t scale_reg_dword; + + if constexpr(std::is_same_v) { - constexpr index_t reg_offset = nIter; - auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale; - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - // cross lane ops - uint32_t scale_reg_dword; - - if constexpr(std::is_same_v) - { - scale_reg_dword = ck_tile::bit_cast(scale_reg); - } - else - { - scale_reg_dword = static_cast(scale_reg); - } - - // cross lane ops to get the value of scale_reg. - int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( - pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); - - float scale_reg_f = cvt_scale_to_fp32(gathered_scale_reg); - - static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { - auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; - const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; - c_ref = c_ref + acc_val * scale_reg_f; - }); + scale_reg_dword = ck_tile::bit_cast(scale_reg); } else { - index_t reg_offset = [&]() { - if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN)) - { - return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN * - KPerBlockBQ + - kQScale; - } - else - { - return nIter * KPerBlockBQ + kQScale; - } - }(); - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - float scale_reg_f = cvt_scale_to_fp32(scale_reg); - - static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { - auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; - const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; - c_ref = c_ref + acc_val * scale_reg_f; - }); + scale_reg_dword = static_cast(scale_reg); } - }); + + // cross lane ops to get the value of scale_reg. + int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( + pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); + + float scale_reg_f = cvt_scale_to_fp32(gathered_scale_reg); + + static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { + auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; + const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; + c_ref = c_ref + acc_val * scale_reg_f; + }); + } + else + { + index_t reg_offset = [&]() { + if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN)) + { + return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN * KPerBlockBQ + + kQScale; + } + else + { + return nIter * KPerBlockBQ + kQScale; + } + }(); + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float scale_reg_f = cvt_scale_to_fp32(scale_reg); + + static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { + auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; + const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; + c_ref = c_ref + acc_val * scale_reg_f; + }); + } }); }); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index 24d9f9a1e5..cc65d213f1 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -290,121 +290,115 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase constexpr auto warp_size = get_warp_size(); // hot loop: - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto mn) { + constexpr auto mIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + CWarpTensor c_warp_tensor; - static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { - static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { - constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = - a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = - b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - if constexpr(kIterInQScale == 0) - { - c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); - } - else - { - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - } - }); - - constexpr auto tbuf_offset = - number{}, - c_warp_y_index_zeros)) / - CBlockTensor::PackedSize>{}; - // a_scale - AQPickerCommon aq_picker( - aq_block_tensor); - - if constexpr(BPreshuffleQuant) + if constexpr(kIterInQScale == 0) { - constexpr index_t reg_offset = [&]() { - if constexpr(GemmTraits::BQuantGroupSize::kN > - (NWarp * WarpGemm::kN) && - Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN) - { - return kQScale; - } - else - { - return nIter; - } - }(); - - auto pull_from_lane = - (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; - - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - // cross lane ops - uint32_t scale_reg_dword; - - if constexpr(std::is_same_v) - { - scale_reg_dword = ck_tile::bit_cast(scale_reg); - } - else - { - scale_reg_dword = static_cast(scale_reg); - } - - // cross lane ops to get the value of scale_reg. - int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( - pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); - - float b_scale_reg_f = - Base::cvt_scale_to_fp32( - gathered_scale_reg); - - static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( - [&](auto c_row) { - float a_scale_reg_f = aq_picker.template pick(); - c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += - (c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f * - b_scale_reg_f); - }); + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); } else { - // Multiply bquant with accumulated C - constexpr index_t reg_offset = [&]() { - if constexpr(GemmTraits::BQuantGroupSize::kN >= - (NWarp * WarpGemm::kN)) - return (nIter * NWarp * WarpGemm::kN) / - GemmTraits::BQuantGroupSize::kN * - Traits::KQPerBlock + - kQScale; - else - { - return nIter * Traits::KQPerBlock + kQScale; - } - }(); - - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - float b_scale_reg_f = - Base::cvt_scale_to_fp32(scale_reg); - - static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( - [&](auto c_row) { - float a_scale_reg_f = aq_picker.template pick(); - c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += - (c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f * - b_scale_reg_f); - }); + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); } }); + + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + // a_scale + AQPickerCommon aq_picker( + aq_block_tensor); + + if constexpr(BPreshuffleQuant) + { + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::BQuantGroupSize::kN > (NWarp * WarpGemm::kN) && + Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN) + { + return kQScale; + } + else + { + return nIter; + } + }(); + + auto pull_from_lane = + (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; + + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + // cross lane ops + uint32_t scale_reg_dword; + + if constexpr(std::is_same_v) + { + scale_reg_dword = ck_tile::bit_cast(scale_reg); + } + else + { + scale_reg_dword = static_cast(scale_reg); + } + + // cross lane ops to get the value of scale_reg. + int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( + pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); + + float b_scale_reg_f = Base::cvt_scale_to_fp32( + gathered_scale_reg); + + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + float a_scale_reg_f = aq_picker.template pick(); + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f * + b_scale_reg_f); + }); + } + else + { + // Multiply bquant with accumulated C + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN)) + return (nIter * NWarp * WarpGemm::kN) / + GemmTraits::BQuantGroupSize::kN * Traits::KQPerBlock + + kQScale; + else + { + return nIter * Traits::KQPerBlock + kQScale; + } + }(); + + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float b_scale_reg_f = + Base::cvt_scale_to_fp32(scale_reg); + + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + float a_scale_reg_f = aq_picker.template pick(); + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f * + b_scale_reg_f); + }); + } }); }); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 8b09530af1..64f8bc7df4 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -268,54 +268,51 @@ struct AQuantBlockUniversalGemmAsBsCr constexpr auto warp_size = get_warp_size(); // hot loop: - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto mn) { + constexpr auto mIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + CWarpTensor c_warp_tensor; - // for every column in AQ - static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { - // for every warp corresponding to a quantization scale - static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { - constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; + // for every column in AQ + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + // for every warp corresponding to a quantization scale + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = - a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = - b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - if constexpr(kIterInQScale == 0) - { - c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); - } - else - { - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - } - }); + if constexpr(kIterInQScale == 0) + { + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } + }); - constexpr auto tbuf_offset = - number{}, - c_warp_y_index_zeros)) / - CBlockTensor::PackedSize>{}; + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; - AQPickerCommon aq_picker( - aq_block_tensor); + AQPickerCommon aq_picker( + aq_block_tensor); - static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( - [&](auto c_row) { - float scale_reg_f = aq_picker.template pick(); + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}([&](auto c_row) { + float scale_reg_f = aq_picker.template pick(); - c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += - (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); - }); + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); }); }); }); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index f5900fcdec..9851fc917d 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -290,57 +290,55 @@ struct BQuantBlockUniversalGemmAsBsCr using SrcVectorRawType = ext_vector_t; using DstVectorType = ext_vector_t; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { - // B scale register offset - constexpr index_t reg_offset = [&]() { - if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN)) - return ((nIter * NWarp * WarpGemm::kN) / - GemmTraits::BQuantGroupSize::kN) * - Traits::KQPerBlock + - kQScale; - else - { - return nIter * Traits::KQPerBlock + kQScale; - } - }(); + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kQScale = number{}]>{}; + // B scale register offset + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN)) + return ((nIter * NWarp * WarpGemm::kN) / GemmTraits::BQuantGroupSize::kN) * + Traits::KQPerBlock + + kQScale; + else + { + return nIter * Traits::KQPerBlock + kQScale; + } + }(); - // Get B scale from thread buffer - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - float b_scale_f = float(scale_reg); + // Get B scale from thread buffer + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float b_scale_f = float(scale_reg); - static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { - constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; - // Thread buffers - using BWarpThreadBuffer = decltype(b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths))); - using BLDSThreadBuffer = decltype(b_warp_tile_lds_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths))); + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; + // Thread buffers + using BWarpThreadBuffer = decltype(b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths))); + using BLDSThreadBuffer = decltype(b_warp_tile_lds_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths))); - BWarpThreadBuffer b_warp_thread_buffer; - BLDSThreadBuffer b_lds_thread_buffer; + BWarpThreadBuffer b_warp_thread_buffer; + BLDSThreadBuffer b_lds_thread_buffer; - // Load thread buffer from tile (LDS type) - b_lds_thread_buffer = b_warp_tile_lds_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + // Load thread buffer from tile (LDS type) + b_lds_thread_buffer = b_warp_tile_lds_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - // Apply scale to B thread buffer and cast - static_for<0, thread_buffer_size, 1>{}([&](auto i) { - elementwise_op( - b_warp_thread_buffer.template get_as()(i), - b_lds_thread_buffer.template get_as()[i], - b_scale_f); - }); - - // Store B thread buffer to tile (MMA type) - b_warp_tile_.set_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths), - b_warp_thread_buffer); + // Apply scale to B thread buffer and cast + static_for<0, thread_buffer_size, 1>{}([&](auto i) { + elementwise_op(b_warp_thread_buffer.template get_as()(i), + b_lds_thread_buffer.template get_as()[i], + b_scale_f); }); + + // Store B thread buffer to tile (MMA type) + b_warp_tile_.set_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths), + b_warp_thread_buffer); }); }); } @@ -361,113 +359,107 @@ struct BQuantBlockUniversalGemmAsBsCr constexpr auto warp_size = get_warp_size(); // hot loop: - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto mn) { + constexpr auto mIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + CWarpTensor c_warp_tensor; - static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { - static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { - constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = - a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = - b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - if constexpr(kIterInQScale == 0) - { - c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); - } - else - { - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - } - }); - - constexpr auto tbuf_offset = - number{}, - c_warp_y_index_zeros)) / - CBlockTensor::PackedSize>{}; - - if constexpr(BPreshuffleQuant) + if constexpr(kIterInQScale == 0) { - constexpr index_t reg_offset = [&]() { - if constexpr(GemmTraits::BQuantGroupSize::kN > - (NWarp * WarpGemm::kN) && - Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN) - { - return kQScale; // prefill: one quant group per block - } - else - { - return nIter; // decode or multiple groups per warp - } - }(); - - auto pull_from_lane = - (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; - - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - // cross lane ops - uint32_t scale_reg_dword; - - if constexpr(std::is_same_v) - { - scale_reg_dword = ck_tile::bit_cast(scale_reg); - } - else - { - scale_reg_dword = static_cast(scale_reg); - } - - // cross lane ops to get the value of scale_reg. - int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( - pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); - - float scale_reg_f = - Base::cvt_scale_to_fp32( - gathered_scale_reg); - - static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( - [&](auto c_row) { - c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += - (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); - }); + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); } else { - // Multiply bquant with accumulated C - constexpr index_t reg_offset = [&]() { - if constexpr(GemmTraits::BQuantGroupSize::kN >= - (NWarp * WarpGemm::kN)) - return (nIter * NWarp * WarpGemm::kN) / - GemmTraits::BQuantGroupSize::kN * - Traits::KQPerBlock + - kQScale; - else - { - return nIter * Traits::KQPerBlock + kQScale; - } - }(); - - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - float scale_reg_f = - Base::cvt_scale_to_fp32(scale_reg); - static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( - [&](auto c_row) { - c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += - (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); - }); + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); } }); + + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + if constexpr(BPreshuffleQuant) + { + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::BQuantGroupSize::kN > (NWarp * WarpGemm::kN) && + Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN) + { + return kQScale; // prefill: one quant group per block + } + else + { + return nIter; // decode or multiple groups per warp + } + }(); + + auto pull_from_lane = + (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; + + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + // cross lane ops + uint32_t scale_reg_dword; + + if constexpr(std::is_same_v) + { + scale_reg_dword = ck_tile::bit_cast(scale_reg); + } + else + { + scale_reg_dword = static_cast(scale_reg); + } + + // cross lane ops to get the value of scale_reg. + int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( + pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); + + float scale_reg_f = Base::cvt_scale_to_fp32( + gathered_scale_reg); + + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); + }); + } + else + { + // Multiply bquant with accumulated C + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN)) + return (nIter * NWarp * WarpGemm::kN) / + GemmTraits::BQuantGroupSize::kN * Traits::KQPerBlock + + kQScale; + else + { + return nIter * Traits::KQPerBlock + kQScale; + } + }(); + + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float scale_reg_f = + Base::cvt_scale_to_fp32(scale_reg); + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); + }); + } }); }); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index f48e12984c..c87a02efe0 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -288,22 +288,22 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe MIterPerWarp> a_warp_windows_pong; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; - move_tile_window(a_warp_windows_ping(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; - move_tile_window(a_warp_windows_pong(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(a_warp_windows_pong(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); // Block GEMM @@ -366,16 +366,16 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // prefetch B - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * flatNPerWarp, kIter * flatKPerWarp}); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), - b_flat_dram_windows(nIter)(kIter)); - }); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); // move B window to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -448,15 +448,15 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe bq_block_tile, a_warp_windows_ping); // prefetch B(2i+1) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), - b_flat_dram_windows(nIter)(kIter)); - }); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); aq_block_tile_2 = load_tile(aq_copy_dram_window); @@ -473,15 +473,15 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe // Next K // prefetch B(2i+2) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), - b_flat_dram_windows(nIter)(kIter)); - }); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); aq_block_tile = load_tile(aq_copy_dram_window); @@ -520,16 +520,16 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe if constexpr(TailNum == TailNumber::Even) { // prefetch B(loopK) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * flatNPerWarp, kIter * flatKPerWarp}); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), - b_flat_dram_windows(nIter)(kIter)); - }); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); aq_block_tile_2 = load_tile(aq_copy_dram_window); bq_block_tile_2 = load_tile(bq_copy_dram_window); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index 025ef53dbb..ff98a06662 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -275,22 +275,22 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV MIterPerWarp> a_warp_windows_pong; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; - move_tile_window(a_warp_windows_ping(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; - move_tile_window(a_warp_windows_pong(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(a_warp_windows_pong(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); // Block GEMM @@ -337,16 +337,16 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // prefetch B - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * flatNPerWarp, kIter * flatKPerWarp}); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), - b_flat_dram_windows(nIter)(kIter)); - }); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); // move B window to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -424,15 +424,15 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV bq_block_tile, a_warp_windows_ping); // prefetch B(2i+1) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), - b_flat_dram_windows(nIter)(kIter)); - }); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -461,15 +461,15 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV // Next K // prefetch B(2i+2) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), - b_flat_dram_windows(nIter)(kIter)); - }); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -518,16 +518,16 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV if constexpr(TailNum == TailNumber::Even) { // prefetch B(loopK) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * flatNPerWarp, kIter * flatKPerWarp}); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), - b_flat_dram_windows(nIter)(kIter)); - }); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); bq_block_tile_2 = load_tile(bq_copy_dram_window); diff --git a/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp index 07d97ec4ff..da9c5c4d57 100644 --- a/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp @@ -303,11 +303,11 @@ struct BlockNormReduceCrossWarpSync index_t local_warp_id = warp_id / num_reduce_warps; index_t local_smem_os = local_warp_id * num_reduce_warps; smem_dtype all_scratch[thread_buf_size * num_reduce_warps]; - static_for<0, thread_buf_size, 1>{}([&](auto i_0) { - static_for<0, num_reduce_warps, 1>{}([&](auto i_1) { - all_scratch[i_0 * num_reduce_warps + i_1] = - smem_ptr[i_0 * num_warps + local_smem_os + i_1]; - }); + static_ford>{}([&](auto ii) { + constexpr auto i_0 = number{}]>{}; + constexpr auto i_1 = number{}]>{}; + all_scratch[i_0 * num_reduce_warps + i_1] = + smem_ptr[i_0 * num_warps + local_smem_os + i_1]; }); block_sync_lds(); // TODO: we don't need sync here diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index ccbdb20793..abad5ed031 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -631,17 +631,17 @@ struct BlockReduce2dLinearCrossWarpSync IndexDataType> all_indices; // Load data from shared memory - static_for<0, thread_buf_size, 1>{}([&](auto i_0) { - static_for<0, num_reduce_warps, 1>{}([&](auto i_1) { - all_scratch[i_0 * num_reduce_warps + i_1] = - smem_ptr[i_0 * num_warps + local_smem_os + i_1]; + static_ford>{}([&](auto ii) { + constexpr auto i_0 = number{}]>{}; + constexpr auto i_1 = number{}]>{}; + all_scratch[i_0 * num_reduce_warps + i_1] = + smem_ptr[i_0 * num_warps + local_smem_os + i_1]; - if constexpr(kProcessIndex) - { - all_indices[i_0 * num_reduce_warps + i_1] = - smem_indices[i_0 * num_warps + local_smem_os + i_1]; - } - }); + if constexpr(kProcessIndex) + { + all_indices[i_0 * num_reduce_warps + i_1] = + smem_indices[i_0 * num_warps + local_smem_os + i_1]; + } }); block_sync_lds(); // TODO: we don't need sync here