diff --git a/CHANGELOG.md b/CHANGELOG.md index 370e9e4243..f6812a8520 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added FP8 block scale quantization for FMHA forward kernel. * Added gfx11 support for FMHA. * Added microscaling (MX) FP8/FP4 support on gfx950 for FMHA forward kernel ("qr" pipeline only). +* Added FP8 per-tensor quantization support for FMHA forward V3 pipeline on gfx950. ### Changed diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index a5fffb5159..42e2d1f487 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -206,22 +206,14 @@ float {F_func_name}([[maybe_unused]] fmha_fwd_traits t, [[maybe_unused]] fmha_fw """ FMHA_FWD_API_FOOTER_TEMPLATE = """ float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) {{ - const std::string device_name = ck_tile::get_device_name(); - - const bool is_swa = (traits.mask_type != mask_enum::no_mask) and - ((0 < args.window_size_left) or (0 < args.window_size_right)); - const bool can_dispatch_v3 = - (device_name.compare(0, 6, "gfx950") == 0) and - (traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and - traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and - (not traits.has_lse) and (not traits.has_dropout) and - (traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and - (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128); - if ({F_is_v3_enabled} and can_dispatch_v3) {{ - return fmha_fwd_v3(traits, args, config); - }} else {{ - return fmha_fwd_v2(traits, args, config); +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunreachable-code" + if ({F_is_v3_enabled}) {{ + float r = fmha_fwd_v3(traits, args, config); + if (r >= 0) return r; }} +#pragma clang diagnostic pop + return fmha_fwd_v2(traits, args, config); }} """ @@ -1059,10 +1051,11 @@ class KernelComponentFactoryGfx950( def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype) if dtype in cls._DT_FP16_BF16: - # add tile for qr_async_trload_v3 - if (128, 128) in result.keys(): - result[(128, 128)].append( - FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip + # # add tile for qr_async_trload_v3 (bf16/fp16 V3 not ready) + # if (128, 128) in result.keys(): + # result[(128, 128)].append( + # FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip + pass elif dtype in cls._DT_MXFP8: return { # bm0, bn0, bk0, bn1, bk1, @@ -1075,6 +1068,10 @@ class KernelComponentFactoryGfx950( (128, 128) : [FmhaFwdTileSize(128, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 32, 32, 64, 32, 32, 64, -1)], (256, 256) : [FmhaFwdTileSize(128, 128, 128, 256, 128, 256, 4, 1, 1, 4, 1, 1, 16, 16, 128, 16, 16, 128, -1)], } # fmt: skip + elif dtype in cls._DT_FP8BF16: + if (128, 128) in result.keys(): + result[(128, 128)].append( + FmhaFwdTileSize(256, 64, 128, 128, 64, 128, 8, 1, 1, 8, 1, 1, 32, 32, 32, 32, 32, 32, -1)) # fmt: skip return result @classmethod @@ -1105,12 +1102,19 @@ class KernelComponentFactoryGfx950( pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip - # qr_async_trload_v3 only supports hdim=hdim_v=128 for now - if (hdim, hdim_v) == (128, 128): - # qr_async_trload_v3 only supports (generic) causal mask - for logits, mask in itertools.product(["t", "f"], ["no", "causal"]): - pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", - F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip + # # qr_async_trload_v3 bf16/fp16 not ready + # if (hdim, hdim_v) == (128, 128): + # for logits, mask in itertools.product(["t", "f"], ["no", "causal"]): + # pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", + # F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip + elif dtype in cls._DT_FP8BF16: + # qr_async_trload_v3 only supports (generic) causal mask + for logits, qscale, mask in itertools.product( + ["t", "f"], + ["no", "pertensor"], + ["no", "causal"], + ): + pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip elif dtype in cls._DT_MXFP8 or dtype in cls._DT_MXFP4: # no need dropout kernels @@ -1494,8 +1498,8 @@ def write_fwd_api( FMHA_FWD_API_FOOTER_TEMPLATE.format( F_is_v3_enabled=BOOL_MAP[ # NOTE: enable v3 pipelines when ready - # 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) - False + 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) + # False ] ), ] diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 521f1e4738..7d7d01bd05 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -844,6 +844,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) return FmhaKernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, + args.q_descale_ptr, + args.k_descale_ptr, + args.v_descale_ptr, nullptr, // lse_ptr args.o_ptr, args.seqstart_q_ptr, @@ -877,6 +880,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) return FmhaKernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, + args.q_descale_ptr, + args.k_descale_ptr, + args.v_descale_ptr, nullptr, // lse_ptr args.o_ptr, args.seqlen_q, diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 0775b34eef..417ec12c8c 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -1209,7 +1209,8 @@ enum LLVMSchedGroupMask : int32_t DS = 1 << 7, DS_READ = 1 << 8, DS_WRITE = 1 << 9, - ALL = (DS_WRITE << 1) - 1, + TRANS = 1 << 10, + ALL = (TRANS << 1) - 1, }; CK_TILE_HOST_DEVICE static constexpr auto get_max_mem_vec_inst_width() diff --git a/include/ck_tile/host/kernel_launch.hpp b/include/ck_tile/host/kernel_launch.hpp index ca7a5c765c..c96a427db1 100644 --- a/include/ck_tile/host/kernel_launch.hpp +++ b/include/ck_tile/host/kernel_launch.hpp @@ -27,6 +27,8 @@ inline constexpr bool kattr_no_packed_fp32_ops_v> = T::kattr_no_packed_fp32_ops; +// TODO: rename to something more specific (e.g. kernel_attr_no_packed_fp32) since +// kernel_attr only controls the no-packed-fp32-ops flag, not a general attribute bag. template struct kernel_attr { @@ -35,6 +37,32 @@ struct kernel_attr static constexpr bool kattr_no_packed_fp32_ops = no_packed_fp32_ops; }; +// Compose an architecture tag with kernel attributes. +// Inherits ArchTag for symbol mangling and adds attribute flags. +// kernel_attr_for -> gfx950_t (identity) +// kernel_attr_for> -> unique type with attribute +namespace detail { +template +struct kernel_attr_for_impl : ArchTag, Attrs... +{ +}; + +template +struct kernel_attr_for_helper +{ + using type = kernel_attr_for_impl; +}; + +template +struct kernel_attr_for_helper +{ + using type = ArchTag; +}; +} // namespace detail + +template +using kernel_attr_for = typename detail::kernel_attr_for_helper::type; + #if CK_TILE_USE_LAUNCH_BOUNDS #define KENTRY_LAUNCH_BOUNDS __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu) #else diff --git a/include/ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp b/include/ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp index 25ef000cc3..f22919d922 100644 --- a/include/ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp +++ b/include/ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp @@ -187,11 +187,11 @@ struct EpilogueGraph Context& context) const { // For each iteration, process all epilogues in order - static_for<0, Steps, 1>{}([&](auto iAccess) { - static_for<0, sizeof...(EpilogueTypes), 1>{}([&](auto I) { - epilogues.template get()( - out_window, acc_tile, aux_windows, p_smem, context, iAccess); - }); + static_ford>{}([&](auto iI) { + constexpr auto iAccess = number{}]>{}; + constexpr auto I = number{}]>{}; + epilogues.template get()( + out_window, acc_tile, aux_windows, p_smem, context, iAccess); }); } }; diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp index 2b8e9e4b1a..de73e4f1ff 100644 --- a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp @@ -92,29 +92,29 @@ struct BlockFlatmmASmemBSmemCRegV1 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block window - const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - __builtin_amdgcn_sched_barrier(0x7F6); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + __builtin_amdgcn_sched_barrier(0x7F6); }); }); } diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 13d5e65155..81cf76cb07 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -1105,15 +1105,14 @@ struct MoeFlatmmKernel statically_indexed_array scale_m_offsets; if constexpr(!BMXFP4_Pipeline) - static_for<0, MRepeat, 1>{}([&](auto mIter) { - static_for<0, kM0, 1>{}([&](auto m0) { - static_for<0, kM2, 1>{}([&](auto m2) { - const auto row_idx = - coord_m + mIter * MPerXdl + m0 * kM1 * kM2 + m2 + scale_m_coord[I0]; - scale_m_offsets[mIter * number{} + m0 * number{} + m2] = - row_to_token_idx(row_idx); - }); - }); + static_ford>{}([&](auto mmm) { + constexpr auto mIter = number{}]>{}; + constexpr auto m0 = number{}]>{}; + constexpr auto m2 = number{}]>{}; + const auto row_idx = + coord_m + mIter * MPerXdl + m0 * kM1 * kM2 + m2 + scale_m_coord[I0]; + scale_m_offsets[mIter * number{} + m0 * number{} + m2] = + row_to_token_idx(row_idx); }); constexpr int DynamicTileOffsetFlag = 0; @@ -1426,19 +1425,19 @@ struct MoeFlatmmKernel statically_indexed_array, NumMEpiTile> c_scatter_valids; auto c_coord = dram_tile_distribution.calculate_index(); - static_for<0, NumMEpiTile, 1>{}([&](auto mIter) { - static_for<0, MPerThread, 1>{}([&](auto m0) { - auto row_idx = coord_m + mIter * MPerIterationShuffle + c_coord[0] + m0; - auto fused_token = - kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0] + static_ford>{}([&](auto mm) { + constexpr auto mIter = number{}]>{}; + constexpr auto m0 = number{}]>{}; + auto row_idx = coord_m + mIter * MPerIterationShuffle + c_coord[0] + m0; + auto fused_token = + kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0] - index_t scatter_token_id = fused_token & token_id_mask; - c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); - if constexpr(IsInputGemm) - scatter_token_id = - scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); - c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; - }); + index_t scatter_token_id = fused_token & token_id_mask; + c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); + if constexpr(IsInputGemm) + scatter_token_id = + scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); + c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; }); //===----------------------------------------------------------------------===// diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index ee8527c458..8f40c9be7a 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -606,16 +606,16 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 MIterPerWarp> a_warp_windows_pong; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; - a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; - move_tile_window(a_warp_windows_ping(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - move_tile_window(a_warp_windows_pong(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + move_tile_window(a_warp_windows_pong(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); // Block GEMM @@ -656,15 +656,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // prefetch B - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); // move B window to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -701,15 +701,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 while(iCounter > 0) { // prefetch B(2i+1) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); // Prefill A(2i+1) @@ -722,44 +722,44 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // GEMM 2i - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); // move B window to next flat K @@ -776,15 +776,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 // Next K // prefetch B(2i+2) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); // Prefill A(2i+2) @@ -797,43 +797,43 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // GEMM 2i+1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_pong(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_pong(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_pong(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); // move B window to next flat K @@ -854,15 +854,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 if constexpr(TailNum == TailNumber::Even) { // prefetch B(loopK) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); // Prefill A(loopK) @@ -870,44 +870,44 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 store_tile(a_copy_lds_window_pong, a_block_tile_tmp); // GEMM loopK-1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); static_for<0, m_preload, 1>{}([&](auto loadIter) { @@ -920,86 +920,86 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 Last2ndHotLoopScheduler(); // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_pong(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_pong(number{})(number{})); - } - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_pong(number{})(number{})); + } + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); LastHotLoopScheduler(); } else if constexpr(TailNum == TailNumber::Odd) { // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); LastHotLoopScheduler(); } diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 11b978813a..0f7f742fa0 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -537,22 +537,22 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 a_warp_windows_pong; auto A_Lds_Stride = 8; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; - a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; - auto weight_k_idx = kIter / number{}; - auto weight_k_rank = kIter % number{}; - move_tile_window( - a_warp_windows_ping(mIter)(kIter), - {mIter * MPerBlockPerIter, - weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK}); - move_tile_window( - a_warp_windows_pong(mIter)(kIter), - {mIter * MPerBlockPerIter, - weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK}); - }); + auto weight_k_idx = kIter / number{}; + auto weight_k_rank = kIter % number{}; + move_tile_window( + a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, + weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK}); + move_tile_window( + a_warp_windows_pong(mIter)(kIter), + {mIter * MPerBlockPerIter, + weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK}); }); // Block GEMM @@ -657,33 +657,32 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // prefetch B - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) { - if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0) - { - auto scale_n_iter = nIter / number{}; - auto scale_k_iter = kIter / number{}; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0) + { + auto scale_n_iter = nIter / number{}; + auto scale_k_iter = kIter / number{}; - scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) = - scale_b_flat_dram_window; - move_tile_window( - scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter), - {scale_n_iter * NFlatPerBlockPerIter, scale_k_iter * ScaleKFlatPerWarp}); - scale_b_warp_tensor_ping(scale_n_iter)(scale_k_iter) = - load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter)); - } - auto packed_n_idx = nIter / number{}; - auto packed_n_rank = nIter % number{}; + scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) = scale_b_flat_dram_window; + move_tile_window( + scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter), + {scale_n_iter * NFlatPerBlockPerIter, scale_k_iter * ScaleKFlatPerWarp}); + scale_b_warp_tensor_ping(scale_n_iter)(scale_k_iter) = + load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter)); + } + auto packed_n_idx = nIter / number{}; + auto packed_n_rank = nIter % number{}; - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter + - packed_n_rank, - kIter * KFlatPerBlockPerIter}); + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + move_tile_window( + b_flat_dram_windows(nIter)(kIter), + {packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter + packed_n_rank, + kIter * KFlatPerBlockPerIter}); - ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); - b_warp_tensor_ping(nIter)(kIter) = ub.u; - }); + ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); + b_warp_tensor_ping(nIter)(kIter) = ub.u; }); // move B window to next flat K move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter}); @@ -794,38 +793,37 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 while(iCounter > 0) { // prefetch B(2i+1) - static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0) - { - auto scale_n_iter = nIter / number{}; - auto scale_k_iter = kIter / number{}; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0) + { + auto scale_n_iter = nIter / number{}; + auto scale_k_iter = kIter / number{}; - scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) = - scale_b_flat_dram_window; - - move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter), - {scale_n_iter * NFlatPerBlockPerIter, - scale_k_iter * ScaleKFlatPerWarp}); - - scale_b_warp_tensor_pong(scale_n_iter)(scale_k_iter) = - load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter)); - } - - auto packed_n_idx = nIter / number{}; - auto packed_n_rank = nIter % number{}; - - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) = + scale_b_flat_dram_window; move_tile_window( - b_flat_dram_windows(nIter)(kIter), - {packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter + - packed_n_rank, - kIter * KFlatPerBlockPerIter}); + scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter), + {scale_n_iter * NFlatPerBlockPerIter, scale_k_iter * ScaleKFlatPerWarp}); - ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); - b_warp_tensor_pong(nIter)(kIter) = ub.u; - }); + scale_b_warp_tensor_pong(scale_n_iter)(scale_k_iter) = + load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter)); + } + + auto packed_n_idx = nIter / number{}; + auto packed_n_rank = nIter % number{}; + + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter + + packed_n_rank, + kIter * KFlatPerBlockPerIter}); + + ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); + b_warp_tensor_pong(nIter)(kIter) = ub.u; }); // Prefill A(2i+1) @@ -835,51 +833,50 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 prefill_lds_a_stage1( a_copy_lds_window_ping, a_copy_dram_window, number{}); // GEMM 2i - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - if constexpr(mIter == 0) - dequant_mxfp4( - b_warp_tensor_ping(nIter)(kIter / number{}), - scale_b_warp_tensor_ping(nIter / number{})( - kIter / number{}), - nIter, - kIter); + if constexpr(mIter == 0) + dequant_mxfp4(b_warp_tensor_ping(nIter)(kIter / number{}), + scale_b_warp_tensor_ping(nIter / number{})( + kIter / number{}), + nIter, + kIter); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor(number{}), dequant_B_n[nIter]); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor(number{}), dequant_B_n[nIter]); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - __builtin_amdgcn_s_waitcnt(Bload_total_num); - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + __builtin_amdgcn_s_waitcnt(Bload_total_num); + block_sync_lds(); + } }); prefill_lds_a_stage1( a_copy_lds_window_ping, a_copy_dram_window, number{}); @@ -902,37 +899,36 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 // Next K // prefetch B(2i+2) - static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0) - { - auto scale_n_iter = nIter / number{}; - auto scale_k_iter = kIter / number{}; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0) + { + auto scale_n_iter = nIter / number{}; + auto scale_k_iter = kIter / number{}; - scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) = - scale_b_flat_dram_window; + scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) = + scale_b_flat_dram_window; - move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter), - {scale_n_iter * NFlatPerBlockPerIter, - scale_k_iter * ScaleKFlatPerWarp}); - - scale_b_warp_tensor_ping(scale_n_iter)(scale_k_iter) = - load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter)); - } - - auto packed_n_idx = nIter / number{}; - auto packed_n_rank = nIter % number{}; - - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; move_tile_window( - b_flat_dram_windows(nIter)(kIter), - {packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter + - packed_n_rank, - kIter * KFlatPerBlockPerIter}); + scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter), + {scale_n_iter * NFlatPerBlockPerIter, scale_k_iter * ScaleKFlatPerWarp}); - ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); - b_warp_tensor_ping(nIter)(kIter) = ub.u; - }); + scale_b_warp_tensor_ping(scale_n_iter)(scale_k_iter) = + load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter)); + } + + auto packed_n_idx = nIter / number{}; + auto packed_n_rank = nIter % number{}; + + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter + + packed_n_rank, + kIter * KFlatPerBlockPerIter}); + + ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); + b_warp_tensor_ping(nIter)(kIter) = ub.u; }); // Prefill A(2i+2) @@ -943,50 +939,49 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 a_copy_lds_window_pong, a_copy_dram_window, number{}); // GEMM 2i+1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - if constexpr(mIter == 0) - dequant_mxfp4( - b_warp_tensor_pong(nIter)(kIter / number{}), - scale_b_warp_tensor_pong(nIter / number{})( - kIter / number{}), - nIter, - kIter); + if constexpr(mIter == 0) + dequant_mxfp4(b_warp_tensor_pong(nIter)(kIter / number{}), + scale_b_warp_tensor_pong(nIter / number{})( + kIter / number{}), + nIter, + kIter); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor(number{}), dequant_B_n[nIter]); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor(number{}), dequant_B_n[nIter]); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_pong(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - __builtin_amdgcn_s_waitcnt(Bload_total_num); - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_pong(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + __builtin_amdgcn_s_waitcnt(Bload_total_num); + block_sync_lds(); + } }); prefill_lds_a_stage1( a_copy_lds_window_pong, a_copy_dram_window, number{}); @@ -1058,51 +1053,50 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 prefill_lds_a_stage2(a_copy_lds_window_pong); // GEMM loopK-1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - if constexpr(mIter == 0) - dequant_mxfp4( - b_warp_tensor_ping(nIter)(kIter / number{}), - scale_b_warp_tensor_ping(nIter / number{})( - kIter / number{}), - nIter, - kIter); + if constexpr(mIter == 0) + dequant_mxfp4(b_warp_tensor_ping(nIter)(kIter / number{}), + scale_b_warp_tensor_ping(nIter / number{})( + kIter / number{}), + nIter, + kIter); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor(number{}), dequant_B_n[nIter]); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor(number{}), dequant_B_n[nIter]); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - __builtin_amdgcn_s_waitcnt(Bload_total_num); - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + __builtin_amdgcn_s_waitcnt(Bload_total_num); + block_sync_lds(); + } }); static_for<0, m_preload, 1>{}([&](auto loadIter) { @@ -1170,50 +1164,49 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 else if constexpr(TailNum == TailNumber::Odd) { // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - if constexpr(mIter == 0) - dequant_mxfp4( - b_warp_tensor_ping(nIter)(kIter / number{}), - scale_b_warp_tensor_ping(nIter / number{})( - kIter / number{}), - nIter, - kIter); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor(number{}), dequant_B_n[nIter]); + if constexpr(mIter == 0) + dequant_mxfp4(b_warp_tensor_ping(nIter)(kIter / number{}), + scale_b_warp_tensor_ping(nIter / number{})( + kIter / number{}), + nIter, + kIter); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor(number{}), dequant_B_n[nIter]); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - __builtin_amdgcn_s_waitcnt(Bload_total_num); - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + __builtin_amdgcn_s_waitcnt(Bload_total_num); + block_sync_lds(); + } }); LastHotLoopScheduler(); } @@ -1904,29 +1897,29 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 }); // prefetch Scale A - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + static_ford>{}([&](auto mk) { + constexpr auto mIter_pack = number{}]>{}; + constexpr auto kIter_pack = number{}]>{}; + scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; + move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), + {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); - }); + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = + load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); }); // move Scale A window to next K move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); // prefetch Scale B - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); + static_ford>{}([&](auto nk) { + constexpr auto nIter_pack = number{}]>{}; + constexpr auto kIter_pack = number{}]>{}; + scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; + move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), + {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); - }); + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = + load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); }); // move Scale B window to next K move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); @@ -1957,95 +1950,90 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 // MAIN LOOP auto main_body_implx2 = [&]() mutable { // prefetch B(2i+1) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_windows(nIter), number{}); - if constexpr(kIter == KIterPerWarp - 1) - move_tile_window(b_flat_dram_windows(nIter), - {0, BlockGemmShape::flatKPerBlock}); - }); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_windows(nIter), number{}); + if constexpr(kIter == KIterPerWarp - 1) + move_tile_window(b_flat_dram_windows(nIter), + {0, BlockGemmShape::flatKPerBlock}); }); // prefetch Scale A and Scale B (2i+1) - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + static_ford>{}([&](auto mk) { + constexpr auto mIter_pack = number{}]>{}; + constexpr auto kIter_pack = number{}]>{}; + scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; + move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), + {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); - }); + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = + load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); }); - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); + static_ford>{}([&](auto nk) { + constexpr auto nIter_pack = number{}]>{}; + constexpr auto kIter_pack = number{}]>{}; + scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; + move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), + {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); - }); + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = + load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); }); // GEMM 2i - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; - constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; - constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; + static_ford>{}([&](auto idx) { + constexpr auto kIter_pack = number{}]>{}; + constexpr auto mIter_pack = number{}]>{}; + constexpr auto nIter_pack = number{}]>{}; + constexpr auto ikxdl = number{}]>{}; + constexpr auto imxdl = number{}]>{}; + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}.template - // operator()( - operator()( - c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) - .get_thread_buffer()[0], - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) - .get_thread_buffer()[0]); + // warp GEMM + WG{}.template + // operator()( + operator()( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter_pack * number{} + + inxdl)(kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack).get_thread_buffer()[0], + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack).get_thread_buffer()[0]); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - constexpr auto addr = - m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; - if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) - { - constexpr auto AmIter = addr % 2 + addr / 4 * 2; - constexpr auto AkIter = addr / 2 % 2; - a_warp_tensor(number{}) = load_tile_with_offset( - a_warp_window_ping, - tuple, number>{}); - } - }); - }); - }); + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + constexpr auto addr = m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_ping, + tuple, number>{}); + } }); // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished s_waitcnt< // vmcnt @@ -2072,96 +2060,94 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 ////////////////////////////// Next K ////////////////////////////// // prefetch B(2i+2) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_windows(nIter), number{}); - if constexpr(kIter == KIterPerWarp - 1) - move_tile_window(b_flat_dram_windows(nIter), - {0, BlockGemmShape::flatKPerBlock}); - }); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_windows(nIter), number{}); + if constexpr(kIter == KIterPerWarp - 1) + move_tile_window(b_flat_dram_windows(nIter), + {0, BlockGemmShape::flatKPerBlock}); }); // prefetch Scale A and Scale B (2i+2) - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + static_ford>{}([&](auto mk) { + constexpr auto mIter_pack = number{}]>{}; + constexpr auto kIter_pack = number{}]>{}; + scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; + move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), + {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); - }); + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = + load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); }); - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); + static_ford>{}([&](auto nk) { + constexpr auto nIter_pack = number{}]>{}; + constexpr auto kIter_pack = number{}]>{}; + scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; + move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), + {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); - }); + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = + load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); }); // GEMM 2i+1 - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + static_ford>{}([&](auto idx) { + constexpr auto kIter_pack = number{}]>{}; + constexpr auto mIter_pack = number{}]>{}; + constexpr auto nIter_pack = number{}]>{}; + constexpr auto ikxdl = number{}]>{}; + constexpr auto imxdl = number{}]>{}; + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}.template - // operator()( - operator()( - c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) - .get_thread_buffer()[0], // scale A - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) - .get_thread_buffer()[0]); // scale B + // warp GEMM + WG{}.template + // operator()( + operator()( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_pong(nIter_pack * number{} + + inxdl)(kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], // scale A + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); // scale B - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; - if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) - { - constexpr auto AmIter = addr % 2 + addr / 4 * 2; - constexpr auto AkIter = addr / 2 % 2; - a_warp_tensor(number{}) = load_tile_with_offset( - a_warp_window_pong, - tuple, number>{}); - } - }); - }); - }); + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + + (kIter_pack * KXdlPack + ikxdl) * 2 + + (mIter_pack * MXdlPack + imxdl) / 2 * 4 + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_pong, + tuple, number>{}); + } }); // barrier as ds_load A(2i + 1) and buffer_load_lds A(2i + 2) finished s_waitcnt< // vmcnt @@ -2199,92 +2185,89 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 if constexpr(TailNum == TailNumber::Even) { // prefetch B(loopK) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_windows(nIter), - make_tuple(number<0>{}, number{})); - }); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_windows(nIter), + make_tuple(number<0>{}, number{})); }); // prefetch Scale A and Scale B (2i+1) - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + static_ford>{}([&](auto mk) { + constexpr auto mIter_pack = number{}]>{}; + constexpr auto kIter_pack = number{}]>{}; + scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; + move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), + {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); - }); + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = + load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); }); - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); + static_ford>{}([&](auto nk) { + constexpr auto nIter_pack = number{}]>{}; + constexpr auto kIter_pack = number{}]>{}; + scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; + move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), + {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); - }); + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = + load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); }); // GEMM loopK-1 - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + static_ford>{}([&](auto idx) { + constexpr auto kIter_pack = number{}]>{}; + constexpr auto mIter_pack = number{}]>{}; + constexpr auto nIter_pack = number{}]>{}; + constexpr auto ikxdl = number{}]>{}; + constexpr auto imxdl = number{}]>{}; + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}.template - operator()( - c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) - .get_thread_buffer()[0], // scale A - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) - .get_thread_buffer()[0]); // scale B + // warp GEMM + WG{}.template operator()( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter_pack * number{} + + inxdl)(kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], // scale A + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); // scale B - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; - if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) - { - constexpr auto AmIter = addr % 2 + addr / 4 * 2; - constexpr auto AkIter = addr / 2 % 2; - a_warp_tensor(number{}) = load_tile_with_offset( - a_warp_window_ping, - tuple, number>{}); - } - }); - }); - }); + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + + (kIter_pack * KXdlPack + ikxdl) * 2 + + (mIter_pack * MXdlPack + imxdl) / 2 * 4 + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_ping, + tuple, number>{}); + } }); // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished s_waitcnt< // vmcnt @@ -2302,123 +2285,115 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 // Last2ndHotLoopScheduler(); // GEMM loopK - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + static_ford>{}([&](auto idx) { + constexpr auto kIter_pack = number{}]>{}; + constexpr auto mIter_pack = number{}]>{}; + constexpr auto nIter_pack = number{}]>{}; + constexpr auto ikxdl = number{}]>{}; + constexpr auto imxdl = number{}]>{}; + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}.template - operator()( - // operator()( - c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) - .get_thread_buffer()[0], // scale A - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) - .get_thread_buffer()[0]); // scale B + // warp GEMM + WG{}.template operator()( + // operator()( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_pong(nIter_pack * number{} + + inxdl)(kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], // scale A + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); // scale B - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; - if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) - { - constexpr auto AmIter = addr % 2 + addr / 4 * 2; - constexpr auto AkIter = addr / 2 % 2; - a_warp_tensor(number{}) = load_tile_with_offset( - a_warp_window_pong, - tuple, number>{}); - } - }); - }); - }); + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + + (kIter_pack * KXdlPack + ikxdl) * 2 + + (mIter_pack * MXdlPack + imxdl) / 2 * 4 + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_pong, + tuple, number>{}); + } }); // LastHotLoopScheduler(); } else if constexpr(TailNum == TailNumber::Odd) { // GEMM loopK - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; - constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; - constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; + static_ford>{}([&](auto idx) { + constexpr auto kIter_pack = number{}]>{}; + constexpr auto mIter_pack = number{}]>{}; + constexpr auto nIter_pack = number{}]>{}; + constexpr auto ikxdl = number{}]>{}; + constexpr auto imxdl = number{}]>{}; + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}.template - // operator()( - operator()( - c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) - .get_thread_buffer()[0], - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) - .get_thread_buffer()[0]); + // warp GEMM + WG{}.template + // operator()( + operator()( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter_pack * number{} + + inxdl)(kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack).get_thread_buffer()[0], + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack).get_thread_buffer()[0]); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - constexpr auto addr = - m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; - if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) - { - constexpr auto AmIter = addr % 2 + addr / 4 * 2; - constexpr auto AkIter = addr / 2 % 2; - a_warp_tensor(number{}) = load_tile_with_offset( - a_warp_window_ping, - tuple, number>{}); - } - }); - }); - }); + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + constexpr auto addr = m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_ping, + tuple, number>{}); + } }); // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished s_waitcnt< // vmcnt diff --git a/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp b/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp index 5681726afe..543f4dc92a 100644 --- a/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp @@ -529,22 +529,22 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 MIterPerWarp> a_warp_windows_pong; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; - move_tile_window(a_warp_windows_ping(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; - move_tile_window(a_warp_windows_pong(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(a_warp_windows_pong(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); // Block GEMM @@ -592,26 +592,26 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 2; // prefetch B - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - if constexpr(!IsGateUpMode) - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + if constexpr(!IsGateUpMode) + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + else + { + if constexpr(nIter % 2 == 0) + move_tile_window( + b_flat_dram_windows(nIter)(kIter), + {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); else - { - if constexpr(nIter % 2 == 0) - move_tile_window( - b_flat_dram_windows(nIter)(kIter), - {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - else - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride, - kIter * KFlatPerBlockPerIter}); - } - b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride, + kIter * KFlatPerBlockPerIter}); + } + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); // move B window to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -648,28 +648,27 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 while(iCounter > 0) { // prefetch B(2i+1) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - if constexpr(!IsGateUpMode) + if constexpr(!IsGateUpMode) + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + else + { + if constexpr(nIter % 2 == 0) move_tile_window( b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); else - { - if constexpr(nIter % 2 == 0) - move_tile_window( - b_flat_dram_windows(nIter)(kIter), - {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - else - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride, - kIter * KFlatPerBlockPerIter}); - } + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride, + kIter * KFlatPerBlockPerIter}); + } - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); // Prefill A(2i+1) @@ -682,44 +681,44 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // GEMM 2i - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); // move B window to next flat K @@ -736,28 +735,27 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 // Next K // prefetch B(2i+2) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - if constexpr(!IsGateUpMode) + if constexpr(!IsGateUpMode) + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + else + { + if constexpr(nIter % 2 == 0) move_tile_window( b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); else - { - if constexpr(nIter % 2 == 0) - move_tile_window( - b_flat_dram_windows(nIter)(kIter), - {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - else - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride, - kIter * KFlatPerBlockPerIter}); - } + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride, + kIter * KFlatPerBlockPerIter}); + } - b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); // Prefill A(2i+2) @@ -770,43 +768,43 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // GEMM 2i+1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_pong(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_pong(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_pong(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); // move B window to next flat K @@ -827,28 +825,27 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 if constexpr(TailNum == TailNumber::Even) { // prefetch B(loopK) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - if constexpr(!IsGateUpMode) + if constexpr(!IsGateUpMode) + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + else + { + if constexpr(nIter % 2 == 0) move_tile_window( b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); else - { - if constexpr(nIter % 2 == 0) - move_tile_window( - b_flat_dram_windows(nIter)(kIter), - {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - else - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride, - kIter * KFlatPerBlockPerIter}); - } + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride, + kIter * KFlatPerBlockPerIter}); + } - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); // Prefill A(loopK) @@ -856,44 +853,44 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 store_tile(a_copy_lds_window_pong, a_block_tile_tmp); // GEMM loopK-1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); static_for<0, m_preload, 1>{}([&](auto loadIter) { @@ -906,86 +903,86 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 Last2ndHotLoopScheduler(); // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_pong(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_pong(number{})(number{})); - } - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_pong(number{})(number{})); + } + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); LastHotLoopScheduler(); } else if constexpr(TailNum == TailNumber::Odd) { // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); LastHotLoopScheduler(); } diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 23d7a9fca9..f698541dbf 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -486,13 +486,13 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}; auto c_block_tile = BlockFlatmm{}.MakeCBlockTile(); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensors(mIter)(nIter).get_thread_buffer()); - }); + static_ford>{}([&](auto mn) { + constexpr auto mIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensors(mIter)(nIter).get_thread_buffer()); }); return c_block_tile; } @@ -643,24 +643,23 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto impack) { - static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { - scale_a_tile_tensor_ping(impack)(ikpack) = load_tile_with_offset( - scale_a_dram_window, + static_ford>{}([&](auto ii) { + constexpr auto impack = number{}]>{}; + constexpr auto ikpack = number{}]>{}; + scale_a_tile_tensor_ping(impack)(ikpack) = + load_tile_with_offset(scale_a_dram_window, - impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); - }); + impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); }); // move Scale A window to next K move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); // prefetch Scale B - static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { - static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { - scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset( - scale_b_dram_window, - inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); - }); + static_ford>{}([&](auto ii) { + constexpr auto inpack = number{}]>{}; + constexpr auto ikpack = number{}]>{}; + scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset( + scale_b_dram_window, inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); }); // move Scale B window to next K move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); @@ -698,34 +697,34 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_window, - b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); - // move B window to next flat K - if constexpr(kIter == KIterPerWarp - 1) - b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( - tuple, number>{}); - }); + // move B window to next flat K + if constexpr(kIter == KIterPerWarp - 1) + b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( + tuple, number>{}); }); // prefetch Scale A and Scale B (2i+1) - static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { - static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) { - scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset( - scale_a_dram_window, - impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); - }); + static_ford>{}([&](auto ii) { + constexpr auto ikpack = number{}]>{}; + constexpr auto impack = number{}]>{}; + scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset( + scale_a_dram_window, + impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); }); - static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { - static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { - scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset( - scale_b_dram_window, - inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); - }); + static_ford>{}([&](auto ii) { + constexpr auto ikpack = number{}]>{}; + constexpr auto inpack = number{}]>{}; + scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset( + scale_b_dram_window, + inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); }); // GEMM 2i @@ -788,34 +787,34 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_window, - b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); - // move B window to next flat K - if constexpr(kIter == KIterPerWarp - 1) - b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( - tuple, number>{}); - }); + // move B window to next flat K + if constexpr(kIter == KIterPerWarp - 1) + b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( + tuple, number>{}); }); // prefetch Scale A and Scale B (2i+2) - static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { - static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) { - scale_a_tile_tensor_ping(impack)(ikpack) = load_tile_with_offset( - scale_a_dram_window, - impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); - }); + static_ford>{}([&](auto ii) { + constexpr auto ikpack = number{}]>{}; + constexpr auto impack = number{}]>{}; + scale_a_tile_tensor_ping(impack)(ikpack) = load_tile_with_offset( + scale_a_dram_window, + impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); }); - static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { - static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { - scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset( - scale_b_dram_window, - inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); - }); + static_ford>{}([&](auto ii) { + constexpr auto ikpack = number{}]>{}; + constexpr auto inpack = number{}]>{}; + scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset( + scale_b_dram_window, + inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); }); // GEMM 2i+1 @@ -888,28 +887,28 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_window, - b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); - }); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); }); // prefetch Scale A and Scale B (2i+1) - static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) { - static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { - scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset( - scale_a_dram_window, - impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); - }); + static_ford>{}([&](auto ii) { + constexpr auto impack = number{}]>{}; + constexpr auto ikpack = number{}]>{}; + scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset( + scale_a_dram_window, + impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); }); - static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { - static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { - scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset( - scale_b_dram_window, - inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); - }); + static_ford>{}([&](auto ii) { + constexpr auto inpack = number{}]>{}; + constexpr auto ikpack = number{}]>{}; + scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset( + scale_b_dram_window, + inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); }); // GEMM loopK-1 diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 53934ebcd3..c6628f66be 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -484,20 +484,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.init_logits_soft_cap(logits_soft_cap); } - // Check that the maximum offset won't overflow. - if constexpr(kPageBlockSize < FmhaPipeline::kN0) - { - if(num_total_pages > 1) - { - assert(static_cast(num_total_pages - 1) * batch_stride_k <= - static_cast(std::numeric_limits::max()) && - "KV cache K offset overflow: exceed int32 max"); - assert(static_cast(num_total_pages - 1) * batch_stride_v <= - static_cast(std::numeric_limits::max()) && - "KV cache V offset overflow: exceed int32 max"); - } - } - return kargs; } @@ -651,20 +637,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.init_logits_soft_cap(logits_soft_cap); } - // Check that the maximum offset won't overflow. - if constexpr(kPageBlockSize < FmhaPipeline::kN0) - { - if(num_total_pages > 1) - { - assert(static_cast(num_total_pages - 1) * batch_stride_k <= - static_cast(std::numeric_limits::max()) && - "KV cache K offset overflow: exceed int32 max"); - assert(static_cast(num_total_pages - 1) * batch_stride_v <= - static_cast(std::numeric_limits::max()) && - "KV cache V offset overflow: exceed int32 max"); - } - } - return kargs; } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index 6fe1de634d..8ee9b9d9b7 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -27,6 +27,7 @@ struct FmhaFwdV3Kernel using QDataType = ck_tile::remove_cvref_t; using KDataType = ck_tile::remove_cvref_t; using VDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; using LSEDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; @@ -38,6 +39,7 @@ struct FmhaFwdV3Kernel static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -118,11 +120,21 @@ struct FmhaFwdV3Kernel float logits_soft_cap_rcp; }; + struct FmhaFwdCommonQScaleKargs + { + const void* q_descale_ptr = nullptr; + const void* k_descale_ptr = nullptr; + const void* v_descale_ptr = nullptr; + }; + struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -139,7 +151,10 @@ struct FmhaFwdV3Kernel : FmhaFwdCommonKargs, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -166,6 +181,9 @@ struct FmhaFwdV3Kernel MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, void* lse_ptr, void* o_ptr, ck_tile::index_t seqlen_q, @@ -218,6 +236,7 @@ struct FmhaFwdV3Kernel nhead_stride_o}, // args for common karg {}, // placeholder for mask {}, // placeholder for lse + {}, // placeholder for qscale {}, // placeholder for logits_soft_cap batch_stride_q, batch_stride_k, @@ -237,6 +256,12 @@ struct FmhaFwdV3Kernel kargs.nhead_stride_lse = nhead_stride_lse; kargs.batch_stride_lse = batch_stride_lse; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + } if constexpr(kHasLogitsSoftCap) { kargs.init_logits_soft_cap(logits_soft_cap); @@ -252,6 +277,9 @@ struct FmhaFwdV3Kernel MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, void* lse_ptr, void* o_ptr, const void* seqstart_q_ptr, @@ -301,6 +329,7 @@ struct FmhaFwdV3Kernel nhead_stride_o}, // args for common karg {}, // placeholder for mask {}, // placeholder for lse + {}, // placeholder for qscale {}, // placeholder for logits_soft_cap reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), @@ -319,6 +348,12 @@ struct FmhaFwdV3Kernel kargs.lse_ptr = lse_ptr; kargs.nhead_stride_lse = nhead_stride_lse; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + } if constexpr(kHasLogitsSoftCap) { kargs.init_logits_soft_cap(logits_soft_cap); @@ -437,8 +472,19 @@ struct FmhaFwdV3Kernel { using namespace ck_tile; - // allocate LDS - __shared__ char smem_ptr[GetSmemSize()]; + // Notice: When using double buffering, make sure both buffers are in the same array. + // This prevents the compiler from using separate VGPRs to store the base address + // and enables the use of immediate offsets in load/store instructions. + constexpr auto smem_size_kv = + FmhaPipeline::Policy::template GetSmemSizeKV(); + __shared__ char smem_k[2][smem_size_kv]; + __shared__ char smem_v[2][smem_size_kv]; + + auto* smem_k0 = reinterpret_cast(smem_k[0]); + auto* smem_k1 = reinterpret_cast(smem_k[1]); + auto* smem_v0 = reinterpret_cast(smem_v[0]); + auto* smem_v1 = reinterpret_cast(smem_v[1]); + ; // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); @@ -640,32 +686,88 @@ struct FmhaFwdV3Kernel return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; }(); + const float scale_s = [&] { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + float q_descale = *(reinterpret_cast(kargs.q_descale_ptr)); + float k_descale = *(reinterpret_cast(kargs.k_descale_ptr)); + return kargs.scale_s * q_descale * k_descale; + } + else + { + return kargs.scale_s; + } + }(); + AttentionVariant variant; const auto variant_params = [&] { if constexpr(kHasLogitsSoftCap) { return ck_tile::LogitsSoftCapParams{ - mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + mask, scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; } else { - return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + return ck_tile::StandardAttentionParams{mask, scale_s}; } }(); BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; auto o_acc_tile = [&]() { - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - lse_dram_window, - mask, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr); + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + float v_descale = *(reinterpret_cast(kargs.v_descale_ptr)); + float scale_p = ck_tile::type_convert(ck_tile::numeric::max()); + float scale_o = v_descale / scale_p; + + auto o_acc_element_func = [&]() { + if constexpr(std::is_same_v) + return make_composes( + ck_tile::saturates{}, + ck_tile::scales>{scale_o}); + else + return ck_tile::scales>{scale_o}; + }(); + + return FmhaPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales>{scale_p}, // p_compute_element_func + o_acc_element_func, + mask, + scale_s, + variant, + variant_params, + block_indices, + smem_k0, + smem_k1, + smem_v0, + smem_v1); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + lse_dram_window, + mask, + scale_s, + variant, + variant_params, + block_indices, + smem_k0, + smem_k1, + smem_v0, + smem_v1); + } }(); // O DRAM and O DRAM window diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index e67a525ac4..bb3fa8c411 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -1706,22 +1706,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - p_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + p_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); #if defined(__gfx11__) - PermuteWarpGemmCToA(pt_warp_tensor, p_warp_tensor); + PermuteWarpGemmCToA(pt_warp_tensor, p_warp_tensor); #else - pt_warp_tensor.get_thread_buffer() = p_warp_tensor.get_thread_buffer(); + pt_warp_tensor.get_thread_buffer() = p_warp_tensor.get_thread_buffer(); #endif - pt_out.set_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), - pt_warp_tensor.get_thread_buffer()); - }); + pt_out.set_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), + pt_warp_tensor.get_thread_buffer()); }); } else @@ -1763,22 +1763,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - ds_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + ds_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); #if defined(__gfx11__) - PermuteWarpGemmCToA(dst_warp_tensor, ds_warp_tensor); + PermuteWarpGemmCToA(dst_warp_tensor, ds_warp_tensor); #else - dst_warp_tensor.get_thread_buffer() = ds_warp_tensor.get_thread_buffer(); + dst_warp_tensor.get_thread_buffer() = ds_warp_tensor.get_thread_buffer(); #endif - dst_out.set_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), - dst_warp_tensor.get_thread_buffer()); - }); + dst_out.set_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), + dst_warp_tensor.get_thread_buffer()); }); } else diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index 463f149a65..ac868ce4b8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -24,183 +24,201 @@ #define CK_TILE_DISABLE_PACKED_FP32 0 #endif -#define WARP_ID 0 -#define LANE_ID 0 - -#define ENABLE_DEBUG_STMTS 1 -#if ENABLE_DEBUG_STMTS -#define DEBUG_STMTS \ - if(get_block_1d_id() == 0 && get_warp_id() == WARP_ID && get_lane_id() == LANE_ID) -#else -#define DEBUG_STMTS if constexpr(false) -#endif - namespace ck_tile { -template -struct CoreLoopScheduler; +// --------------------------------------------------------------------------- +// block_gemm_mfma_count_v: number of hardware MFMA instructions issued per +// warp in one full BlockGemm call. +// +// warp gemm calls = MIterPerWarp * NIterPerWarp * KIterPerWarp +// MFMAs per call = WarpGemm::kK / WarpGemm::WarpGemmAttribute::Impl::kK (kKIter) +// +// For bf16/fp16 kKIter=1; for fp8 kKIter=2 (K=32 warp gemm wraps 2× K=16 MFMA). +// --------------------------------------------------------------------------- +template +static constexpr ck_tile::index_t block_gemm_mfma_count_v = + BlockGemm::MIterPerWarp * BlockGemm::NIterPerWarp * BlockGemm::KIterPerWarp * + (BlockGemm::WarpGemm::kK / BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK); -template -struct CoreLoopScheduler +// --------------------------------------------------------------------------- +// CoreLoopSchedulingParams: auto-derived instruction counts from tile/gemm config +// --------------------------------------------------------------------------- +template +struct CoreLoopSchedulingParams { + using QKBlockGemm = + ck_tile::remove_cvref_t())>; + using PVBlockGemm = + ck_tile::remove_cvref_t())>; + + static constexpr ck_tile::index_t kMfmaPerWarpGemm0 = block_gemm_mfma_count_v; + static constexpr ck_tile::index_t kMfmaPerWarpGemm1 = block_gemm_mfma_count_v; + + static constexpr bool kIsMasking = PipelineProblem::FmhaMask::IsMasking; +}; + +// --------------------------------------------------------------------------- +// CoreLoopSchedulerDefaultBase: reusable phase helpers (bf16/fp16 pattern) +// --------------------------------------------------------------------------- +template +struct CoreLoopSchedulerDefaultBase +{ + using Params = CoreLoopSchedulingParams; + + // Phase helper: GEMM0 compute (QK matmul) — MFMA interleaved with TRANS + VALU + CK_TILE_DEVICE static constexpr void schedule_gemm0_compute() + { + static_for<0, Params::kMfmaPerWarpGemm0, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::TRANS, 2, 0); + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); + }); + } + + // Phase helper: GEMM1 compute (PV matmul) — optional packed-FP32 preamble + MFMA/VALU + CK_TILE_DEVICE static constexpr void schedule_gemm1_compute() + { +#if !CK_TILE_DISABLE_PACKED_FP32 + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 4, 0); +#endif + static_for<0, Params::kMfmaPerWarpGemm1, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 4, 0); + }); + } + + // Phase helper: load phase (memory/LDS loads) — VALU + SALU + CK_TILE_DEVICE static constexpr void schedule_load_phase() + { + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::SALU, 4, 0); + } + + // Compose phases via WG0/WG1 phase-shift pattern: + // WG0: compute0(P0), load(P1), compute1(P2), load(P3) + // WG1: load(P0), compute0(P1), load(P2), compute1(P3) template CK_TILE_DEVICE static constexpr void schedule(ck_tile::number, ck_tile::number) { - using namespace ck_tile; + // WG1 is shifted by 3 phases (equivalently, -1 mod 4) relative to WG0 + constexpr ck_tile::index_t effective = (WaveGroup == 0) ? Phase : (Phase + 3) % 4; - if constexpr(WaveGroup == 0) - { - if constexpr(Phase == 0) - { - static_for<0, 8, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - }); - } - else if constexpr(Phase == 1) - { - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU - } - else if constexpr(Phase == 2) - { -#if !CK_TILE_DISABLE_PACKED_FP32 - __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU -#endif - static_for<0, 8, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU - }); - } - else if constexpr(Phase == 3) - { - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU - } - } + if constexpr(effective == 0) + schedule_gemm0_compute(); + else if constexpr(effective == 2) + schedule_gemm1_compute(); else - { - if constexpr(Phase == 0) - { - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU - } - else if constexpr(Phase == 1) - { - static_for<0, 8, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - }); - } - else if constexpr(Phase == 2) - { - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU - } - else if constexpr(Phase == 3) - { -#if !CK_TILE_DISABLE_PACKED_FP32 - __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU -#endif - static_for<0, 8, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU - }); - } - } + schedule_load_phase(); } }; +// --------------------------------------------------------------------------- +// CoreLoopSchedulerImpl: dtype-specialized dispatch +// --------------------------------------------------------------------------- +template +struct CoreLoopSchedulerImpl; + +// bf16 — uses default base template -struct CoreLoopScheduler +struct CoreLoopSchedulerImpl + : CoreLoopSchedulerDefaultBase { +}; + +// fp16 — uses default base +template +struct CoreLoopSchedulerImpl + : CoreLoopSchedulerDefaultBase +{ +}; + +// fp8 — asymmetric GEMM0 scheduling for 2× K iterations +// +// FP8 GEMM0 has 16 MFMAs (kKIter=2) but the same TRANS work as bf16/fp16 (softmax +// exp count is dtype-independent). The uniform (MFMA:1, TRANS:2, VALU:2) pattern +// causes the compiler to front-load all 32 TRANS into MFMA #1, leaving MFMAs #2-8 +// with nothing to interleave (7 back-to-back MFMAs). +// +// Fix: split into two halves matching the natural K iteration boundary: +// K iter 0 (MFMAs 1-8): TRANS-heavy — softmax exp + add reduction chain +// K iter 1 (MFMAs 9-16): VALU-heavy — P scale + cvt_pk_fp8 + o_acc rescale +template +struct CoreLoopSchedulerImpl + : CoreLoopSchedulerDefaultBase +{ + using Base = CoreLoopSchedulerDefaultBase; + using Params = typename Base::Params; + + CK_TILE_DEVICE static constexpr void schedule_gemm0_compute() + { + // K iter 0: 32 TRANS (v_exp_f32) + ~33 VALU (v_add reduction + permlane) + static_for<0, Params::kMfmaPerWarpGemm0 / 2, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::TRANS, 4, 0); + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 4, 0); + }); + // K iter 1: ~58 VALU (v_mul scale + v_cvt_pk_fp8 + o_acc rescale) + static_for<0, Params::kMfmaPerWarpGemm0 / 2, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 6, 0); + }); + } + + // Phase helper: GEMM1 compute (PV matmul) — asymmetric for fmha_alu0 data dependency + // + // fmha_alu0 runs during PV GEMM on the OTHER sp buffer: + // v_perm (byte packing) + v_max3 (row max) + permlane + v_fma (sp_delta) + // + // The v_fma chain depends on the serial max3→permlane→max→mul chain, creating + // a data dependency gap around MFMAs 8-11. Use a looser VALU constraint for the + // second half to give the scheduler freedom to place v_fma where available. + CK_TILE_DEVICE static constexpr void schedule_gemm1_compute() + { +#if !CK_TILE_DISABLE_PACKED_FP32 + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 4, 0); +#endif + // First half: v_perm + v_max3 + permlane chain (~29 VALU) + static_for<0, Params::kMfmaPerWarpGemm1 / 2, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 4, 0); + }); + // Second half: v_fma chain (~33 VALU, data-dep limited at start) + static_for<0, Params::kMfmaPerWarpGemm1 / 2, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 3, 0); + }); + } + + // Must override schedule() — static methods have no virtual dispatch template CK_TILE_DEVICE static constexpr void schedule(ck_tile::number, ck_tile::number) { - using namespace ck_tile; + constexpr ck_tile::index_t effective = (WaveGroup == 0) ? Phase : (Phase + 3) % 4; - if constexpr(WaveGroup == 0) - { - if constexpr(Phase == 0) - { - static_for<0, 8, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - }); - } - else if constexpr(Phase == 1) - { - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU - } - else if constexpr(Phase == 2) - { -#if !CK_TILE_DISABLE_PACKED_FP32 - __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU -#endif - static_for<0, 8, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU - }); - } - else if constexpr(Phase == 3) - { - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU - } - } + if constexpr(effective == 0) + schedule_gemm0_compute(); + else if constexpr(effective == 2) + schedule_gemm1_compute(); else - { - if constexpr(Phase == 0) - { - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU - } - else if constexpr(Phase == 1) - { - static_for<0, 8, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - }); - } - else if constexpr(Phase == 2) - { - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU - } - else if constexpr(Phase == 3) - { -#if !CK_TILE_DISABLE_PACKED_FP32 - __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU -#endif - static_for<0, 8, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU - }); - } - } + Base::schedule_load_phase(); } }; +// --------------------------------------------------------------------------- +// CoreLoopScheduler: user-facing template, delegates to dtype-specialized impl +// --------------------------------------------------------------------------- +template +struct CoreLoopScheduler : CoreLoopSchedulerImpl +{ +}; + namespace detail { -CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c) -{ -#if CK_TILE_DISABLE_PACKED_FP32 - return a * b + c; -#else - float result; - asm volatile("v_fma_f32 %[result], %[a], %[b], %[c]" - : [result] "=v"(result) - : [a] "v"(a), [b] "s"(b), [c] "v"(c)); - return result; -#endif -} +CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c) { return a * b + c; } CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs) { @@ -237,6 +255,19 @@ CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs) : [lhs] "v"(lhs), [rhs] "v"(rhs)); return result; } + +/// FP8 packed conversion with asm volatile to prevent code sinking. +/// This anchors the conversion instruction in Phase 0, and all predecessor +/// instructions (scale, saturate, NaN check) will automatically stay in Phase 0. +/// v_cvt_pk_fp8_f32 packs two FP8 values into lower 16 bits of a 32-bit VGPR. +CK_TILE_DEVICE uint32_t cvt_pk_fp8_f32(float a, float b) +{ + uint32_t result; + asm volatile("v_cvt_pk_fp8_f32 %[result], %[a], %[b]" + : [result] "=v"(result) + : [a] "v"(a), [b] "v"(b)); + return result; +} } // namespace detail /// NOTICE: This pipeline is a work in progress and is awaiting upcoming compiler fixes and @@ -290,10 +321,9 @@ struct BlockFmhaFwdV3Pipeline static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ; - static_assert((BiasEnum == BlockAttentionBiasEnum::NO_BIAS && !kStoreLSE && !kHasDropout && - (QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) && - !kSkipMinSeqlenQ), + static_assert((BiasEnum == BlockAttentionBiasEnum::NO_BIAS && !kHasDropout && !kSkipMinSeqlenQ), "enable unsupported features"); + // HACK: Removed !kStoreLSE check to allow BF16 V3 compilation for assembly analysis // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -318,35 +348,7 @@ struct BlockFmhaFwdV3Pipeline CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - // create another LDS buffer for p - return ck_tile::max(kM0 * kN1 * sizeof(PDataType), - Policy::template GetSmemSize() + - kM0 * kN0 * sizeof(PDataType)); - } - - // for debug only - template - CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc() - { - using namespace ck_tile; - constexpr auto lds_block_desc = - make_naive_tensor_descriptor(make_tuple(number{}, number{}), - make_tuple(number{}, number<1>{}), - number<1>{}, - number<1>{}); - - return lds_block_desc; - } - - // for debug only - template - CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc1D() - { - using namespace ck_tile; - constexpr auto lds_block_desc = make_naive_tensor_descriptor( - make_tuple(number{}), make_tuple(number<1>{}), number<1>{}, number<1>{}); - - return lds_block_desc; + return Policy::template GetSmemSize(); } template @@ -359,29 +361,6 @@ struct BlockFmhaFwdV3Pipeline return make_tile_window(tensor_view, desc.get_lengths(), {0, 0}); } - // vmcnt=0~63, lgkmcnt=0~15, expcnt=0~7 - template - CK_TILE_DEVICE static constexpr void s_waitcnt() - { - // vmcnt use bits {[15:14],[3:0]} - // expcnt use bits [6:4] - // lgkmcnt use bits [11:8] - __builtin_amdgcn_s_waitcnt((((0b110000 & Vmcnt) << (14 - 4)) | (0b1111 & Vmcnt)) | - ((0b111 & Expcnt) << 4) | ((0b1111 & Lgkmcnt) << 8)); - } - - template - CK_TILE_DEVICE static constexpr void s_waitcnt_vmcnt() - { - s_waitcnt(); - } - - template - CK_TILE_DEVICE static constexpr void s_waitcnt_lgkmcnt() - { - s_waitcnt<63, Lgkmcnt>(); - } - template - CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - [[maybe_unused]] const KElementFunction& k_element_func, - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - [[maybe_unused]] const VElementFunction& v_element_func, - LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile - const LSEElementFunction& lse_element_func, - [[maybe_unused]] const SAccElementFunction& s_acc_element_func, - const PComputeElementFunction& p_compute_element_func, - const OAccElementFunction& o_acc_element_func, - FmhaMask mask, - float scale_s, - const AttentionVariant& variant, - const AttentionVariantParams& variant_params, - const BlockIndices& block_indices, - void* smem_ptr) const + CK_TILE_DEVICE auto + operator()(const QDramBlockWindowTmp& __restrict__ q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& __restrict__ k_dram_block_window_tmp, // N0*K0 tile + [[maybe_unused]] const KElementFunction& k_element_func, + const VDramBlockWindowTmp& __restrict__ v_dram_block_window_tmp, // N1*K1 tile + [[maybe_unused]] const VElementFunction& v_element_func, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + [[maybe_unused]] const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + KDataType* __restrict__ smem_k0, + KDataType* __restrict__ smem_k1, + VDataType* __restrict__ smem_v0, + VDataType* __restrict__ smem_v1) const { using namespace ck_tile; @@ -428,33 +411,6 @@ struct BlockFmhaFwdV3Pipeline kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize()); - auto s_lds = make_tensor_view( - reinterpret_cast(static_cast(smem_ptr)), - MakeSimpleLdsDesc()); - [[maybe_unused]] auto s_lds_window = - make_tile_window(s_lds, make_tuple(number{}, number{}), {0, 0}); - - auto p_lds = make_tensor_view( - reinterpret_cast(static_cast(smem_ptr) + - Policy::template GetSmemSize()), - MakeSimpleLdsDesc()); - [[maybe_unused]] auto p_lds_window = - make_tile_window(p_lds, make_tuple(number{}, number{}), {0, 0}); - - auto o_lds = make_tensor_view( - reinterpret_cast(static_cast(smem_ptr)), - MakeSimpleLdsDesc()); - [[maybe_unused]] auto o_lds_window = - make_tile_window(o_lds, make_tuple(number{}, number{}), {0, 0}); - - auto m_lds = make_tensor_view( - reinterpret_cast(static_cast(smem_ptr) + - Policy::template GetSmemSize()), - MakeSimpleLdsDesc1D()); - [[maybe_unused]] auto m_lds_window = - make_tile_window(m_lds, make_tuple(number{}), {0}); - const index_t warp_group_id = get_warp_id() / 4; // Block GEMM @@ -469,16 +425,18 @@ struct BlockFmhaFwdV3Pipeline const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; auto k_lds_window_store = generate_tuple( - [&](auto i_buf) { + [&](auto write_idx) { + auto k_buf = (write_idx == 0 ? smem_k0 : smem_k1); return make_lds_tile_window( - smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)); + k_buf, Policy::template MakeKLdsStoreBlockDescriptor()); }, number<2>{}); auto v_lds_window_store = generate_tuple( - [&](auto i_buf) { - return make_lds_tile_window( - smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor(i_buf)); + [&](auto write_idx) { + auto v_buf = (write_idx == 0 ? smem_v0 : smem_v1); + return make_lds_tile_window( + v_buf, Policy::template MakeVLdsStoreBlockDescriptor()); }, number<2>{}); @@ -521,9 +479,11 @@ struct BlockFmhaFwdV3Pipeline statically_indexed_array sp; decltype(gemm_1.MakeCBlockTile()) o_acc; - constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd() - // instructions should we move to fmha_alu1() - static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size()); + constexpr index_t fmha_alu_D_reg_cnt = + 6; // Threshold for determining how many fmha_alu_D_upd() unpacked + // instructions to relocate to fmha_alu1(). + static_assert(fmha_alu_D_reg_cnt % 2 == 0 && + fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size()); decltype(block_tile_reduce( sp(number<0>{}).sp_compute, sequence<1>{}, f_max, SMPLComputeDataType{0})) m; @@ -531,18 +491,27 @@ struct BlockFmhaFwdV3Pipeline // initialize k_lds_window and v_lds_window static_for<0, 2, 1>{}([&](auto idx) { - k_lds_window_load(idx) = make_tile_window( - make_lds_tile_window( - static_cast(smem_ptr) + (idx)*Policy::template GetSmemSizeKV(), - Policy::template MakeKLdsLoadBlockDescriptor()), - Policy::template MakeKRegTileDistribution()); + k_lds_window_load(idx) = + make_tile_window(make_lds_tile_window( + [&] { + if constexpr(idx == 0) + return smem_k0; + else + return smem_k1; + }(), + Policy::template MakeKLdsLoadBlockDescriptor()), + Policy::template MakeKRegTileDistribution()); }); static_for<0, 2, 1>{}([&](auto idx) { v_lds_window_load(idx) = make_tile_window(make_lds_tile_window( - static_cast(smem_ptr) + - (idx + 2) * Policy::template GetSmemSizeKV(), + [&] { + if constexpr(idx == 0) + return smem_v0; + else + return smem_v1; + }(), Policy::template MakeVLdsLoadBlockDescriptor()), Policy::template MakeVRegTileDistribution()); }); @@ -591,14 +560,12 @@ struct BlockFmhaFwdV3Pipeline k_dram_block_window_tmp.get_window_lengths(), {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); - k_dram_window.init_raw(); auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), {seqlen_k_start, 0}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); - v_dram_window.init_raw(); // prefetch K tile index_t i_total_loops = 0; @@ -611,86 +578,13 @@ struct BlockFmhaFwdV3Pipeline constexpr index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup; static_assert(NumWarpGroups == 2); - [[maybe_unused]] auto print_dist_tensor = [&](const auto& dist_tensor, const char* name) { - printf("[POYENC] %s (size=%d): %5.2f", - name, - decltype(dist_tensor.thread_buf_)::size(), - ck_tile::type_convert(dist_tensor.thread_buf_[0])); - static_for<1, decltype(dist_tensor.thread_buf_)::size(), 1>{}([&](auto i) { - printf(", %5.2f", ck_tile::type_convert(dist_tensor.thread_buf_[i])); - }); - printf("\n"); - }; - - [[maybe_unused]] auto print_lds = [&](auto lds_tile_window, const char* name) { - const auto num_rows = lds_tile_window.get_window_lengths().at(number<0>{}); - const auto num_cols = lds_tile_window.get_window_lengths().at(number<1>{}); - - auto desc = lds_tile_window.get_bottom_tensor_view().desc_; - auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_; - - if constexpr(true || num_rows < num_cols) - { - for(int row = 0; row < num_rows; ++row) - { - int offset = desc.calculate_offset(make_tuple(row, 0)); - printf("[DEVICE] %s[%3d] = %5.2f", - name, - row, - ck_tile::type_convert(data[offset])); - for(int col = 1; col < num_cols; ++col) - { - printf(", "); - offset = desc.calculate_offset(make_tuple(row, col)); - printf("%5.2f", ck_tile::type_convert(data[offset])); - } - printf("\n"); - } - } - else - { - for(int col = 0; col < num_cols; ++col) - { - int offset = desc.calculate_offset(make_tuple(0, col)); - printf("[DEVICE] %s[%3d] = %5.2f", - name, - col, - ck_tile::type_convert(data[offset])); - for(int row = 1; row < num_rows; ++row) - { - printf(", "); - offset = desc.calculate_offset(make_tuple(row, col)); - printf("%5.2f", ck_tile::type_convert(data[offset])); - } - printf("\n"); - } - } - }; - - [[maybe_unused]] auto print_lds_1d = [&](auto lds_tile_window, const char* name) { - const auto num_elems = lds_tile_window.get_window_lengths().at(number<0>{}); - - auto desc = lds_tile_window.get_bottom_tensor_view().desc_; - auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_; - - int offset = desc.calculate_offset(make_tuple(0)); - printf("[DEVICE] %s = %5.2f", name, ck_tile::type_convert(data[offset])); - for(int e = 1; e < num_elems; ++e) - { - printf(", "); - offset = desc.calculate_offset(make_tuple(e)); - printf("%5.2f", ck_tile::type_convert(data[offset])); - } - printf("\n"); - }; - // K_mem_su_ld_insts = 1 for 32 x 128 // V_mem_su_ld_insts = 1 for 128 x 32 constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access(); constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access(); auto K_mem_load = [&](auto k_lds_write_idx) { - async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); + async_load_tile(k_lds_window_store(k_lds_write_idx), k_dram_window); /// FIXME: use the future-predicting method to move the window // move K tile windows @@ -702,7 +596,7 @@ struct BlockFmhaFwdV3Pipeline }; auto V_mem_load = [&](auto v_lds_write_idx) { - async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); + async_load_tile(v_lds_window_store(v_lds_write_idx), v_dram_window); /// FIXME: use the future-predicting method to move the window move_tile_window(v_dram_window, {kK1, 0}); @@ -735,24 +629,8 @@ struct BlockFmhaFwdV3Pipeline auto fmha_alu0 = [&](auto sp_reg_idx) { m_old = m; // m{j-1} - static_assert(m.thread_buf_.size() == 1, - "assuming that each thread holds 1 rowmax value"); - auto m_latest = block_tile_reduce( - sp(sp_reg_idx).sp_compute, sequence<1>{}, f_max, m.thread_buf_[0]); -#if defined(__gfx950__) - // assuming that we are using 32x32 mfma - int32x2_t swapped_regs = - __builtin_amdgcn_permlane32_swap(bit_cast(m_latest.thread_buf_[0]), - bit_cast(m_latest.thread_buf_[0]), - false, - false); - /// TODO: eliminate 2 redudant v_max_f32 instructions generated by the compiler - m_latest.thread_buf_[0] = f_max(bit_cast(swapped_regs.x), - bit_cast(swapped_regs.y)); -#else - block_tile_reduce_sync(m_latest, f_max, bool_constant{}); -#endif - m = m_latest; + block_tile_reduce(m, sp(sp_reg_idx).sp_compute, sequence<1>{}, f_max); + block_tile_reduce_sync(m, f_max, bool_constant{}, bool_constant{}); constexpr auto p_spans = std::decay_t::get_distributed_spans(); @@ -771,7 +649,8 @@ struct BlockFmhaFwdV3Pipeline } }); }); - /// TODO: move some fmha_alu1() code here if necessary + /// NOTE: moving exp2(sp_delta) here was explored and reverted (~1.1% regression). + /// See session.md for details. }; auto fmha_alu1 = [&](auto sp_reg_idx) { @@ -790,20 +669,7 @@ struct BlockFmhaFwdV3Pipeline sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) - static_assert(rowsum_p.thread_buf_.size() == 1, - "assuming that each thread holds 1 rowsum value"); -#if defined(__gfx950__) - // assuming that we are using 32x32 mfma - int32x2_t swapped_regs = - __builtin_amdgcn_permlane32_swap(bit_cast(rowsum_p.thread_buf_[0]), - bit_cast(rowsum_p.thread_buf_[0]), - false, - false); - rowsum_p.thread_buf_[0] = f_sum(bit_cast(swapped_regs.x), - bit_cast(swapped_regs.y)); -#else - block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); -#endif + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}, bool_constant{}); // l{j} /// Note: The compiler keeps moving the following instructions elsewhere because 'l' @@ -845,12 +711,26 @@ struct BlockFmhaFwdV3Pipeline sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; } - else + else if constexpr(std::is_same_v) { auto casted = ck_tile::cvt_pk_bf16_f32(x, y); sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; } + else if constexpr(std::is_same_v) + { + // Use asm volatile wrapper to prevent code sinking + // v_cvt_pk_fp8_f32 packs two FP8 into lower 16 bits of 32-bit result + uint32_t packed = detail::cvt_pk_fp8_f32(x, y); + sp(sp_reg_idx).p.thread_buf_[idx] = + bit_cast(static_cast(packed & 0xFF)); + sp(sp_reg_idx).p.thread_buf_[idx + 1] = + bit_cast(static_cast((packed >> 8) & 0xFF)); + } + else + { + static_assert(false, "unsupported data type for P"); + } }); /// Note: Place fmha_alu1() at the end of the phase. The surrounding inline assembly @@ -907,7 +787,14 @@ struct BlockFmhaFwdV3Pipeline } }; - auto fmha_alu_D_upd = [&] { + // Number of o_acc registers rescaled with unpacked (scalar) v_mul_f32 before the + // scheduler, so the compiler can interleave them with MFMA tail slots. The remaining + // registers are rescaled with packed v_pk_mul_f32 (asm volatile, invisible to the + // scheduler) after the scheduler. Set to 0 to use packed multiply for all registers + // beyond fmha_alu_D_reg_cnt; increase to feed the scheduler more visible VALU work. + constexpr index_t num_unpack_insts = 0; + fp32x2_t pk_o_acc_scale; + auto fmha_alu_D_upd_unpack = [&] { o_acc_scale = [&] { if constexpr(kHasLogitsSoftCap) { @@ -919,28 +806,20 @@ struct BlockFmhaFwdV3Pipeline } }(); - fp32x2_t pk_o_acc_scale; + static_assert(num_unpack_insts % 2 == 0 && + (fmha_alu_D_reg_cnt + num_unpack_insts) <= o_acc.thread_buf_.size()); + static_for{}( + [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; }); pk_o_acc_scale.x = o_acc_scale; pk_o_acc_scale.y = o_acc_scale; + }; - static_assert((o_acc.thread_buf_.size() - fmha_alu_D_reg_cnt) % 2 == 0); -#if CK_TILE_DISABLE_PACKED_FP32 - static_assert(fmha_alu_D_reg_cnt + 2 <= o_acc.thread_buf_.size()); - static_for{}( - [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; }); -#endif - - constexpr auto issued_D_reg_cnt = -#if CK_TILE_DISABLE_PACKED_FP32 - fmha_alu_D_reg_cnt + 2 -#else - fmha_alu_D_reg_cnt -#endif - ; + auto fmha_alu_D_upd_pack = [&] { + constexpr index_t issued_unpack_insts = fmha_alu_D_reg_cnt + num_unpack_insts; /// NOTICE: Use inline asm v_pk_mul_f32 to reduce latency. The fmha_alu_D_upd() call /// should be placed at the end of a phase. - // update partial o_acc after [issued_D_reg_cnt] - static_for{}([&](auto idx) { + // update partial o_acc after [issued_unpack_insts] + static_for{}([&](auto idx) { fp32x2_t input; input.x = o_acc.thread_buf_[idx]; input.y = o_acc.thread_buf_[idx + 1]; @@ -952,6 +831,11 @@ struct BlockFmhaFwdV3Pipeline }); }; + auto fmha_alu_D_upd = [&] { + fmha_alu_D_upd_unpack(); + fmha_alu_D_upd_pack(); + }; + auto fmha_mask = [&](auto sp_reg_idx) { if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { @@ -996,7 +880,7 @@ struct BlockFmhaFwdV3Pipeline auto memV = number<0>{}; auto memK = number<1>{}; - using Scheduler = CoreLoopScheduler; + using Scheduler = CoreLoopScheduler; auto iteration = [&](auto pi) { auto xdl_SP_p01_reg_idx = number<1>{} - pi; @@ -1030,7 +914,7 @@ struct BlockFmhaFwdV3Pipeline { ASM_MARKER("phase0 Wave0-3 (pi=1)"); } - s_waitcnt_lgkmcnt<0>(); + s_waitcnt(); __builtin_amdgcn_sched_barrier(0); cl_calc(xdl_SP_p01_reg_idx, gemm0); fmha_alu1(xdl_SP_p23_reg_idx); @@ -1040,7 +924,7 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_sched_barrier(0); // phase1 ASM_MARKER("phase1 Wave0-3"); - s_waitcnt_vmcnt(); + s_waitcnt(); __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); @@ -1051,22 +935,22 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_sched_barrier(0); // phase2 ASM_MARKER("phase2 Wave0-3"); - s_waitcnt_lgkmcnt<0>(); + s_waitcnt(); __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); asm volatile("s_nop 0"); __builtin_amdgcn_sched_barrier(0); cl_calc(xdl_SP_p23_reg_idx, gemm1); - + fmha_alu_D_upd_unpack(); Scheduler::schedule(cl_p, number<2>{}); __builtin_amdgcn_sched_barrier(0); - fmha_alu_D_upd(); + fmha_alu_D_upd_pack(); __builtin_amdgcn_sched_barrier(0); // phase3 ASM_MARKER("phase3 Wave0-3"); - s_waitcnt_vmcnt(); + s_waitcnt(); __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); @@ -1101,7 +985,7 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_sched_barrier(0); // phase1 ASM_MARKER("phase1 Wave4-7"); - s_waitcnt(); + s_waitcnt(); __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); @@ -1130,17 +1014,17 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_sched_barrier(0); // phase3 ASM_MARKER("phase3 Wave4-7"); - s_waitcnt(); + s_waitcnt(); __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); asm volatile("s_nop 1"); __builtin_amdgcn_sched_barrier(0); cl_calc(xdl_SP_p23_reg_idx, gemm1); - + fmha_alu_D_upd_unpack(); Scheduler::schedule(cl_p, number<3>{}); __builtin_amdgcn_sched_barrier(0); - fmha_alu_D_upd(); + fmha_alu_D_upd_pack(); } return result; }; @@ -1153,18 +1037,18 @@ struct BlockFmhaFwdV3Pipeline if(1 < num_total_loop) { - s_waitcnt_vmcnt(); + s_waitcnt(); } else { - s_waitcnt_vmcnt<0>(); + s_waitcnt<0>(); } __builtin_amdgcn_s_barrier(); V_lds_load(V_lds_rd_idx); fmha_alu1(ps_pi); - s_waitcnt_lgkmcnt<0>(); + s_waitcnt(); auto xdl_SP_p23_reg_idx = ps_pi; gemm(xdl_SP_p23_reg_idx, /*gemm_idx=*/number<1>{}); @@ -1176,12 +1060,12 @@ struct BlockFmhaFwdV3Pipeline // (1) load K0 to LDS & VGPR K_mem_load(number<0>{}); // mem_K0 - s_waitcnt_vmcnt<0>(); + s_waitcnt<0>(); __builtin_amdgcn_s_barrier(); K_lds_load(number<0>{}); // lds_K0 - s_waitcnt_lgkmcnt<0>(); + s_waitcnt(); __builtin_amdgcn_s_barrier(); // (2) prefetch K1 and V0 to LDS in parallel with GEMM0 @@ -1209,11 +1093,12 @@ struct BlockFmhaFwdV3Pipeline if(2 < num_total_loop) { K_mem_load(number<0>{}); // mem_K2 - - s_waitcnt_vmcnt(); - __builtin_amdgcn_s_barrier(); } + // drain K1 + V0 async loads before core_loop reads K1 from LDS + s_waitcnt(); + __builtin_amdgcn_s_barrier(); + ASM_MARKER("end pre-stage"); } @@ -1291,16 +1176,20 @@ struct BlockFmhaFwdV3Pipeline typename LSEDramBlockWindowTmp, typename AttentionVariantParams, typename BlockIndices> - CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile - FmhaMask mask, - float scale_s, - const AttentionVariant& variant, - const AttentionVariantParams& variant_params, - const BlockIndices& block_indices, - void* smem_ptr) const + CK_TILE_DEVICE auto + operator()(const QDramBlockWindowTmp& __restrict__ q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& __restrict__ k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& __restrict__ v_dram_block_window_tmp, // N1*K1 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + KDataType* __restrict__ smem_k0, + KDataType* __restrict__ smem_k1, + VDataType* __restrict__ smem_v0, + VDataType* __restrict__ smem_v1) const { using namespace ck_tile; @@ -1320,7 +1209,10 @@ struct BlockFmhaFwdV3Pipeline variant, variant_params, block_indices, - smem_ptr); + smem_k0, + smem_k1, + smem_v0, + smem_v1); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp index ce097b6741..a6b21ac555 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp @@ -239,10 +239,18 @@ struct BlockFmhaV3PipelineDefaultPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0WarpTile>>; - constexpr auto warp_gemm = []() { - if constexpr(std::is_same_v && - std::is_same_v && + constexpr auto warp_gemm = [] { + if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + // Use SwizzleB variant to get 8 contiguous K positions per lane, + // matching the V tile distribution for PV GEMM + return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) { /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use /// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here @@ -310,9 +318,8 @@ struct BlockFmhaV3PipelineDefaultPolicy static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords - template - CK_TILE_DEVICE static constexpr auto - MakeKLdsStoreBlockDescriptor(ck_tile::number = ck_tile::number<0>{}) + template + CK_TILE_DEVICE static constexpr auto MakeKLdsStoreBlockDescriptor() { using namespace ck_tile; @@ -323,7 +330,6 @@ struct BlockFmhaV3PipelineDefaultPolicy constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; constexpr index_t WarpSize = ck_tile::get_warp_size(); - [[maybe_unused]] constexpr index_t KPack = GetSmemKPackK(); // this is for lds constexpr index_t KVector = GetAlignmentK(); // this is for global load constexpr index_t kPad = kKLdsPadInBytes / @@ -339,31 +345,28 @@ struct BlockFmhaV3PipelineDefaultPolicy constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); - constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( - make_tuple(number{}, // n0 - number{}, // n1 - number{}, // n2 - number{}, // k0 - number{}), // k1 - make_tuple(number{}, - number{}, - number{}, - number{}, - number<1>{}), - number()>{}, - number{}, - number<1>{}); + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, // n0 + number{}, // n1 + number{}, // n2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); - // TODO this layout is hard coded, and will be used in async copy buffer view load - // in LDS the real layout is (bufs, N0, N2, N1*K0*K1) + // CRITICAL: Must match Load descriptor merge pattern (NumIssues, LaneGroups, NumWarps) constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( k_lds_block_desc_0, - make_tuple(make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_merge_transform(make_tuple( - number{}, number{}, number{}))), - make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + make_tuple(make_merge_transform(make_tuple( + number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); return k_lds_block_desc_issues_warps_lanes; } @@ -458,9 +461,8 @@ struct BlockFmhaV3PipelineDefaultPolicy return max(SingleKSize, SingleVSize); } - template - CK_TILE_DEVICE static constexpr auto - MakeVLdsStoreBlockDescriptor(ck_tile::number = ck_tile::number<0>{}) + template + CK_TILE_DEVICE static constexpr auto MakeVLdsStoreBlockDescriptor() { using namespace ck_tile; @@ -471,7 +473,6 @@ struct BlockFmhaV3PipelineDefaultPolicy constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; constexpr index_t WarpSize = ck_tile::get_warp_size(); - [[maybe_unused]] constexpr index_t KPack = GetSmemVPackK(); // this is for lds constexpr index_t KVector = GetAlignmentV(); // this is for global load constexpr index_t kPad = kVLdsPadInBytes / @@ -487,31 +488,27 @@ struct BlockFmhaV3PipelineDefaultPolicy constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); - constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( - make_tuple(number{}, // n0 - number{}, // n1 - number{}, // n2 - number{}, // k0 - number{}), // k1 - make_tuple(number{}, - number{}, - number{}, - number{}, - number<1>{}), - number<(IBuf + 2) * GetSingleSmemElementSpaceSize()>{}, - number{}, - number<1>{}); + constexpr auto v_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, // n0 + number{}, // n1 + number{}, // n2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); - // TODO this layout is hard coded, and will be used in async copy buffer view load - // in LDS the real layout is (bufs, N0, N2, N1*K0*K1) constexpr auto v_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( v_lds_block_desc_0, - make_tuple(make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_merge_transform(make_tuple( - number{}, number{}, number{}))), - make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + make_tuple(make_merge_transform(make_tuple( + number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); return v_lds_block_desc_issues_warps_lanes; } diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp index 108afd9b1c..0ac8efbc8d 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp @@ -213,38 +213,38 @@ struct BlockGemmARegBRegCRegV1 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A Block window - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - // read C warp tensor from C block tensor - using c_iter_idx = std:: - conditional_t, sequence>; - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + // read C warp tensor from C block tensor + using c_iter_idx = + std::conditional_t, sequence>; + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } @@ -323,73 +323,69 @@ struct BlockGemmARegBRegCRegV1 // hot loop with MX scaling and pre-packed int32_t scales: // Outer loops iterate over pack groups (scale tile indices) - static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { - static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) { - // Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t) - auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data( - sequence{}, sequence<1, 1, 1>{}); - const int32_t a_scale_packed = bit_cast(scale_a_slice[number<0>{}]); + static_ford>{}([&](auto ii) { + constexpr auto ikpack = number{}]>{}; + constexpr auto impack = number{}]>{}; + // Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t) + auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data( + sequence{}, sequence<1, 1, 1>{}); + const int32_t a_scale_packed = bit_cast(scale_a_slice[number<0>{}]); - static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { - // Get pre-packed int32_t B scale - auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data( - sequence{}, sequence<1, 1, 1>{}); - const int32_t b_scale_packed = bit_cast(scale_b_slice[number<0>{}]); + static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { + // Get pre-packed int32_t B scale + auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data( + sequence{}, sequence<1, 1, 1>{}); + const int32_t b_scale_packed = bit_cast(scale_b_slice[number<0>{}]); - // Inner loops: issue MFMAs within the pack group using OpSel - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - constexpr auto kIter = ikpack * KXdlPack + ikxdl; - constexpr auto mIter = impack * MXdlPack + imxdl; + // Inner loops: issue MFMAs within the pack group using OpSel + static_ford>{}([&](auto jj) { + constexpr auto ikxdl = number{}]>{}; + constexpr auto imxdl = number{}]>{}; + constexpr auto kIter = ikpack * KXdlPack + ikxdl; + constexpr auto mIter = impack * MXdlPack + imxdl; - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = - a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - // OpSel for A: selects byte within packed int32_t - constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl; + // OpSel for A: selects byte within packed int32_t + constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl; - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto nIter = inpack * NXdlPack + inxdl; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto nIter = inpack * NXdlPack + inxdl; - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = - b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, - b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - // OpSel for B: selects byte within packed int32_t - constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl; + // OpSel for B: selects byte within packed int32_t + constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl; - // read C warp tensor from C block tensor - using c_iter_idx = std::conditional_t, - sequence>; - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tensor.get_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + // read C warp tensor from C block tensor + using c_iter_idx = std::conditional_t, + sequence>; + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM with MX scaling using pre-packed scale and OpSel - WarpGemm{}.template operator()(c_warp_tensor, - a_warp_tensor, - b_warp_tensor, - a_scale_packed, - b_scale_packed); + // warp GEMM with MX scaling using pre-packed scale and OpSel + WarpGemm{}.template operator()(c_warp_tensor, + a_warp_tensor, + b_warp_tensor, + a_scale_packed, + b_scale_packed); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); }); diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp index 960a685792..a559206b98 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp @@ -250,74 +250,74 @@ struct BlockGemmARegBRegCRegV2 // hot loop: if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A Block window - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - // read A warp tensor from A Block window - AWarpTensor a_warp_tensor; + static_ford>{}([&](auto mnk) { + constexpr auto mIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - }); + // warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); } } diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp index 3302d149ca..a7f1cef519 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp @@ -109,13 +109,13 @@ struct BlockGemmARegBSmemCRegOneWarpV1 NIterPerWarp> b_warp_windows; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); }); #endif @@ -141,35 +141,35 @@ struct BlockGemmARegBSmemCRegOneWarpV1 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp index 14d59ff373..0118258668 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp @@ -116,13 +116,13 @@ struct BlockGemmARegBSmemCRegV1 NIterPerWarp> b_warp_windows; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); }); #endif @@ -148,35 +148,35 @@ struct BlockGemmARegBSmemCRegV1 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp index 0aa7509b1e..d292cade24 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp @@ -103,13 +103,13 @@ struct BlockGemmARegBSmemCRegV2 NIterPerWarp> b_warp_windows; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); }); #endif @@ -135,36 +135,36 @@ struct BlockGemmARegBSmemCRegV2 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp index 2ba01d91c5..9ffc9f2070 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp @@ -90,13 +90,13 @@ struct BlockGemmARegBSmemCRegV2R1 NIterPerWarp> b_warp_windows; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); }); // check C-block-distribution @@ -126,43 +126,43 @@ struct BlockGemmARegBSmemCRegV2R1 NIterPerWarp> b_warp_tensors; - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_warp_tensors(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter)); - }); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_warp_tensors(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter)); }); // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = b_warp_tensors(nIter)(kIter); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + // read B warp tensor from B Block window + const auto b_warp_tensor = b_warp_tensors(nIter)(kIter); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp index b1223f8755..2b750c75b3 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp @@ -116,13 +116,13 @@ struct BlockGemmASmemBRegCRegV1 MIterPerWarp> a_warp_windows; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); #endif @@ -148,34 +148,34 @@ struct BlockGemmASmemBRegCRegV1 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A Block window - const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A Block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp index 6eedfabaf8..32776b786d 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp @@ -85,13 +85,13 @@ struct BlockGemmASmemBSmemCRegV1 MIterPerWarp> a_warp_windows; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); #endif @@ -120,13 +120,13 @@ struct BlockGemmASmemBSmemCRegV1 NIterPerWarp> b_warp_windows; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); }); #endif @@ -138,31 +138,31 @@ struct BlockGemmASmemBSmemCRegV1 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block window - const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } diff --git a/include/ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1.hpp index 5dde03912a..9ad8c4cc97 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1.hpp @@ -165,61 +165,60 @@ struct BlockGemmMxARegBSmemCRegV1 uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - auto b_warp_window = b_warp_window_tmp; - move_tile_window( - b_warp_window, - {nIter * (NPerBlock / NIterPerWarp), kIter * (KPerBlock / KIterPerWarp)}); - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_window); + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + auto b_warp_window = b_warp_window_tmp; + move_tile_window( + b_warp_window, + {nIter * (NPerBlock / NIterPerWarp), kIter * (KPerBlock / KIterPerWarp)}); + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_window); - BScaleWarpTensor b_scale_warp_tensor; + BScaleWarpTensor b_scale_warp_tensor; - b_scale_warp_tensor.get_thread_buffer() = - b_scale_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, - b_scale_warp_y_index_zeros), - merge_sequences(sequence<1, 1, 1>{}, b_scale_warp_y_lengths)); + b_scale_warp_tensor.get_thread_buffer() = b_scale_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, + b_scale_warp_y_index_zeros), + merge_sequences(sequence<1, 1, 1>{}, b_scale_warp_y_lengths)); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - AScaleWarpTensor a_scale_warp_tensor; + AScaleWarpTensor a_scale_warp_tensor; - a_scale_warp_tensor.get_thread_buffer() = - a_scale_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_scale_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_scale_warp_y_lengths)); + a_scale_warp_tensor.get_thread_buffer() = + a_scale_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_scale_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_scale_warp_y_lengths)); - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WarpGemm{}.template operator()<0, 0>( - c_warp_tensor, - a_warp_tensor, - b_warp_tensor, - int32_t(a_scale_warp_tensor.get_thread_buffer()[0]), - int32_t(b_scale_warp_tensor.get_thread_buffer()[0])); + // warp GEMM + WarpGemm{}.template operator()<0, 0>( + c_warp_tensor, + a_warp_tensor, + b_warp_tensor, + int32_t(a_scale_warp_tensor.get_thread_buffer()[0]), + int32_t(b_scale_warp_tensor.get_thread_buffer()[0])); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index f7f5cd33db..2b64f6e340 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -239,39 +239,39 @@ struct BlockUniversalGemmAsBsCr "C block tensor data type!"); // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } @@ -392,63 +392,59 @@ struct BlockUniversalGemmAsBsCr 0); // Prevents instruction reordering across this boundary } - static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kInnerIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = - b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, - b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - // read C warp tensor from C block tensor- - CWarpTensor c_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + // read C warp tensor from C block tensor- + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // The block_sync_lds() here performs double duty: - // A) safeguard against data hazard because barrier from - // blockwise_gemm is moved here B) reduce VMEM FIFO congestion - // by applying small delays to different wavefronts It is - // performed near the end of MAC cluster to minimize lgkmcnt - // penalty - if constexpr(kIter.value == KRepeat - 1 && - kInnerIter.value == KInnerLoopIter - 1 && - mIter.value == MIterPerWarp - 1 && - nIter.value == NIterPerWarp - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from + // blockwise_gemm is moved here B) reduce VMEM FIFO congestion + // by applying small delays to different wavefronts It is + // performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(kIter.value == KRepeat - 1 && + kInnerIter.value == KInnerLoopIter - 1 && + mIter.value == MIterPerWarp - 1 && + nIter.value == NIterPerWarp - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + // warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); - if constexpr(kInnerIter.value == 0 && mIter.value == 0 && - nIter.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } - }); + if constexpr(kInnerIter.value == 0 && mIter.value == 0 && nIter.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); diff --git a/include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp b/include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp index 4fc180b42b..45602f3064 100644 --- a/include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp +++ b/include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp @@ -156,55 +156,54 @@ struct BlockWeightPreshuffleASmemBRegCReg uniform_sequence_gen_t{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - BWarpTensor b_warp_tensor; - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + BWarpTensor b_warp_tensor; + CWarpTensor c_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, - typename sequence_split::right_type{}), - merge_sequences( - sequence<1, 1>{}, - typename sequence_split::right_type{})); + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences( + sequence{}, + typename sequence_split::right_type{}), + merge_sequences( + sequence<1, 1>{}, + typename sequence_split::right_type{})); - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WarpGemm{}( - c_warp_tensor, preloaded_a_warp_tensor(number{}), b_warp_tensor); + // warp GEMM + WarpGemm{}( + c_warp_tensor, preloaded_a_warp_tensor(number{}), b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - - load_tile(preloaded_a_warp_tensor(number{}), - a_load_windows[number{}][number{}]); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + __builtin_amdgcn_sched_barrier(0x7F6); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + + load_tile(preloaded_a_warp_tensor(number{}), + a_load_windows[number{}][number{}]); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); } }; diff --git a/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp index 49c26fab6c..08a7e7a3ea 100644 --- a/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp @@ -88,28 +88,28 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block window - const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + static_ford>{}([&](auto km) { + constexpr auto kIter = number{}]>{}; + constexpr auto mIter = number{}]>{}; + // read A warp tensor from A block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter)); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); } diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index e37af2ef5f..c2ddaa2730 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -369,6 +369,13 @@ using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl< WarpGemmAttributeMfma>>; +template +using WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed = + WarpGemmImpl, + 2, + AttrNumAccess>>; + using WarpGemmMfma_f32_32x32x32_fp8_bf8 = WarpGemmImpl, 2>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 94e0494aac..f59bd61db7 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -170,6 +170,8 @@ template struct Dispatcher struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8<>; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8<>; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; }; diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp index a068001482..94fabe6f65 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp @@ -210,45 +210,45 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg : public BlockGemmQuantBase c_acc; auto zero_accumulators = [&] { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, (WG::kM * WG::kN) / warp_size, 1>{}([&](auto i) { - c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f; - }); // make sure WG::CWarpTensor exposes a clear/zero + static_ford>{}( + [&](auto mni) { + constexpr auto mIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + constexpr auto i = number{}]>{}; + c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f; }); - }); }; static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) { zero_accumulators(); - static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) { - constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // warp GEMM - WG{}(c_acc(mIter)(nIter), - a_warp_tensor(number{}), - b_warp_tensor(nIter)(number{})); - }); - __builtin_amdgcn_sched_barrier(0x7F6); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - - load_and_convert_tile( - a_warp_tensor(number{}), - a_warp_windows(number{})(number{})); - } - // barrier - // Could be deleted - if constexpr((mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + static_ford>{}([&](auto km) { + constexpr auto kIterInQScale = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // warp GEMM + WG{}(c_acc(mIter)(nIter), + a_warp_tensor(number{}), + b_warp_tensor(nIter)(number{})); }); + __builtin_amdgcn_sched_barrier(0x7F6); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + + load_and_convert_tile( + a_warp_tensor(number{}), + a_warp_windows(number{})(number{})); + } + // barrier + // Could be deleted + if constexpr((mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { AQPickerCommon aq_picker(aq_block_tensor); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index d2cfaca7b7..1ee3b227b7 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -127,105 +127,103 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg c_acc; auto zero_accumulators = [&] { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, (WG::kM * WG::kN) / warp_size, 1>{}([&](auto i) { - c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f; - }); // make sure WG::CWarpTensor exposes a clear/zero + static_ford>{}( + [&](auto mni) { + constexpr auto mIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + constexpr auto i = number{}]>{}; + c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f; }); - }); }; static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) { zero_accumulators(); - static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) { - constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // warp GEMM - WG{}(c_acc(mIter)(nIter), - a_warp_tensor(number{}), - b_warp_tensor(nIter)(number{})); - }); - __builtin_amdgcn_sched_barrier(0x7F6); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows(number{})(number{})); - } - // barrier - // Could be deleted - if constexpr((mIter == MIter_2nd_last)) - { - block_sync_lds(); - } - }); - }); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_ford>{}([&](auto km) { + constexpr auto kIterInQScale = number{}]>{}; + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - constexpr auto tbuf_offset = - number{}, - c_warp_y_index_zeros)) / - CBlockTensor::PackedSize>{}; + // warp GEMM + WG{}(c_acc(mIter)(nIter), + a_warp_tensor(number{}), + b_warp_tensor(nIter)(number{})); + }); + __builtin_amdgcn_sched_barrier(0x7F6); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows(number{})(number{})); + } + // barrier + // Could be deleted + if constexpr((mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + static_ford>{}([&](auto mn) { + constexpr auto mIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + constexpr auto tbuf_offset = + number{}, c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; - if constexpr(BPreshuffleQuant) + if constexpr(BPreshuffleQuant) + { + constexpr index_t reg_offset = nIter; + auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale; + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + // cross lane ops + uint32_t scale_reg_dword; + + if constexpr(std::is_same_v) { - constexpr index_t reg_offset = nIter; - auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale; - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - // cross lane ops - uint32_t scale_reg_dword; - - if constexpr(std::is_same_v) - { - scale_reg_dword = ck_tile::bit_cast(scale_reg); - } - else - { - scale_reg_dword = static_cast(scale_reg); - } - - // cross lane ops to get the value of scale_reg. - int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( - pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); - - float scale_reg_f = cvt_scale_to_fp32(gathered_scale_reg); - - static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { - auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; - const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; - c_ref = c_ref + acc_val * scale_reg_f; - }); + scale_reg_dword = ck_tile::bit_cast(scale_reg); } else { - index_t reg_offset = [&]() { - if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN)) - { - return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN * - KPerBlockBQ + - kQScale; - } - else - { - return nIter * KPerBlockBQ + kQScale; - } - }(); - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - float scale_reg_f = cvt_scale_to_fp32(scale_reg); - - static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { - auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; - const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; - c_ref = c_ref + acc_val * scale_reg_f; - }); + scale_reg_dword = static_cast(scale_reg); } - }); + + // cross lane ops to get the value of scale_reg. + int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( + pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); + + float scale_reg_f = cvt_scale_to_fp32(gathered_scale_reg); + + static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { + auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; + const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; + c_ref = c_ref + acc_val * scale_reg_f; + }); + } + else + { + index_t reg_offset = [&]() { + if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN)) + { + return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN * KPerBlockBQ + + kQScale; + } + else + { + return nIter * KPerBlockBQ + kQScale; + } + }(); + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float scale_reg_f = cvt_scale_to_fp32(scale_reg); + + static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { + auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; + const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; + c_ref = c_ref + acc_val * scale_reg_f; + }); + } }); }); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index 24d9f9a1e5..cc65d213f1 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -290,121 +290,115 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase constexpr auto warp_size = get_warp_size(); // hot loop: - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto mn) { + constexpr auto mIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + CWarpTensor c_warp_tensor; - static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { - static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { - constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = - a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = - b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - if constexpr(kIterInQScale == 0) - { - c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); - } - else - { - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - } - }); - - constexpr auto tbuf_offset = - number{}, - c_warp_y_index_zeros)) / - CBlockTensor::PackedSize>{}; - // a_scale - AQPickerCommon aq_picker( - aq_block_tensor); - - if constexpr(BPreshuffleQuant) + if constexpr(kIterInQScale == 0) { - constexpr index_t reg_offset = [&]() { - if constexpr(GemmTraits::BQuantGroupSize::kN > - (NWarp * WarpGemm::kN) && - Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN) - { - return kQScale; - } - else - { - return nIter; - } - }(); - - auto pull_from_lane = - (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; - - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - // cross lane ops - uint32_t scale_reg_dword; - - if constexpr(std::is_same_v) - { - scale_reg_dword = ck_tile::bit_cast(scale_reg); - } - else - { - scale_reg_dword = static_cast(scale_reg); - } - - // cross lane ops to get the value of scale_reg. - int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( - pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); - - float b_scale_reg_f = - Base::cvt_scale_to_fp32( - gathered_scale_reg); - - static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( - [&](auto c_row) { - float a_scale_reg_f = aq_picker.template pick(); - c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += - (c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f * - b_scale_reg_f); - }); + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); } else { - // Multiply bquant with accumulated C - constexpr index_t reg_offset = [&]() { - if constexpr(GemmTraits::BQuantGroupSize::kN >= - (NWarp * WarpGemm::kN)) - return (nIter * NWarp * WarpGemm::kN) / - GemmTraits::BQuantGroupSize::kN * - Traits::KQPerBlock + - kQScale; - else - { - return nIter * Traits::KQPerBlock + kQScale; - } - }(); - - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - float b_scale_reg_f = - Base::cvt_scale_to_fp32(scale_reg); - - static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( - [&](auto c_row) { - float a_scale_reg_f = aq_picker.template pick(); - c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += - (c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f * - b_scale_reg_f); - }); + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); } }); + + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + // a_scale + AQPickerCommon aq_picker( + aq_block_tensor); + + if constexpr(BPreshuffleQuant) + { + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::BQuantGroupSize::kN > (NWarp * WarpGemm::kN) && + Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN) + { + return kQScale; + } + else + { + return nIter; + } + }(); + + auto pull_from_lane = + (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; + + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + // cross lane ops + uint32_t scale_reg_dword; + + if constexpr(std::is_same_v) + { + scale_reg_dword = ck_tile::bit_cast(scale_reg); + } + else + { + scale_reg_dword = static_cast(scale_reg); + } + + // cross lane ops to get the value of scale_reg. + int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( + pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); + + float b_scale_reg_f = Base::cvt_scale_to_fp32( + gathered_scale_reg); + + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + float a_scale_reg_f = aq_picker.template pick(); + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f * + b_scale_reg_f); + }); + } + else + { + // Multiply bquant with accumulated C + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN)) + return (nIter * NWarp * WarpGemm::kN) / + GemmTraits::BQuantGroupSize::kN * Traits::KQPerBlock + + kQScale; + else + { + return nIter * Traits::KQPerBlock + kQScale; + } + }(); + + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float b_scale_reg_f = + Base::cvt_scale_to_fp32(scale_reg); + + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + float a_scale_reg_f = aq_picker.template pick(); + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f * + b_scale_reg_f); + }); + } }); }); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 8b09530af1..64f8bc7df4 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -268,54 +268,51 @@ struct AQuantBlockUniversalGemmAsBsCr constexpr auto warp_size = get_warp_size(); // hot loop: - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto mn) { + constexpr auto mIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + CWarpTensor c_warp_tensor; - // for every column in AQ - static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { - // for every warp corresponding to a quantization scale - static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { - constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; + // for every column in AQ + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + // for every warp corresponding to a quantization scale + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = - a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = - b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - if constexpr(kIterInQScale == 0) - { - c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); - } - else - { - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - } - }); + if constexpr(kIterInQScale == 0) + { + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } + }); - constexpr auto tbuf_offset = - number{}, - c_warp_y_index_zeros)) / - CBlockTensor::PackedSize>{}; + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; - AQPickerCommon aq_picker( - aq_block_tensor); + AQPickerCommon aq_picker( + aq_block_tensor); - static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( - [&](auto c_row) { - float scale_reg_f = aq_picker.template pick(); + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}([&](auto c_row) { + float scale_reg_f = aq_picker.template pick(); - c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += - (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); - }); + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); }); }); }); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index f5900fcdec..9851fc917d 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -290,57 +290,55 @@ struct BQuantBlockUniversalGemmAsBsCr using SrcVectorRawType = ext_vector_t; using DstVectorType = ext_vector_t; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { - // B scale register offset - constexpr index_t reg_offset = [&]() { - if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN)) - return ((nIter * NWarp * WarpGemm::kN) / - GemmTraits::BQuantGroupSize::kN) * - Traits::KQPerBlock + - kQScale; - else - { - return nIter * Traits::KQPerBlock + kQScale; - } - }(); + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kQScale = number{}]>{}; + // B scale register offset + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN)) + return ((nIter * NWarp * WarpGemm::kN) / GemmTraits::BQuantGroupSize::kN) * + Traits::KQPerBlock + + kQScale; + else + { + return nIter * Traits::KQPerBlock + kQScale; + } + }(); - // Get B scale from thread buffer - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - float b_scale_f = float(scale_reg); + // Get B scale from thread buffer + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float b_scale_f = float(scale_reg); - static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { - constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; - // Thread buffers - using BWarpThreadBuffer = decltype(b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths))); - using BLDSThreadBuffer = decltype(b_warp_tile_lds_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths))); + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; + // Thread buffers + using BWarpThreadBuffer = decltype(b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths))); + using BLDSThreadBuffer = decltype(b_warp_tile_lds_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths))); - BWarpThreadBuffer b_warp_thread_buffer; - BLDSThreadBuffer b_lds_thread_buffer; + BWarpThreadBuffer b_warp_thread_buffer; + BLDSThreadBuffer b_lds_thread_buffer; - // Load thread buffer from tile (LDS type) - b_lds_thread_buffer = b_warp_tile_lds_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + // Load thread buffer from tile (LDS type) + b_lds_thread_buffer = b_warp_tile_lds_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - // Apply scale to B thread buffer and cast - static_for<0, thread_buffer_size, 1>{}([&](auto i) { - elementwise_op( - b_warp_thread_buffer.template get_as()(i), - b_lds_thread_buffer.template get_as()[i], - b_scale_f); - }); - - // Store B thread buffer to tile (MMA type) - b_warp_tile_.set_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths), - b_warp_thread_buffer); + // Apply scale to B thread buffer and cast + static_for<0, thread_buffer_size, 1>{}([&](auto i) { + elementwise_op(b_warp_thread_buffer.template get_as()(i), + b_lds_thread_buffer.template get_as()[i], + b_scale_f); }); + + // Store B thread buffer to tile (MMA type) + b_warp_tile_.set_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths), + b_warp_thread_buffer); }); }); } @@ -361,113 +359,107 @@ struct BQuantBlockUniversalGemmAsBsCr constexpr auto warp_size = get_warp_size(); // hot loop: - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - CWarpTensor c_warp_tensor; + static_ford>{}([&](auto mn) { + constexpr auto mIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + CWarpTensor c_warp_tensor; - static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { - static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { - constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = - a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = - b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - if constexpr(kIterInQScale == 0) - { - c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); - } - else - { - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - } - }); - - constexpr auto tbuf_offset = - number{}, - c_warp_y_index_zeros)) / - CBlockTensor::PackedSize>{}; - - if constexpr(BPreshuffleQuant) + if constexpr(kIterInQScale == 0) { - constexpr index_t reg_offset = [&]() { - if constexpr(GemmTraits::BQuantGroupSize::kN > - (NWarp * WarpGemm::kN) && - Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN) - { - return kQScale; // prefill: one quant group per block - } - else - { - return nIter; // decode or multiple groups per warp - } - }(); - - auto pull_from_lane = - (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; - - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - // cross lane ops - uint32_t scale_reg_dword; - - if constexpr(std::is_same_v) - { - scale_reg_dword = ck_tile::bit_cast(scale_reg); - } - else - { - scale_reg_dword = static_cast(scale_reg); - } - - // cross lane ops to get the value of scale_reg. - int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( - pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); - - float scale_reg_f = - Base::cvt_scale_to_fp32( - gathered_scale_reg); - - static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( - [&](auto c_row) { - c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += - (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); - }); + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); } else { - // Multiply bquant with accumulated C - constexpr index_t reg_offset = [&]() { - if constexpr(GemmTraits::BQuantGroupSize::kN >= - (NWarp * WarpGemm::kN)) - return (nIter * NWarp * WarpGemm::kN) / - GemmTraits::BQuantGroupSize::kN * - Traits::KQPerBlock + - kQScale; - else - { - return nIter * Traits::KQPerBlock + kQScale; - } - }(); - - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - float scale_reg_f = - Base::cvt_scale_to_fp32(scale_reg); - static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( - [&](auto c_row) { - c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += - (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); - }); + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); } }); + + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + if constexpr(BPreshuffleQuant) + { + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::BQuantGroupSize::kN > (NWarp * WarpGemm::kN) && + Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN) + { + return kQScale; // prefill: one quant group per block + } + else + { + return nIter; // decode or multiple groups per warp + } + }(); + + auto pull_from_lane = + (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; + + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + // cross lane ops + uint32_t scale_reg_dword; + + if constexpr(std::is_same_v) + { + scale_reg_dword = ck_tile::bit_cast(scale_reg); + } + else + { + scale_reg_dword = static_cast(scale_reg); + } + + // cross lane ops to get the value of scale_reg. + int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( + pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); + + float scale_reg_f = Base::cvt_scale_to_fp32( + gathered_scale_reg); + + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); + }); + } + else + { + // Multiply bquant with accumulated C + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN)) + return (nIter * NWarp * WarpGemm::kN) / + GemmTraits::BQuantGroupSize::kN * Traits::KQPerBlock + + kQScale; + else + { + return nIter * Traits::KQPerBlock + kQScale; + } + }(); + + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float scale_reg_f = + Base::cvt_scale_to_fp32(scale_reg); + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); + }); + } }); }); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index f48e12984c..c87a02efe0 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -288,22 +288,22 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe MIterPerWarp> a_warp_windows_pong; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; - move_tile_window(a_warp_windows_ping(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; - move_tile_window(a_warp_windows_pong(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(a_warp_windows_pong(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); // Block GEMM @@ -366,16 +366,16 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // prefetch B - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * flatNPerWarp, kIter * flatKPerWarp}); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), - b_flat_dram_windows(nIter)(kIter)); - }); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); // move B window to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -448,15 +448,15 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe bq_block_tile, a_warp_windows_ping); // prefetch B(2i+1) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), - b_flat_dram_windows(nIter)(kIter)); - }); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); aq_block_tile_2 = load_tile(aq_copy_dram_window); @@ -473,15 +473,15 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe // Next K // prefetch B(2i+2) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), - b_flat_dram_windows(nIter)(kIter)); - }); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); aq_block_tile = load_tile(aq_copy_dram_window); @@ -520,16 +520,16 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe if constexpr(TailNum == TailNumber::Even) { // prefetch B(loopK) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * flatNPerWarp, kIter * flatKPerWarp}); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), - b_flat_dram_windows(nIter)(kIter)); - }); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); aq_block_tile_2 = load_tile(aq_copy_dram_window); bq_block_tile_2 = load_tile(bq_copy_dram_window); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index 025ef53dbb..ff98a06662 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -275,22 +275,22 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV MIterPerWarp> a_warp_windows_pong; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; - move_tile_window(a_warp_windows_ping(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + static_ford>{}([&](auto mk) { + constexpr auto mIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; - move_tile_window(a_warp_windows_pong(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); + move_tile_window(a_warp_windows_pong(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); // Block GEMM @@ -337,16 +337,16 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // prefetch B - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto nk) { + constexpr auto nIter = number{}]>{}; + constexpr auto kIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * flatNPerWarp, kIter * flatKPerWarp}); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), - b_flat_dram_windows(nIter)(kIter)); - }); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); // move B window to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -424,15 +424,15 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV bq_block_tile, a_warp_windows_ping); // prefetch B(2i+1) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), - b_flat_dram_windows(nIter)(kIter)); - }); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -461,15 +461,15 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV // Next K // prefetch B(2i+2) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), - b_flat_dram_windows(nIter)(kIter)); - }); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -518,16 +518,16 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV if constexpr(TailNum == TailNumber::Even) { // prefetch B(loopK) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_ford>{}([&](auto kn) { + constexpr auto kIter = number{}]>{}; + constexpr auto nIter = number{}]>{}; + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * flatNPerWarp, kIter * flatKPerWarp}); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), - b_flat_dram_windows(nIter)(kIter)); - }); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); bq_block_tile_2 = load_tile(bq_copy_dram_window); diff --git a/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp index 07d97ec4ff..da9c5c4d57 100644 --- a/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp @@ -303,11 +303,11 @@ struct BlockNormReduceCrossWarpSync index_t local_warp_id = warp_id / num_reduce_warps; index_t local_smem_os = local_warp_id * num_reduce_warps; smem_dtype all_scratch[thread_buf_size * num_reduce_warps]; - static_for<0, thread_buf_size, 1>{}([&](auto i_0) { - static_for<0, num_reduce_warps, 1>{}([&](auto i_1) { - all_scratch[i_0 * num_reduce_warps + i_1] = - smem_ptr[i_0 * num_warps + local_smem_os + i_1]; - }); + static_ford>{}([&](auto ii) { + constexpr auto i_0 = number{}]>{}; + constexpr auto i_1 = number{}]>{}; + all_scratch[i_0 * num_reduce_warps + i_1] = + smem_ptr[i_0 * num_warps + local_smem_os + i_1]; }); block_sync_lds(); // TODO: we don't need sync here diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index ccbdb20793..abad5ed031 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -631,17 +631,17 @@ struct BlockReduce2dLinearCrossWarpSync IndexDataType> all_indices; // Load data from shared memory - static_for<0, thread_buf_size, 1>{}([&](auto i_0) { - static_for<0, num_reduce_warps, 1>{}([&](auto i_1) { - all_scratch[i_0 * num_reduce_warps + i_1] = - smem_ptr[i_0 * num_warps + local_smem_os + i_1]; + static_ford>{}([&](auto ii) { + constexpr auto i_0 = number{}]>{}; + constexpr auto i_1 = number{}]>{}; + all_scratch[i_0 * num_reduce_warps + i_1] = + smem_ptr[i_0 * num_warps + local_smem_os + i_1]; - if constexpr(kProcessIndex) - { - all_indices[i_0 * num_reduce_warps + i_1] = - smem_indices[i_0 * num_warps + local_smem_os + i_1]; - } - }); + if constexpr(kProcessIndex) + { + all_indices[i_0 * num_reduce_warps + i_1] = + smem_indices[i_0 * num_warps + local_smem_os + i_1]; + } }); block_sync_lds(); // TODO: we don't need sync here