From 4fdcfab53d61fb4e40af67655f9855d02946a104 Mon Sep 17 00:00:00 2001 From: Christopher Millette <63608002+cgmillette@users.noreply.github.com> Date: Wed, 18 Mar 2026 08:45:22 -0600 Subject: [PATCH] [CK] Replace nested static_for with static_ford to reduce device IR function emissions [1B] (#5031) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary ### Rationale CK's GPU kernels are among the slowest files in the ROCm build, with a single translation unit taking up to 10+ minutes. Profiling with `-ftime-trace` identified nested `static_for` loops as the root cause: each nesting level multiplies the number of unique lambda IR functions the compiler must process. A 2-level nest of `static_for<0, M, 1>` / `static_for<0, N, 1>` produces M×N unique lambda types. With typical GEMM dimensions (M=16, N=4), a single nest generates 64 unique functions — and these nests appear hundreds of times across the codebase. The LLVM backend's CGSCC (Call Graph Strongly Connected Components) framework processes each function independently, so reducing function count directly reduces backend time. ### What changed 393 nested compile-time loop patterns across 73 files are converted to `static_ford`, which flattens multi-dimensional compile-time iteration into a single `static_for` with index decomposition. This eliminates 994 `static_for` nesting levels (42% reduction). Three pattern categories were converted: - **Category A**: `static_for` wrapping `static_ford` — fold outer dimension into ford - **Category B**: nested `static_ford` — merge into single higher-dimensional ford - **Category C**: nested `static_for` chains — convert to single `static_ford` ### Verification **ASM equivalence: PASS — 51/51 device assembly files identical (gfx942 + gfx1100)** | Architecture | Files compared | Largest file | Result | |---|---|---|---| | gfx942 | 36 | 386,685 lines | ALL MATCH | | gfx1100 | 15 | 47,769 lines | ALL MATCH | **Build time (Wilcoxon signed-rank test, 7 paired trials):** | Target | Pre (s) | Post (s) | Delta | p-value | |---|---|---|---|---| | bscale | 169 | 152 | **-9.8%** | 0.016 \* | | xdl_v1234 | 207 | 194 | **-6.6%** | 0.016 \* | | preshuffle | 275 | 264 | **-3.9%** | 0.016 \* | | xdl_base | 142 | 137 | **-3.2%** | 0.031 \* | **IR function counts (device backend, gfx942):** | Target | InstFunc Δ | CodeGen Δ | Compiler Δ | |---|---|---|---| | bscale | -13,043 (-8.2%) | -2,103 (-3.5%) | -10.7% | | xdl_v1234 | -9,431 (-5.7%) | +59 (+0.1%) | -5.2% | | xdl_base | -6,162 (-4.9%) | -1,141 (-2.5%) | -2.2% | | xdl_old | -3,234 (-3.7%) | -963 (-8.7%) | -3.3% | ### Value - **994 fewer `static_for` nesting levels** (-42%) across 73 files - **393 `static_ford` sites** created (from 4 pre-existing) - **Up to 9.8% compile-time reduction** on representative targets (statistically significant, p < 0.05) - **Up to 13K fewer IR function instantiations** per translation unit - Net -849 LOC from reduced indentation - **Zero ASM changes** — identical device code output verified on gfx942 and gfx1100 - All scheduling barriers, `if constexpr` guards, and MFMA/WMMA accumulation order preserved ### Files changed (73) - `block/`: 47 files (GEMM pipelines — xdlops, wmma, moe, preshuffle, blockscale variants) - `grid/`: 20 files (softmax, normalization, reduction, attention, layernorm) - `thread/`: 5 files (tensor slice transfer, contraction, GEMM dlops, reduction) - `tensor_description/`: 1 file (tensor_adaptor) ## Test plan - [x] `static_ford` tested with 21 unit tests in `test/util/unit_ford.cpp` (1D-4D, custom orders, compile-time verification) - [x] All conversions preserve iteration order, `block_sync_lds()` placement, `if constexpr` scheduling guards, and MFMA/WMMA accumulation order - [x] ASM equivalence verified: 51 device `.s` files across gfx942 + gfx1100 - [x] Build-time improvement statistically confirmed (Wilcoxon, p < 0.05, 4 targets) - [x] IR function count reduction confirmed via `-ftime-trace` on 7 targets - [x] Detection script reports 0 remaining safe patterns (180 blocked with structural reasons) - [x] Existing CI tests (GEMM, softmax, normalization, batch norm, reduction, attention) exercise all converted code paths ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Claude Opus 4.6 --- .../ck/tensor_description/tensor_adaptor.hpp | 20 +- .../blockwise_gemm_pipeline_wmmaops_base.hpp | 28 +- .../blockwise_gemm_pipeline_wmmaops_v1.hpp | 849 ++++++++-------- .../blockwise_gemm_pipeline_wmmaops_v3.hpp | 504 ++++------ .../block/blockwise_gemm_pipeline_xdlops.hpp | 518 +++++----- ...ipeline_xdlops_b_preshuffle_dequant_v1.hpp | 237 +++-- ...ipeline_xdlops_b_preshuffle_dequant_v3.hpp | 193 ++-- ...dlops_b_preshuffle_gufusion_dequant_v1.hpp | 297 +++--- ...peline_xdlops_b_preshuffle_gufusion_v1.hpp | 308 +++--- ...peline_xdlops_b_preshuffle_gufusion_v3.hpp | 649 ++++++------ ...xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp | 722 +++++++------ ...pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp | 650 ++++++------ ...pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp | 603 ++++++----- ...e_gemm_pipeline_xdlops_b_preshuffle_v1.hpp | 249 +++-- ...e_gemm_pipeline_xdlops_b_preshuffle_v2.hpp | 277 +++-- ...e_gemm_pipeline_xdlops_b_preshuffle_v3.hpp | 565 +++++------ ...line_xdlops_blockscale_b_preshuffle_v1.hpp | 593 +++++------ ...line_xdlops_blockscale_b_preshuffle_v3.hpp | 180 ++-- ...oe_blockscale_b_preshuffle_gufusion_v1.hpp | 809 +++++++-------- ...oe_blockscale_b_preshuffle_gufusion_v3.hpp | 305 +++--- ..._xdlops_moe_blockscale_b_preshuffle_v1.hpp | 593 +++++------ ..._xdlops_moe_blockscale_b_preshuffle_v3.hpp | 305 +++--- ...emm_pipeline_xdlops_mx_moe_gufusion_v3.hpp | 948 +++++++++--------- ...pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp | 948 +++++++++--------- ...ise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp | 443 ++++---- ...ise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp | 707 +++++++------ ...ockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp | 710 +++++++------ .../blockwise_gemm_pipeline_xdlops_v1.hpp | 497 +++++---- ...kwise_gemm_pipeline_xdlops_v1_ab_scale.hpp | 410 ++++---- ...ckwise_gemm_pipeline_xdlops_v1_b_scale.hpp | 114 +-- .../blockwise_gemm_pipeline_xdlops_v1_mx.hpp | 453 +++++---- .../blockwise_gemm_pipeline_xdlops_v2.hpp | 494 +++++---- ...kwise_gemm_pipeline_xdlops_v2_ab_scale.hpp | 333 +++--- ...ckwise_gemm_pipeline_xdlops_v2_b_scale.hpp | 648 ++++++------ .../blockwise_gemm_pipeline_xdlops_v3.hpp | 91 +- ...kwise_gemm_pipeline_xdlops_v3_ab_scale.hpp | 135 ++- ...ckwise_gemm_pipeline_xdlops_v3_b_scale.hpp | 91 +- .../blockwise_gemm_pipeline_xdlops_v3_mx.hpp | 710 +++++++------ ...gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp | 110 +- .../blockwise_gemm_pipeline_xdlops_v4.hpp | 480 +++++---- ...ckwise_gemm_pipeline_xdlops_v4_b_scale.hpp | 274 +++-- .../blockwise_gemm_pipeline_xdlops_v5.hpp | 39 +- .../gpu/block/blockwise_gemm_wmma.hpp | 177 ++-- .../gpu/block/blockwise_gemm_xdlops.hpp | 161 ++- .../gpu/block/blockwise_softmax.hpp | 14 +- ...cond_half_multiblock_reduce_first_half.hpp | 18 +- .../grid/epilogue_cshuffle_v3_reduce_wmma.hpp | 16 +- ...idwise_welford_second_half_layernorm2d.hpp | 20 +- ...dwise_2d_multiple_reduction_multiblock.hpp | 14 +- ...dwise_2d_multiple_reduction_threadwise.hpp | 14 +- .../grid/gridwise_2d_reduction_multiblock.hpp | 12 +- .../grid/gridwise_2d_reduction_threadwise.hpp | 42 +- ...idwise_2d_reduction_threadwise_multi_d.hpp | 12 +- ...ise_batched_gemm_gemm_wmma_cshuffle_v3.hpp | 16 +- ...ultiple_d_softmax_gemm_xdl_cshuffle_v1.hpp | 24 +- ...atched_gemm_softmax_gemm_wmma_cshuffle.hpp | 22 +- ...ched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp | 24 +- ...e_batchnorm_backward_blockwise_welford.hpp | 18 +- ...elementwise_layernorm_welford_variance.hpp | 54 +- ...e_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp | 7 +- ...emm_multiple_d_multiple_r_xdl_cshuffle.hpp | 14 +- .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 7 +- ...ridwise_gemm_xdl_layernorm_cshuffle_v1.hpp | 32 +- .../gpu/grid/gridwise_softmax.hpp | 46 +- ...se_sparse_embeddings_forward_layernorm.hpp | 59 +- ..._embeddings_forward_layernorm_builtins.hpp | 59 +- .../gridwise_normalization_naive_variance.hpp | 136 +-- .../gridwise_normalization_splitk_2nd.hpp | 52 +- ...ridwise_normalization_welford_variance.hpp | 104 +- .../thread/reduction_functions_threadwise.hpp | 28 +- .../gpu/thread/threadwise_contraction_dl.hpp | 92 +- .../gpu/thread/threadwise_gemm_dlops_v3.hpp | 107 +- .../threadwise_tensor_slice_transfer.hpp | 155 ++- 73 files changed, 9398 insertions(+), 10247 deletions(-) diff --git a/include/ck/tensor_description/tensor_adaptor.hpp b/include/ck/tensor_description/tensor_adaptor.hpp index ee8c7ed71b..0804e07fc3 100644 --- a/include/ck/tensor_description/tensor_adaptor.hpp +++ b/include/ck/tensor_description/tensor_adaptor.hpp @@ -350,16 +350,16 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a }); // match hidden id - static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { - static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) { - // if this low dim is bottom dim, then do id matching - if constexpr(low_dim_hidden_ids_1[idim_low_1] == - TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1]) - { - low_dim_hidden_ids_1_mod_(idim_low_1) = - TensorAdaptor0::GetTopDimensionHiddenIds()[idim_bottom_1]; - } - }); + static_ford>{}([&](auto ii) { + constexpr auto idim_low_1 = Number{}]>{}; + constexpr auto idim_bottom_1 = Number{}]>{}; + // if this low dim is bottom dim, then do id matching + if constexpr(low_dim_hidden_ids_1[idim_low_1] == + TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1]) + { + low_dim_hidden_ids_1_mod_(idim_low_1) = + TensorAdaptor0::GetTopDimensionHiddenIds()[idim_bottom_1]; + } }); return low_dim_hidden_ids_1_mod_; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp index e41cf8c82d..fb1ae8c543 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -202,22 +202,22 @@ struct BlockwiseGemmWmmaops_pipeline_base using AScaleThreadDesc = decltype(AScaleStruct::scale_thread_desc); using BScaleThreadDesc = decltype(BScaleStruct::scale_thread_desc); - static_for<0, num_scale_m_block, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_bufs(I0)(Number{}) = - a_scale_struct.scale_thread_bufs(I0)[Number{}] * - b_scale_struct.scale_thread_bufs(I0)[Number{}]; - }); + c_scale_thread_bufs(I0)(Number{}) = + a_scale_struct.scale_thread_bufs(I0)[Number{}] * + b_scale_struct.scale_thread_bufs(I0)[Number{}]; }); - }); } __device__ void Clear() diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp index 4f884b1df3..ee36f75164 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -224,87 +224,75 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(I0, m0, k0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0, I0, I0, I0, I0), - a_thread_buf); - if constexpr(m0 == I0) + static_ford>{}([&](auto km) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0, I0, I0, I0), + a_thread_buf); + if constexpr(m0 == I0) + { + if constexpr(ck::is_same::value == true) { - if constexpr(ck::is_same::value == true) - { - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(I0, n0, k0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0, I0), - b_thread_buf); - }); - } - else - { - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(I0, n0, k0, I0, I0, I0, I0), - b_block_buf, - b_scale_struct.scale_thread_bufs( - I0)[Number{}], - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0, I0), - b_thread_buf); - }); - } - } - - static_for<0, KInner, 1>{}([&](auto k_inner) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - I0, - I0, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - I0, - I0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0, I0), + b_thread_buf); }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_scale_struct.scale_thread_bufs( + I0)[Number{}], + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0, I0), + b_thread_buf); + }); + } + } + + static_ford>{}([&](auto kn) { + constexpr auto k_inner = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, I0, I0, I0, I0, I0, Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, n0, I0, I0, I0, I0, Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }; @@ -341,20 +329,17 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto) { __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read }); - static_for<0, KRepeat, 1>{}([&](auto) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - if constexpr(m0 == I0) - { - static_for<0, NRepeat, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - } - static_for<0, KInner, 1>{}([&](auto) { - static_for<0, NRepeat, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA - }); + static_ford>{}([&](auto km) { + constexpr auto m0 = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + if constexpr(m0 == I0) + { + static_for<0, NRepeat, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read }); + } + static_ford>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA }); }); static_for<0, num_ds_write_inst, 1>{}([&](auto) { @@ -464,59 +449,55 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { - c_scale_struct.Clear(); - static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KInner, 1>{}([&](auto k_inner) { - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k_index, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k_index, - I0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - wmma_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( - Number<0>{})); - }); + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; }); - c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( + Number<0>{})); }); }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); }); }; @@ -850,73 +831,71 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0_inner) { - static_for<0, KInner, 1>{}([&](auto k_inner) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}( + [&](auto kkmn) { + constexpr auto k0_inner = Number{}]>{}; + constexpr auto k_inner = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0_inner, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k0_inner, - I0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - // The block_sync_lds() here performs double duty: - // A) safeguard against data hazard. - // B) reduce VMEM FIFO congestion by applying small delays to - // different wavefronts. - // It is performed near the end of MAC cluster to minimize lgkmcnt - // penalty - if constexpr(k0_offset + k0_inner == KRepeat - 1 && - m0 == MRepeat - 1 && n0 == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - wmma_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } - }); + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0_inner, + I0, + I0, + I0, + Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0_inner, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard. + // B) reduce VMEM FIFO congestion by applying small delays to + // different wavefronts. + // It is performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 && + n0 == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); - }); __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_setprio(0); @@ -1249,15 +1228,15 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(I0, m0, k0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, k0, I0, I0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); }); // Initialize C @@ -1287,66 +1266,64 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - static_for<0, KInner, 1>{}([&](auto k_inner) { - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[wmma_reg_buf] - [Number{}, - I0, - I0, - n0, - I0, - k0, - Number{}))>{}]; - }); - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0, + I0, + I0, + I0, + Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[wmma_reg_buf] + [Number{}, + I0, + I0, + n0, + I0, + k0, + Number{}))>{}]; + }); + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); block_sync_lds(); // loop prefetch copy - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(I0, m0, k0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, k0, I0, I0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); }); HotLoopScheduler(); @@ -1373,112 +1350,86 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - static_for<0, KInner, 1>{}([&](auto k_inner) { - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}, - I0, - I0, - n0, - I0, - k0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, m0, k0, I0, I0, I0, Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}, I0, I0, n0, I0, k0, Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); block_sync_lds(); // tail Local Prefetch A1 - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(I0, m0, k0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, k0, I0, I0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); }); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - static_for<0, KInner, 1>{}([&](auto k_inner) { - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}, - I0, - I0, - n0, - I0, - k0, - Number{}))>{}]; - }); - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, m0, k0, I0, I0, I0, Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}, I0, I0, n0, I0, k0, Number{}))>{}]; + }); + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle @@ -1487,49 +1438,36 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - static_for<0, KInner, 1>{}([&](auto k_inner) { - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}, - I0, - I0, - n0, - I0, - k0, - Number{}))>{}]; - }); - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, m0, k0, I0, I0, I0, Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}, I0, I0, n0, I0, k0, Number{}))>{}]; + }); + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); } @@ -1590,70 +1528,65 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { - c_scale_struct.Clear(); - static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - static_for<0, KInner, 1>{}([&](auto k_inner) { - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k_index, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[reg_buf] - [Number{}, - I0, - I0, - n0, - I0, - k_index, - Number{}))>{}]; - }); - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - wmma_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( - Number<0>{})); - }); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; }); - c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[reg_buf][Number{}, + I0, + I0, + n0, + I0, + k_index, + Number{}))>{}]; + }); + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( + Number<0>{})); }); }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); }); }; auto a_local_prefetch_func = [&]() { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(I0, m0, k0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, k0, I0, I0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); }); }; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp index b8d451363e..03146f22ee 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp @@ -434,53 +434,38 @@ struct BlockwiseGemmWmmaops_pipeline_v3((i + 2) % num_loop_per_scale == 0); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KInner, 1>{}([&](auto k_inner) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmnk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k_inner = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, m0, k0, I0, I0, I0, Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, n0, k0, I0, I0, I0, Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); @@ -506,52 +491,35 @@ struct BlockwiseGemmWmmaops_pipeline_v3(num_loop % num_loop_per_scale == 0); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KInner, 1>{}([&](auto k_inner) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmnk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k_inner = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, m0, k0, I0, I0, I0, Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, n0, k0, I0, I0, I0, Number{}))>{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); @@ -564,52 +532,35 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KInner, 1>{}([&](auto k_inner) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmnk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k_inner = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, m0, k0, I0, I0, I0, Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, n0, k0, I0, I0, I0, Number{}))>{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle // latency @@ -747,58 +698,55 @@ struct BlockwiseGemmWmmaops_pipeline_v3((i + 2) % num_loop_per_scale == 0); b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { - c_scale_struct.Clear(); - static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { - static_for<0, KInner, 1>{}([&](auto k_inner) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + c_scale_struct.Clear(); + static_ford>{}([&](auto kk_id) { + constexpr auto k0 = Number{}]>{}; + constexpr auto k_inner = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k_index, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k_index, - I0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_scale_struct.c_thread_buf_per_scale - .GetVectorTypeReference(Number<0>{})); - }); - }); - c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( + Number<0>{})); }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); }); c_scale_struct.Load(a_scale_struct, b_scale_struct); @@ -825,59 +773,55 @@ struct BlockwiseGemmWmmaops_pipeline_v3(num_loop % num_loop_per_scale == 0); b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { - c_scale_struct.Clear(); - static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { - static_for<0, KInner, 1>{}([&](auto k_inner) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + c_scale_struct.Clear(); + static_ford>{}([&](auto kk_id) { + constexpr auto k0 = Number{}]>{}; + constexpr auto k_inner = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k_index, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k_index, - I0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - wmma_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( - Number<0>{})); - }); - }); - c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); }); c_scale_struct.Load(a_scale_struct, b_scale_struct); @@ -891,58 +835,54 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { - c_scale_struct.Clear(); - static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - static_for<0, KInner, 1>{}([&](auto k_inner) { - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k_index, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k_index, - I0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - wmma_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( - Number<0>{})); - }); - }); - c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + c_scale_struct.Clear(); + static_ford>{}([&](auto kk_id) { + constexpr auto k0 = Number{}]>{}; + constexpr auto k_inner = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); }); // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle // latency diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp index fa0a69ed1f..812e14d73f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp @@ -512,22 +512,22 @@ struct BlockwiseGemmXdlops_pipeline_v4 // Local prefetch 1th, Fill Ping Reg block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(I0)); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(I0), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(I0)); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(I0)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(I0), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(I0)); }); }); @@ -566,22 +566,22 @@ struct BlockwiseGemmXdlops_pipeline_v4 // DS_READ: Pong LDS to Pong Reg block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(PongP1{}), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(PongP1{})); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(PongP1{}), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(PongP1{})); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); }); }); @@ -594,33 +594,31 @@ struct BlockwiseGemmXdlops_pipeline_v4 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[PingP1{}][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[PingP1{}][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); HotLoopScheduler(); @@ -634,22 +632,22 @@ struct BlockwiseGemmXdlops_pipeline_v4 // DS_READ: Ping LDS to Ping Reg block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(PongP2{}), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(PongP2{})); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(PongP2{}), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(PongP2{})); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP2{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP2{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP2{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP2{})); }); }); @@ -662,33 +660,31 @@ struct BlockwiseGemmXdlops_pipeline_v4 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[PingP2{}][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[PingP2{}][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); HotLoopScheduler(); @@ -708,54 +704,52 @@ struct BlockwiseGemmXdlops_pipeline_v4 // DS_READ: Pong LDS to Pong Reg block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(PongP1{}), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(PongP1{})); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(PongP1{}), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(PongP1{})); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); }); }); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{})); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{})); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[PingP1{}][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[PingP1{}][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); TailScheduler<1>(); @@ -769,82 +763,78 @@ struct BlockwiseGemmXdlops_pipeline_v4 // DS_READ: Ping LDS to Ping Reg block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(PongP2{}), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(PongP2{})); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(PongP2{}), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(PongP2{})); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP2{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP2{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP2{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP2{})); }); }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[PingP2{}][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[PingP2{}][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); TailScheduler<2>(); __builtin_amdgcn_sched_barrier(0); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[PongP2{}][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[PongP2{}][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PongP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PongP2{}][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); // 64 v_mfma @@ -860,51 +850,49 @@ struct BlockwiseGemmXdlops_pipeline_v4 // DS_READ: Pong LDS to Pong Reg block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(PongP1{}), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(PongP1{})); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(PongP1{}), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(PongP1{})); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); }); }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[PingP1{}][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[PingP1{}][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); TailScheduler<2>(); @@ -916,32 +904,30 @@ struct BlockwiseGemmXdlops_pipeline_v4 // DS_WRITE: To Pong LDS // DS_READ: Ping LDS to Ping Reg - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[PingP2{}][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[PingP2{}][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); // 64 v_mfma diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp index b2cca08cf5..6d45539e35 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp @@ -275,15 +275,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); // B VGPR->VGPR dequant b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, @@ -318,47 +318,43 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[mfma_reg_buf] - [Number{}]; - }); - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[mfma_reg_buf][Number< + b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}]; }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); // B VGPR->VGPR dequant b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, @@ -390,45 +386,42 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_dequant_bufs + [I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); // B VGPR->VGPR dequant b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, @@ -440,32 +433,29 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_dequant_bufs + [I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle // latency @@ -473,32 +463,29 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_dequant_bufs + [I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp index 682760a3df..2c43998a36 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp @@ -593,39 +593,38 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[mfma_reg_buf] - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[mfma_reg_buf] + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == MRepeat - 1) @@ -710,30 +709,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == MRepeat - 1) @@ -781,30 +780,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value != (MRepeat - 1)) @@ -837,30 +836,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value != (MRepeat - 1)) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp index 64ea0f9eab..7285685404 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp @@ -289,15 +289,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< // Local prefetch A1 block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); // B VGPR->VGPR dequant b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, @@ -345,57 +345,51 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[mfma_reg_buf] - [Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_dequant_bufs_up - [mfma_reg_buf][Number{}]; - }); - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[mfma_reg_buf][Number< + b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[mfma_reg_buf][Number< + b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}]; }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); // B VGPR->VGPR dequant b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, @@ -439,52 +433,49 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< block_sync_lds(); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_dequant_bufs_up[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_dequant_bufs + [I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); // B VGPR->VGPR dequant b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, @@ -502,39 +493,36 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< b_thread_dequant_bufs_up(I1)); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[I1][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_dequant_bufs_up[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_dequant_bufs + [I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle // latency @@ -542,39 +530,36 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< } else { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_dequant_bufs_up[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_dequant_bufs + [I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp index 31a86199c8..c5e040cfb4 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp @@ -298,17 +298,16 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); // Initialize C @@ -342,60 +341,53 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[mfma_reg_buf] - [Number{}]; - }); - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[mfma_reg_buf][Number< + b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}]; }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); HotLoopScheduler(); @@ -425,93 +417,83 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = b_thread_bufs_up + [I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = b_thread_bufs_up + [I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle // latency @@ -519,39 +501,35 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = b_thread_bufs_up + [I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp index 1b661b29ca..3fc204e6a2 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp @@ -255,54 +255,53 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3= 3 ? 1 : 0; // B global read - static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + static_ford>{}([&](auto ii) { + constexpr auto i = Number{}]>{}; + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(((i < buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_more == - buffer_load_issue_point_b)) || - ((i >= buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_less == - buffer_load_issue_point_b))) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_b)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_b))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); // A global read + A local write - static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_more == - ds_write_issue_point)) || - (((i + buffer_load_b_stages) >= buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_less == - ds_write_issue_point))) - { - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - } - if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_more == - buffer_load_issue_point_a)) || - (((i + buffer_load_b_stages) >= buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_less == - buffer_load_issue_point_a))) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + static_ford>{}([&](auto ii) { + constexpr auto i = Number{}]>{}; + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + ds_write_issue_point)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == ds_write_issue_point))) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); // lds synchronization, prefetch next loop local A @@ -511,17 +510,16 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); // Initialize C @@ -554,130 +552,129 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[mfma_reg_buf] - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[mfma_reg_buf] + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == MRepeat - 2) { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } else if constexpr(m0.value == (MRepeat - 1)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } else { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(mfma_reg_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); @@ -706,100 +703,100 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == (MRepeat - 2)) { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } else if constexpr(m0.value == MRepeat - 1) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } else { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } }); @@ -807,58 +804,58 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); if constexpr(m0.value < (MRepeat - 2)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); @@ -870,53 +867,53 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); if constexpr(m0.value < (MRepeat - 2)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp index 3ed035e1cc..b064889a8a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp @@ -259,25 +259,26 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< // Stage 1 // global read more - static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + static_ford>{}([&](auto ii) { + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_more == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } + if constexpr(imfma % buffer_load_issue_point_interval_more == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); // global read less - static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + static_ford< + Sequence<(num_total_stages - 2 - buffer_load_stages_more), num_mfma_perstage>>{}( + [&](auto ii) { + constexpr auto imfma = Number{}]>{}; __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA if constexpr(imfma % buffer_load_issue_point_interval_less == 0) { @@ -288,22 +289,20 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read } }); - }); // Stage 2, Sync // lds synchronization, prefetch next loop local A - static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + static_ford>{}([&](auto ii) { + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); } @@ -463,25 +462,24 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< // Local prefetch 1, sync the async load __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); block_sync_lds(); - static_for<0, LocalPrefetchStages, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * - (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_m3_k, - make_tuple( - I0, I0, Number{}, I0, Number{}), - a_block_bufs(I0), - a_thread_desc_, - make_tuple( - I0, I0, Number{}, k, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k = Number{}]>{}; + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(I0, I0, Number{}, I0, Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple( + I0, I0, Number{}, k, Number{}), + a_thread_buf); + }); }); // Global prefetch 2 @@ -583,105 +581,97 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< static_for<0, MRepeat, 1>{}([&](auto m0) { constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto ik_major = k0 / KXdlPack; constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset( - make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset( - make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = a_scale_thread_desc.CalculateOffset( + make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = b_scale_thread_desc.CalculateOffset( + make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; - vector_type - b_scale_thread_vec_up; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type + b_scale_thread_vec_up; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); - // B Gate scale - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); - // B Up scale - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up( - scale_comp_buf)[Number{}]; - }); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + // B Gate scale + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + // B Up scale + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up( + scale_comp_buf)[Number{}]; + }); - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = b_thread_bufs + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[scale_comp_buf] + [Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up [scale_comp_buf][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up - [scale_comp_buf][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation A * Gate - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - // MFMA accumulation A * Up - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up - .template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation A * Gate + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + // MFMA accumulation A * Up + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == SwitchM) @@ -798,91 +788,91 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< static_for<0, MRepeat, 1>{}([&](auto m0) { constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto ik_major = k0 / KXdlPack; constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - vector_type b_scale_thread_vec_up; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; - }); - // B Gate scale - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); - // B Up scale - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up(I0)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I0][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation A * Gate - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - // MFMA accumulation A * Gate - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; }); + // B Gate scale + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + // B Up scale + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation A * Gate + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + // MFMA accumulation A * Gate + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == SwitchM) { @@ -920,89 +910,89 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< static_for<0, MRepeat, 1>{}([&](auto m0) { constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto ik_major = k0 / KXdlPack; constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - vector_type b_scale_thread_vec_up; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I1)[Number{}]; - }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I1)[Number{}]; - }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up(I1)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I1][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation A * Gate - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - // MFMA accumulation A * Up - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; }); + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I1)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation A * Gate + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + // MFMA accumulation A * Up + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) { @@ -1040,91 +1030,91 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< static_for<0, MRepeat, 1>{}([&](auto m0) { constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto ik_major = k0 / KXdlPack; constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - vector_type b_scale_thread_vec_up; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; - }); - // B Gate scale - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); - // B Up scale - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up(I0)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I0][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation A * Gate - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - // MFMA accumulation A * up - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; }); + // B Gate scale + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + // B Up scale + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation A * Gate + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + // MFMA accumulation A * up + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp index a8ceddc1a3..e0a9b43986 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp @@ -357,31 +357,31 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k) { + static_ford< + Sequence>{}( + [&](auto mkc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); }); - }); // Initialize C c_thread_buf.Clear(); @@ -448,118 +448,107 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1{}([&](auto m0) { + static_ford>{}([&](auto mkn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { - constexpr auto ik_major = k0 / KXdlPack; - constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset( - make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset( - make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = b_thread_bufs - [scale_comp_buf][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(scale_comp_buf)[Number{}]; }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(scale_comp_buf)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[scale_comp_buf][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k) { + static_ford>{}( + [&](auto mkc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, - xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), - 1>{}([&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); }); - }); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); }; @@ -611,257 +600,246 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1{}([&](auto m0) { + static_ford>{}([&](auto mkn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { - constexpr auto ik_major = k0 / KXdlPack; - constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; }); - // constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0; + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k) { + static_ford>{}( + [&](auto mkc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); }); - }); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { + static_ford>{}([&](auto mkn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { - constexpr auto ik_major = k0 / KXdlPack; - constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I1)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I1)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat, 1>{}([&](auto m0) { + static_ford>{}([&](auto mkn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { - constexpr auto ik_major = k0 / KXdlPack; - constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp index 0de34baa42..b18ea372a1 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp @@ -261,54 +261,49 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + static_ford>{}([&](auto ii) { + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_more == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } + if constexpr(imfma % buffer_load_issue_point_interval_more == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier( - 0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); // global read less - static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_less == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier( - 0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + static_ford>{}([&](auto ii) { + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_less == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); // Stage 2, Sync // lds synchronization, prefetch next loop local A - static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier( - 0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + static_ford>{}([&](auto ii) { + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); } else @@ -536,25 +531,24 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * - (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_m3_k, - make_tuple( - I0, I0, Number{}, I0, Number{}), - a_block_bufs(I0), - a_thread_desc_, - make_tuple( - I0, I0, Number{}, k, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k = Number{}]>{}; + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(I0, I0, Number{}, I0, Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple( + I0, I0, Number{}, k, Number{}), + a_thread_buf); + }); }); // Global prefetch 2 @@ -628,83 +622,76 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}([&](auto m0) { constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto ik_major = k0 / KXdlPack; constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset( - make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset( - make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = a_scale_thread_desc.CalculateOffset( + make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = b_scale_thread_desc.CalculateOffset( + make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = b_thread_bufs - [scale_comp_buf][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[scale_comp_buf] + [Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == SwitchM) @@ -802,73 +789,73 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}([&](auto m0) { constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto ik_major = k0 / KXdlPack; constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == SwitchM) { @@ -906,73 +893,73 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}([&](auto m0) { constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto ik_major = k0 / KXdlPack; constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I1)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I1)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) { @@ -1010,73 +997,73 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}([&](auto m0) { constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto ik_major = k0 / KXdlPack; constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp index 9dccd9b4e6..b4c64b718b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp @@ -280,17 +280,16 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); // Initialize C @@ -318,51 +317,46 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; - }); - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); // loop prefetch copy - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); HotLoopScheduler(); @@ -387,79 +381,71 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); // tail Local Prefetch A1 - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle // latency @@ -467,32 +453,29 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp index 3da80ba25e..eff2c09571 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp @@ -281,17 +281,16 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_bufs(I0)); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_bufs(I0)); }); // Local prefill A2 @@ -323,18 +322,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_bufs(local_read_buf)); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_bufs(local_read_buf)); }); a_blockwise_copy.RunWrite( @@ -343,36 +341,31 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf] - [Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); HotLoopScheduler(); @@ -398,48 +391,44 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(local_read_reg), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_bufs(local_read_reg)); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(local_read_reg), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_bufs(local_read_reg)); }); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(mfma_reg), mfma_reg); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); HotLoopScheduler(); @@ -455,46 +444,42 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(local_read_reg), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_bufs(local_read_reg)); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(local_read_reg), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_bufs(local_read_reg)); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); HotLoopScheduler(); @@ -502,32 +487,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp index 4c20f21d22..c7b7948e41 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp @@ -258,52 +258,50 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3= 3 ? 1 : 0; // B global read - static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); + static_ford>{}([&](auto ii) { + constexpr auto i = Number{}]>{}; + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); - if constexpr(((i < buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_more == 0)) || - ((i >= buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_less == 0))) - { - __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0); - } + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == 0)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == 0))) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0); + } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier( - SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); - } - }); + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); + } }); // A global read + A local write - static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); - if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_more == 0)) || - (((i + buffer_load_b_stages) >= buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_less == 0))) - { - __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_LDS_WRITE, 1, 0); - } - if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_more == - buffer_load_issue_point_a)) || - (((i + buffer_load_b_stages) >= buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_less == - buffer_load_issue_point_a))) - { - __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0); - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier( - SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); - } - }); + static_ford>{}([&](auto ii) { + constexpr auto i = Number{}]>{}; + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == 0)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == 0))) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_LDS_WRITE, 1, 0); + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0); + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); + } }); // lds synchronization, prefetch next loop local A @@ -379,17 +377,16 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); // Initialize C @@ -416,119 +413,114 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == (MRepeat - 2)) { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<0>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number<0>{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } else if constexpr(m0.value == (MRepeat - 1)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } else { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(mfma_reg_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); @@ -552,88 +544,87 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == (MRepeat - 2)) { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number<0>{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<0>{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } else if constexpr(m0.value == (MRepeat - 1)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } else { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } }); @@ -641,50 +632,50 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value < (MRepeat - 2)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); @@ -694,46 +685,46 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value < (MRepeat - 2)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp index 19cd141a3e..bfeeb715e3 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp @@ -347,22 +347,19 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1{}); constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); - static_for<0, num_scale_m_block, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * b_scale_thread_buf[Number{}]; }); - }); // Local prefill A1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); @@ -409,18 +406,16 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); // Initialize C @@ -448,114 +443,104 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - constexpr index_t cscale_offset = - CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number< - b_thread_desc_.CalculateOffset(make_tuple( - n0, - I0, - kscale0 * KRepeat / num_scale_k_block + k0, - ik))>{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}( - [&](auto t) { - using pk_fma_type = - typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = - __builtin_elementwise_fma( - c_thread_buf_per_scale - .GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec - .template AsType()[Number<0>{}], - c_thread_buf - .GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; }); - }); static_for<0, MRepeat, 1>{}([&](auto m0) { a_scale_thread_copy.Run(a_scale_grid_desc, @@ -606,231 +591,207 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - }); - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * b_scale_thread_buf[Number{}]; }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp index 16bb7d0e2f..48f815c509 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp @@ -538,18 +538,16 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); #if 1 @@ -717,28 +715,28 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(Number{}), - a_thread_desc_, - make_tuple(Number<(m0 + LocalPrefetchStages + - HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(Number{}), + a_thread_desc_, + make_tuple(Number<(m0 + LocalPrefetchStages + + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); }); @@ -841,26 +839,25 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3= (MRepeat - LocalPrefetchStages) ? I1 : I0; - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + LocalPrefetchStages) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(Number{}), - a_thread_desc_, - make_tuple(Number<(m0 + LocalPrefetchStages) % 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + LocalPrefetchStages) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(Number{}), + a_thread_desc_, + make_tuple(Number<(m0 + LocalPrefetchStages) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); }); @@ -943,28 +940,27 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + LocalPrefetchStages + HotloopLocalBufSwitch) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + LocalPrefetchStages + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); @@ -1042,22 +1038,22 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(Number<(m0 + LocalPrefetchStages) % 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(Number<(m0 + LocalPrefetchStages) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp index 735891c556..7de7376e2a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp @@ -375,25 +375,22 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1< constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{}); constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{}); constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); - static_for<0, num_scale_m_block, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - c_scale_thread_buf_up(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf_up[Number{}]; - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * b_scale_thread_buf[Number{}]; + c_scale_thread_buf_up(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf_up[Number{}]; }); - }); // Local prefill A1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); @@ -450,18 +447,16 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1< // Local prefetch A1 block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); // Initialize C @@ -496,149 +491,134 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1< a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - vector_type c_scale_thread_vec_up; - constexpr index_t cscale_offset = - CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + vector_type c_scale_thread_vec_up; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec_up.template AsType()(Number<0>{}) = - c_scale_thread_buf_up[Number{}]; - c_scale_thread_vec_up.template AsType()(Number<1>{}) = - c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number< - b_thread_desc_.CalculateOffset(make_tuple( - n0, - I0, - kscale0 * KRepeat / num_scale_k_block + k0, - ik))>{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[mfma_reg_buf][Number< - b_thread_desc_.CalculateOffset(make_tuple( - n0, - I0, - kscale0 * KRepeat / num_scale_k_block + k0, - ik))>{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_per_scale_up.GetVectorTypeReference( - Number<0>{})); - }); - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}( - [&](auto t) { - using pk_fma_type = - typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = - __builtin_elementwise_fma( - c_thread_buf_per_scale - .GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec - .template AsType()[Number<0>{}], - c_thread_buf - .GetVectorTypeReference(Number{}) - .template AsType()[t]); - c_thread_buf_up.GetVectorTypeReference(Number{}) - .template AsType()(t) = - __builtin_elementwise_fma( - c_thread_buf_per_scale_up - .GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec_up - .template AsType()[Number<0>{}], - c_thread_buf_up - .GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up + [mfma_reg_buf][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - c_scale_thread_buf_up(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf_up[Number{}]; - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + c_scale_thread_buf_up(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf_up[Number{}]; }); - }); a_scale_thread_copy.Run(a_scale_grid_desc, a_scale_grid_buf, @@ -699,310 +679,277 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1< block_sync_lds(); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - vector_type c_scale_thread_vec_up; - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + vector_type c_scale_thread_vec_up; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec_up.template AsType()(Number<0>{}) = - c_scale_thread_buf_up[Number{}]; - c_scale_thread_vec_up.template AsType()(Number<1>{}) = - c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); - }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - c_thread_buf_up.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec_up.template AsType()[Number<0>{}], - c_thread_buf_up.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - c_scale_thread_buf_up(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf_up[Number{}]; - }); - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * b_scale_thread_buf[Number{}]; + c_scale_thread_buf_up(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf_up[Number{}]; }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); // __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - vector_type c_scale_thread_vec_up; - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + vector_type c_scale_thread_vec_up; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec_up.template AsType()(Number<0>{}) = - c_scale_thread_buf_up[Number{}]; - c_scale_thread_vec_up.template AsType()(Number<1>{}) = - c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); - }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - c_thread_buf_up.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec_up.template AsType()[Number<0>{}], - c_thread_buf_up.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - vector_type c_scale_thread_vec_up; - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + vector_type c_scale_thread_vec_up; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec_up.template AsType()(Number<0>{}) = - c_scale_thread_buf_up[Number{}]; - c_scale_thread_vec_up.template AsType()(Number<1>{}) = - c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); - }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - c_thread_buf_up.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec_up.template AsType()[Number<0>{}], - c_thread_buf_up.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp index e5cff43bf9..cbdf56f2f5 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp @@ -569,17 +569,16 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< // Local prefetch A1 block_sync_lds(); - static_for<0, 2, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); __builtin_amdgcn_sched_barrier(0); @@ -729,80 +728,80 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } else if constexpr(m0.value == (MRepeat - 1)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } else { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(mfma_reg_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); @@ -916,62 +915,62 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } else if constexpr(m0.value == (MRepeat - 1)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } else { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } }); @@ -1058,22 +1057,22 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< if constexpr(m0.value < (MRepeat - 2)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); @@ -1156,18 +1155,18 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< if constexpr(m0.value < (MRepeat - 2)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp index a76be40753..cc4dc0ed5c 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp @@ -343,22 +343,19 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{}); constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); - static_for<0, num_scale_m_block, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * b_scale_thread_buf[Number{}]; }); - }); __builtin_amdgcn_sched_barrier(0); // Local prefill A1 @@ -402,18 +399,16 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< // Local prefetch A1 block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); // Initialize C @@ -441,115 +436,105 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - constexpr index_t cscale_offset = - CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number< - b_thread_desc_.CalculateOffset(make_tuple( - n0, - I0, - kscale0 * KRepeat / num_scale_k_block + k0, - ik))>{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}( - [&](auto t) { - using pk_fma_type = - typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = - __builtin_elementwise_fma( - c_thread_buf_per_scale - .GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec - .template AsType()[Number<0>{}], - c_thread_buf - .GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; }); - }); __builtin_amdgcn_sched_barrier(0); a_scale_thread_copy.Run(a_scale_grid_desc, @@ -597,233 +582,209 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< block_sync_lds(); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - }); - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * b_scale_thread_buf[Number{}]; }); block_sync_lds(); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); // __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp index f2d85150d6..36fd54fce6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp @@ -526,17 +526,16 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< // Local prefetch A1 block_sync_lds(); - static_for<0, 2, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); #if 0 @@ -672,80 +671,80 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } else if constexpr(m0.value == (MRepeat - 1)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } else { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(mfma_reg_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); @@ -830,62 +829,62 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } else if constexpr(m0.value == (MRepeat - 1)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } else { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } }); @@ -947,22 +946,22 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< if constexpr(m0.value < (MRepeat - 2)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); @@ -1023,18 +1022,18 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< if constexpr(m0.value < (MRepeat - 2)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp index 1f8d3b28b5..43a405bcaa 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp @@ -496,71 +496,74 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_bufs(I0), - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_bufs(I0), - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_bufs_up(I0), - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf_up); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I0), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs_up(I0), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf_up); + }); }); // Global prefetch 2 @@ -664,116 +667,97 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; - vector_type - b_scale_thread_vec_up; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(scale_comp_buf)[Number{}]; + }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up( - scale_comp_buf)[Number{}]; - }); + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(scale_comp_buf)[Number{}]; + }); - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up( + scale_comp_buf)[Number{}]; + }); - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()( - ik) = a_thread_buf - [Number{}]; - b_thread_vec.template AsType()( - ik) = b_thread_buf - [Number{}]; - b_thread_vec_up.template AsType()( - ik) = b_thread_buf_up - [Number{}]; - }); + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference( - Number{})); - - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec_up - .template AsType(), - b_scale_thread_vec_up - .template AsType(), - c_thread_buf_up.GetVectorTypeReference( - Number{})); - }); - }); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); }); @@ -927,105 +911,97 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - vector_type b_scale_thread_vec_up; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up(I0)[Number{}]; - }); + using mfma_input_type_b = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_buf_up[Number{}]; - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up - .template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); - }); - }); + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); }); - }); __builtin_amdgcn_s_waitcnt(3952); block_sync_lds(); @@ -1033,277 +1009,261 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_bufs(I1), - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_bufs(I1), - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_bufs_up(I1), - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf_up); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I1), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I1), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs_up(I1), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf_up); + }); }); - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - vector_type b_scale_thread_vec_up; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I1)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I1)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I1)[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up(I1)[Number{}]; - }); + using mfma_input_type_b = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_buf_up[Number{}]; - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up - .template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); - }); - }); + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); }); - }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - vector_type b_scale_thread_vec_up; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up(I0)[Number{}]; - }); + using mfma_input_type_b = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_buf_up[Number{}]; - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up - .template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); - }); - }); + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); }); - }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp index e25136afb1..c426a0e1f4 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp @@ -529,71 +529,74 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf_up, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf_up); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf_up, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf_up); + }); }); // Initialize C @@ -689,116 +692,97 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; - vector_type - b_scale_thread_vec_up; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(scale_comp_buf)[Number{}]; + }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up( - scale_comp_buf)[Number{}]; - }); + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(scale_comp_buf)[Number{}]; + }); - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up( + scale_comp_buf)[Number{}]; + }); - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()( - ik) = a_thread_buf - [Number{}]; - b_thread_vec.template AsType()( - ik) = b_thread_buf - [Number{}]; - b_thread_vec_up.template AsType()( - ik) = b_thread_buf_up - [Number{}]; - }); + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference( - Number{})); - - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec_up - .template AsType(), - b_scale_thread_vec_up - .template AsType(), - c_thread_buf_up.GetVectorTypeReference( - Number{})); - }); - }); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); }); @@ -957,382 +941,358 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - vector_type b_scale_thread_vec_up; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up(I0)[Number{}]; - }); + using mfma_input_type_b = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_buf_up[Number{}]; - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up - .template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); - }); - }); + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); }); - }); block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf_up, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf_up); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf_up, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf_up); + }); }); - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - vector_type b_scale_thread_vec_up; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I1)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I1)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I1)[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up(I1)[Number{}]; - }); + using mfma_input_type_b = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_buf_up[Number{}]; - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up - .template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); - }); - }); + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); }); - }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - vector_type b_scale_thread_vec_up; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up(I0)[Number{}]; - }); + using mfma_input_type_b = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_buf_up[Number{}]; - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up - .template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); - }); - }); + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); }); - }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp index 0be289287d..d598d281de 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp @@ -324,133 +324,128 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); + static_ford>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); }); - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_buf[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_buf[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference( - Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); // Prefetch a_scales static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { @@ -510,132 +505,126 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); }); - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_buf[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_buf[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp index cf6fa231e1..2d295c881b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp @@ -492,49 +492,51 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); }); // Initialize C @@ -603,91 +605,78 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(scale_comp_buf)[Number{}]; + }); - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(scale_comp_buf)[Number{}]; + }); - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()( - ik) = a_thread_buf - [Number{}]; - b_thread_vec.template AsType()( - ik) = b_thread_buf - [Number{}]; - }); + vector_type a_thread_vec; + vector_type b_thread_vec; - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference( - Number{})); - }); - }); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); @@ -805,299 +794,281 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); }); - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I1)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I1)[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp index 6f086eed05..1d3f3c8bff 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp @@ -459,49 +459,51 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_bufs(I0), - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_bufs(I0), - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I0), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); }); // Global prefetch 2 @@ -577,91 +579,85 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford< + Sequence>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs( - scale_comp_buf)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_ford>{}([&](auto + kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()( - ik) = a_thread_buf - [Number{}]; - b_thread_vec.template AsType()( - ik) = b_thread_buf - [Number{}]; - }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference( - Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); // k indexes mapping to threads for 32x32x64: // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc. @@ -774,83 +770,81 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); __builtin_amdgcn_s_waitcnt(3952); block_sync_lds(); @@ -858,210 +852,206 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_bufs(I1), - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_bufs(I1), - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I1), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I1), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); }); - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I1)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I1)[Number{}]; - }); + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp index ae4504d6ba..723ef9cd1e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -220,69 +220,9 @@ struct BlockwiseGemmXdlops_pipeline_v1{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); - }); - }); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - - i += 1; - } while(i < (num_loop - 1)); - } - - // tail - if constexpr(TailNum == TailNumber::Full) - { - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, make_tuple(m0, I0, I0, Number{}), a_block_buf, @@ -298,34 +238,85 @@ struct BlockwiseGemmXdlops_pipeline_v1>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } } @@ -553,51 +544,51 @@ struct BlockwiseGemmXdlops_pipeline_v1{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - // The block_sync_lds() here performs double duty: - // A) safeguard against data hazard because barrier from - // blockwise_gemm is moved here B) reduce VMEM FIFO congestion by - // applying small delays to different wavefronts It is performed - // near the end of MAC cluster to minimize lgkmcnt penalty - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from + // blockwise_gemm is moved here B) reduce VMEM FIFO congestion by + // applying small delays to different wavefronts It is performed + // near the end of MAC cluster to minimize lgkmcnt penalty + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); @@ -642,46 +633,43 @@ struct BlockwiseGemmXdlops_pipeline_v1{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); @@ -942,73 +930,9 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); - }); - }); - - block_sync_lds(); - a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); - b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - block_sync_lds_direct_load(); - - i += 1; - } while(i < (num_loop - 1)); - } - - // tail - if constexpr(TailNum == TailNumber::Full) - { - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, make_tuple(m0, I0, I0, Number{}), a_block_buf, @@ -1024,34 +948,89 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + + block_sync_lds_direct_load(); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp index 6da21fcec7..cb5ef0e700 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp @@ -406,22 +406,19 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale{}); constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); - static_for<0, num_scale_m_block, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * b_scale_thread_buf[Number{}]; }); - }); // Local prefill 1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); @@ -512,74 +509,64 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - constexpr index_t cscale_offset = - CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - - c_thread_buf(Number{}) += - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[Number{}] * - type_convert( - c_scale_thread_buf[Number{}]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert(c_scale_thread_buf[Number{}]); }); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; }); - }); block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k) { @@ -642,72 +629,59 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - - c_thread_buf(Number{}) += - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[Number{}] * - type_convert( - c_scale_thread_buf[Number{}]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert(c_scale_thread_buf[Number{}]); }); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - }); - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * b_scale_thread_buf[Number{}]; }); block_sync_lds(); @@ -733,108 +707,90 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - - c_thread_buf(Number{}) += - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[Number{}] * - type_convert( - c_scale_thread_buf[Number{}]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert(c_scale_thread_buf[Number{}]); }); }); __builtin_amdgcn_sched_barrier(0); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - - c_thread_buf(Number{}) += - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[Number{}] * - type_convert( - c_scale_thread_buf[Number{}]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert(c_scale_thread_buf[Number{}]); }); }); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp index 6b95cd6ec7..79a72f4b5e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp @@ -277,38 +277,37 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(b_scale_thread_buf[n0]); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(b_scale_thread_buf[n0]); }); }); @@ -358,37 +357,34 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(b_scale_thread_buf[n0]); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(b_scale_thread_buf[n0]); }); }); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp index 189d0ad2c3..4c2d11c8e6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp @@ -331,56 +331,60 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx 47 96 --> 111| 160 --> 175 224 --> 239| etc. // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc. // k = 0 k = 1 - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = - k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; - - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); + // LDS reads for A + static_ford>{}( + [&](auto km_chunk) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / BPackedSize / KThreadChunk, 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); + // LDS reads for B + static_ford>{}( + [&](auto kn_chunk) { + constexpr auto k = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); }); - }); // load for next k loop block_sync_lds(); @@ -389,82 +393,78 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThreadA && - 0 < ScalesPerXdlopsRunPerThreadB, - "Must have at least one scale per Xdlops per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThreadA && + 0 < ScalesPerXdlopsRunPerThreadB, + "Must have at least one scale per Xdlops per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_buf[Number{}]; - }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_buf[Number{}]; + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + }); + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference( - Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); // Prefetch a_scales static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { @@ -519,131 +519,130 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}([&](auto k) { - constexpr auto k_step = - k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; - - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); + // LDS reads for A + static_ford< + Sequence>{}( + [&](auto km_chunk) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / BPackedSize / KThreadChunk, 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); + // LDS reads for B + static_ford< + Sequence>{}( + [&](auto kn_chunk) { + constexpr auto k = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); }); - }); - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThreadA && - 0 < ScalesPerXdlopsRunPerThreadB, - "Must have at least one scale per Xdlops per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThreadA && + 0 < ScalesPerXdlopsRunPerThreadB, + "Must have at least one scale per Xdlops per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_buf[Number{}]; - }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_buf[Number{}]; + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + }); + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp index 18bc1e130b..e823a1f573 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp @@ -283,34 +283,31 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); @@ -354,34 +351,29 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); @@ -409,32 +401,28 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }; @@ -460,32 +448,28 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } else if constexpr(TailNum == TailNumber::Two) @@ -788,52 +772,52 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - // The block_sync_lds() here performs double duty: - // A) safeguard against data hazard because barrier from - // blockwise_gemm is moved here B) reduce VMEM FIFO congestion - // by applying small delays to different wavefronts It is - // performed near the end of MAC cluster to minimize lgkmcnt - // penalty - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from + // blockwise_gemm is moved here B) reduce VMEM FIFO congestion + // by applying small delays to different wavefronts It is + // performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); @@ -887,46 +871,46 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); @@ -963,46 +947,43 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); @@ -1039,46 +1020,43 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp index 712d26c897..96f683e60e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp @@ -349,39 +349,39 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[m0]) * - type_convert(b_scale_thread_buf[I0]); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[m0]) * + type_convert(b_scale_thread_buf[I0]); }); }); @@ -436,58 +436,57 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}([&](auto iprefetch) { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[m0]) * - type_convert(b_scale_thread_buf[I0]); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[m0]) * + type_convert(b_scale_thread_buf[I0]); }); }); @@ -526,57 +525,54 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[m0]) * - type_convert(b_scale_thread_buf[I0]); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[m0]) * + type_convert(b_scale_thread_buf[I0]); }); }); }; @@ -584,57 +580,54 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[m0]) * - type_convert(b_scale_thread_buf[I0]); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[m0]) * + type_convert(b_scale_thread_buf[I0]); }); }); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp index 43ff439e0d..cb56450721 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp @@ -264,54 +264,50 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto iprefetch) { // ------------------------------------------------------------------------------------------- block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); @@ -336,53 +332,48 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto iprefetch) { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); @@ -391,102 +382,94 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }; if constexpr(TailNum == TailNumber::One) { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } else if constexpr(TailNum == TailNumber::Two) @@ -823,61 +806,52 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - // The block_sync_lds() here performs double duty: - // A) safeguard against data hazard because barrier from - // blockwise_gemm is moved here B) reduce VMEM FIFO congestion - // by applying small delays to different wavefronts It is - // performed near the end of MAC cluster to minimize lgkmcnt - // penalty - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) - // { - // constexpr index_t c_offset = - // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - // c_thread_buf(Number{}) += - // c_thread_buf_per_scale[Number{}] * - // type_convert(b_scale_thread_buf[n0]); - // }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from + // blockwise_gemm is moved here B) reduce VMEM FIFO congestion + // by applying small delays to different wavefronts It is + // performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); @@ -944,54 +918,46 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - // constexpr index_t c_offset = - // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - // c_thread_buf(Number{}) += - // c_thread_buf_per_scale[Number{}] * - // type_convert(b_scale_thread_buf[n0]); - // }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); @@ -1041,54 +1007,43 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - // constexpr index_t c_offset = - // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - // c_thread_buf(Number{}) += - // c_thread_buf_per_scale[Number{}] * - // type_convert(b_scale_thread_buf[n0]); - // }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); @@ -1125,54 +1080,43 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - // constexpr index_t c_offset = - // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - // c_thread_buf(Number{}) += - // c_thread_buf_per_scale[Number{}] * - // type_convert(b_scale_thread_buf[n0]); - // }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp index b4f90065dd..82d388ef9a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp @@ -363,34 +363,29 @@ struct BlockwiseGemmXdlops_pipeline_v3{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); @@ -423,32 +418,28 @@ struct BlockwiseGemmXdlops_pipeline_v3{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle // latency diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp index b78a9d3199..c14a06597c 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp @@ -471,42 +471,41 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - c_thread_buf(Number{}) += - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[Number{}] * - type_convert(c_scale_thread_buf[m0]); - }); + .template AsType()[Number{}] * + type_convert(c_scale_thread_buf[m0]); }); }); @@ -573,41 +572,39 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - c_thread_buf(Number{}) += - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[Number{}] * - type_convert(c_scale_thread_buf[m0]); - }); + .template AsType()[Number{}] * + type_convert(c_scale_thread_buf[m0]); }); }); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp index 63177c8d45..f7e88d75dd 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp @@ -428,34 +428,29 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale{})); } - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); @@ -490,32 +485,28 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); __builtin_amdgcn_sched_barrier(0); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp index 8d98a36dcd..a179c6c3bd 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp @@ -459,49 +459,51 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_bufs(I0), - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_bufs(I0), - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I0), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); }); // Global prefetch 2 @@ -577,91 +579,85 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford< + Sequence>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs( - scale_comp_buf)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_ford>{}([&](auto + kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()( - ik) = a_thread_buf - [Number{}]; - b_thread_vec.template AsType()( - ik) = b_thread_buf - [Number{}]; - }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference( - Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); // k indexes mapping to threads for 32x32x64: // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc. @@ -774,83 +770,81 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); __builtin_amdgcn_s_waitcnt(3952); block_sync_lds(); @@ -858,210 +852,206 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_bufs(I1), - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_bufs(I1), - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I1), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I1), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); }); - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I1)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I1)[Number{}]; - }); + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp index ff09fd011f..67a9769aca 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp @@ -261,54 +261,49 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + static_ford>{}([&](auto ii) { + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_more == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } + if constexpr(imfma % buffer_load_issue_point_interval_more == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier( - 0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); // global read less - static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_less == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier( - 0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + static_ford>{}([&](auto ii) { + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_less == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); // Stage 2, Sync // lds synchronization, prefetch next loop local A - static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier( - 0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + static_ford>{}([&](auto ii) { + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); } else @@ -537,25 +532,24 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * - (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_m3_k, - make_tuple( - I0, I0, Number{}, I0, Number{}), - a_block_bufs(I0), - a_thread_desc_, - make_tuple( - I0, I0, Number{}, k, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k = Number{}]>{}; + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(I0, I0, Number{}, I0, Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple( + I0, I0, Number{}, k, Number{}), + a_thread_buf); + }); }); // Global prefetch 2 diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp index 7b94ed5086..d659f2c2e0 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp @@ -368,79 +368,10 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf] - [Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - HotLoopScheduler(); - }; - - LoopFunc(I1, I1, I0, I0); - LoopFunc(I0, I0, I1, I1); - - i += HotloopUnroll; - } while(i < (num_loop - PrefetchStages)); - } - - auto ReadWriteCompFunc = [&](auto lds_read_buf, - auto lds_read_reg_buf, - auto lds_write_buf, - auto mfma_reg_buf) { - block_sync_lds(); - - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(lds_read_buf), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(lds_read_reg_buf)); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(lds_read_buf), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(lds_read_reg_buf)); - }); - }); - - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf)); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf)); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; vector_type a_thread_vec; vector_type b_thread_vec; @@ -463,11 +394,72 @@ struct BlockwiseGemmXdlops_pipeline_v4(), c_thread_buf.GetVectorTypeReference(Number{})); }); - }); - }); - HotLoopScheduler(); - }; + HotLoopScheduler(); + }; + + LoopFunc(I1, I1, I0, I0); + LoopFunc(I0, I0, I1, I1); + + i += HotloopUnroll; + } while(i < (num_loop - PrefetchStages)); + } + + auto ReadWriteCompFunc = + [&](auto lds_read_buf, auto lds_read_reg_buf, auto lds_write_buf, auto mfma_reg_buf) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf)); + + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + + HotLoopScheduler(); + }; auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) { block_sync_lds(); @@ -491,64 +483,60 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); HotLoopScheduler(); }; auto CompFunc = [&](auto mfma_reg_buf) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }; // tail @@ -918,81 +906,10 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v4{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf] - [Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - HotLoopScheduler(); - }; - - LoopFunc(I1, I1, I0, I0); - LoopFunc(I0, I0, I1, I1); - - i += HotloopUnroll; - } while(i < (num_loop - PrefetchStages)); - } - - auto ReadWriteCompFunc = [&](auto lds_read_buf, - auto lds_read_reg_buf, - auto lds_write_buf, - auto mfma_reg_buf) { - block_sync_lds_direct_load(); - - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(lds_read_buf), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(lds_read_reg_buf)); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(lds_read_buf), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(lds_read_reg_buf)); - }); - }); - - a_blockwise_copy.Run( - a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(lds_write_buf)); - b_blockwise_copy.Run( - b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(lds_write_buf)); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; vector_type a_thread_vec; vector_type b_thread_vec; @@ -1015,11 +932,74 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v4(), c_thread_buf.GetVectorTypeReference(Number{})); }); - }); - }); - HotLoopScheduler(); - }; + HotLoopScheduler(); + }; + + LoopFunc(I1, I1, I0, I0); + LoopFunc(I0, I0, I1, I1); + + i += HotloopUnroll; + } while(i < (num_loop - PrefetchStages)); + } + + auto ReadWriteCompFunc = + [&](auto lds_read_buf, auto lds_read_reg_buf, auto lds_write_buf, auto mfma_reg_buf) { + block_sync_lds_direct_load(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(lds_write_buf)); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(lds_write_buf)); + + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + + HotLoopScheduler(); + }; auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) { block_sync_lds_direct_load(); @@ -1043,64 +1023,60 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v4{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); HotLoopScheduler(); }; auto CompFunc = [&](auto mfma_reg_buf) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }; // tail diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp index 3bbf4bb690..b48102b023 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp @@ -356,23 +356,23 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(I0)); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(I0), - b_scale_thread_bufs(I0)[n0], - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(I0)); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(I0)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(I0), + b_scale_thread_bufs(I0)[n0], + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(I0)); }); }); @@ -477,80 +477,10 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf] - [Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - HotLoopScheduler(); - }; - - LoopFunc(I1, I1, I0, I0); - LoopFunc(I0, I0, I1, I1); - - i += HotloopUnroll; - } while(i < (num_loop - PrefetchStages)); - } - - auto ReadWriteCompFunc = [&](auto lds_read_buf, - auto lds_read_reg_buf, - auto lds_write_buf, - auto mfma_reg_buf) { - block_sync_lds(); - - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(lds_read_buf), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(lds_read_reg_buf)); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(lds_read_buf), - b_scale_thread_bufs(lds_read_buf)[n0], - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(lds_read_reg_buf)); - }); - }); - - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf)); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf)); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; vector_type a_thread_vec; vector_type b_thread_vec; @@ -573,11 +503,73 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale(), c_thread_buf.GetVectorTypeReference(Number{})); }); - }); - }); - HotLoopScheduler(); - }; + HotLoopScheduler(); + }; + + LoopFunc(I1, I1, I0, I0); + LoopFunc(I0, I0, I1, I1); + + i += HotloopUnroll; + } while(i < (num_loop - PrefetchStages)); + } + + auto ReadWriteCompFunc = + [&](auto lds_read_buf, auto lds_read_reg_buf, auto lds_write_buf, auto mfma_reg_buf) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_scale_thread_bufs(lds_read_buf)[n0], + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf)); + + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + + HotLoopScheduler(); + }; auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) { block_sync_lds(); @@ -602,64 +594,60 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); HotLoopScheduler(); }; auto CompFunc = [&](auto mfma_reg_buf) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp index 2dfeb1d0cb..059544b239 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp @@ -590,27 +590,26 @@ struct BlockwiseGemmXdlops_pipeline_v5{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = a_thread_buf - [Number{}]; - }); - static_for<0, KPack, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = b_thread_buf - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; }); + static_for<0, KPack, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); HotLoopScheduler(); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index e3ce5e29c8..8bbf809521 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -365,56 +365,50 @@ struct BlockwiseGemmWMMA } else { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of - // k=0,kpack*1, .. - // read B - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0), - b_thread_buf); - // read A - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, I0, I0, I0, I0), - a_thread_buf); + static_ford>{}([&](auto nmk) { + constexpr auto n0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto k = Number{}]>{}; // k=0,1,2 instead of k=0,kpack*1, .. + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = - a_thread_buf[Number{}]; - }); - - static_for<0, KPack / B_KRow, 1>{}([&](auto i) { - b_thread_vec.template AsType()(i) = - b_thread_buf[Number{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - wmma_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack / A_KRow, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; }); + + static_for<0, KPack / B_KRow, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run<>(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } } @@ -862,60 +856,47 @@ struct BlockwiseGemmWMMA } else { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of - // k=0,kpack*1, .. - // read B - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0), - b_thread_buf); - // read A - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, I0, I0, I0, I0), - a_thread_buf); + static_ford>{}([&](auto nmk) { + constexpr auto n0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto k = Number{}]>{}; // k=0,1,2 instead of k=0,kpack*1, .. + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto i) { - b_thread_vec.template AsType()(i) = - b_thread_buf[Number{}]; - a_thread_vec.template AsType()(i) = - a_thread_buf[Number{}]; - }); - - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index f81e93d82a..53646a4eba 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -514,54 +514,54 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 __builtin_amdgcn_sched_barrier(0); } static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - // The block_sync_lds() here performs double duty: - // A) safeguard against data hazard because barrier from blockwise_gemm is - // moved here B) reduce VMEM FIFO congestion by applying small delays to - // different wavefronts It is performed near the end of MAC cluster to - // minimize lgkmcnt penalty - if constexpr(k.value == KPerThread - KPerInnerLoop && - k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 && - n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - - // TODO: insert setprio in more precise manner since we - // could have more than >1 MFMA instructions in single call - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from blockwise_gemm is + // moved here B) reduce VMEM FIFO congestion by applying small delays to + // different wavefronts It is performed near the end of MAC cluster to + // minimize lgkmcnt penalty + if constexpr(k.value == KPerThread - KPerInnerLoop && + k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 && + n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + + // TODO: insert setprio in more precise manner since we + // could have more than >1 MFMA instructions in single call + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); @@ -953,44 +953,43 @@ struct BlockwiseGemmXdlops_v2 auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - static_for<0, KPerThread / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... - static_for<0, MRepeat, 1>{}([&](auto m0) { - // read A - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, make_tuple(I0, I0, I0, I0), - a_thread_buf); + b_thread_buf); + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read B - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = a_thread_buf - [Number{}]; - b_thread_vec.template AsType()(i) = b_thread_buf - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp b/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp index e3d5afe6d8..ba061b20fa 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp @@ -98,13 +98,13 @@ struct BlockwiseSoftmax }); // calculate exp for elements, P=exp(s-max) - static_for<0, MRepeat, 1>{}([&](auto iM) { - static_for<0, KRepeat, 1>{}([&](auto iK) { - auto offset = Number{}; - in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset]) - ? 0 - : math::exp(in_thread_buf[offset] - max_value_buf(iM)); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + auto offset = Number{}; + in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset]) + ? 0 + : math::exp(in_thread_buf[offset] - max_value_buf(iM)); }); // sum data diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp index 6d1b454282..d234b9f846 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp @@ -480,19 +480,17 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf make_tuple(I0, I0), dy_thread_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - dy_elementwise_op(dy_thread_buf(Number{}), - dy_thread_buf[Number{}]); + dy_elementwise_op(dy_thread_buf(Number{}), dy_thread_buf[Number{}]); - AccDataType norm_x = (x_thread_buf[Number{}] - mean_thread_buf[iM]) * - inv_var_thread_buf[iM]; + AccDataType norm_x = + (x_thread_buf[Number{}] - mean_thread_buf[iM]) * inv_var_thread_buf[iM]; - tmp1_thread_buf(Number{}) = norm_x * dy_thread_buf[Number{}]; - }); + tmp1_thread_buf(Number{}) = norm_x * dy_thread_buf[Number{}]; }); ThreadwiseReduce::Reduce(tmp1_thread_buf, reduce_dscale_thread_buf); diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp index d1e7f35607..d516fd248c 100644 --- a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp @@ -502,15 +502,15 @@ struct EpilogueReduceCShuffle [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; }); // reduce in VGPR - static_for<0, mreduce_per_thread, 1>{}([&](auto im) { - static_for<0, nreduce_per_thread, 1>{}([&](auto in) { - constexpr auto offset = - Number{}; + static_ford>{}([&](auto ii) { + constexpr auto im = Number{}]>{}; + constexpr auto in = Number{}]>{}; + constexpr auto offset = + Number{}; - reduce_in_element_op(c_reduce_thread_buf(offset), - c_reduce_thread_buf(offset)); - }); + reduce_in_element_op(c_reduce_thread_buf(offset), + c_reduce_thread_buf(offset)); }); ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf); diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp index 87f3d50e10..f735124539 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp @@ -362,11 +362,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d make_tuple(I0), gamma_thread_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto m) { - static_for<0, NThreadSliceSize, 1>{}([&](auto n) { - constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n)); - h_thread_buf(Number{}) = h_thread_buf(Number{}) * gamma_thread_buf(n); - }); + static_ford>{}([&](auto mn) { + constexpr auto m = Number{}]>{}; + constexpr auto n = Number{}]>{}; + constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n)); + h_thread_buf(Number{}) = h_thread_buf(Number{}) * gamma_thread_buf(n); }); threadwise_beta_load_n.Run(beta_grid_desc_n, @@ -375,11 +375,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d make_tuple(I0), beta_thread_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto m) { - static_for<0, NThreadSliceSize, 1>{}([&](auto n) { - constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n)); - h_thread_buf(Number{}) = h_thread_buf(Number{}) + beta_thread_buf(n); - }); + static_ford>{}([&](auto mn) { + constexpr auto m = Number{}]>{}; + constexpr auto n = Number{}]>{}; + constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n)); + h_thread_buf(Number{}) = h_thread_buf(Number{}) + beta_thread_buf(n); }); threadwise_h_store_m_n.Run(thread_buffer_desc_m_n, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp index 949edb35f6..35f74fed95 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp @@ -218,14 +218,12 @@ struct GridwiseMultipleReduction_mk_to_m_multiblock in_thread_buf); static_for<0, NumReduction, 1>{}([&](auto iR) { - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // do element-wise pre-reduction operation - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = - thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number{}), - in_thread_buf(Number{})); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number{}), + in_thread_buf(Number{})); }); ThreadwiseReduce::Reduce(in_thread_buf_tuple(iR), accu_value_buf_tuple(iR)); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp index 5ad0ef3117..191a6942c4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp @@ -173,14 +173,12 @@ struct GridwiseMultipleReduction_mk_to_m_threadwise in_thread_buf); static_for<0, NumReduction, 1>{}([&](auto iR) { - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // do element-wise pre-reduction operation - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = - thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number{}), - in_thread_buf(Number{})); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number{}), + in_thread_buf(Number{})); }); ThreadwiseReduce::Reduce(in_thread_buf_tuple(iR), accu_value_buf_tuple(iR)); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp index f72aca8605..9ad1038251 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp @@ -212,13 +212,11 @@ struct GridwiseReduction_mk_to_m_multiblock make_tuple(I0, I0), in_thread_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // do element-wise pre-reduction operation - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - in_elementwise_op(in_thread_buf(Number{}), - in_thread_buf(Number{})); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op(in_thread_buf(Number{}), in_thread_buf(Number{})); }); ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp index fb20531133..a688878026 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp @@ -162,13 +162,11 @@ struct GridwiseReduction_mk_to_m_threadwise make_tuple(I0, I0), in_thread_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // do element-wise pre-reduction operation - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - in_elementwise_op(in_thread_buf(Number{}), - in_thread_buf(Number{})); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op(in_thread_buf(Number{}), in_thread_buf(Number{})); }); ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); @@ -340,15 +338,13 @@ struct GridwiseReduction_mk_to_m_threadwise make_tuple(I0, I0), in_thread_idx_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // do element-wise pre-reduction operation - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = - thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - in_elementwise_op(in_thread_val_buf(Number{}), - in_thread_val_buf(Number{})); - }); + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); }); ThreadwiseReduceWithIndex::Reduce( @@ -371,17 +367,15 @@ struct GridwiseReduction_mk_to_m_threadwise make_tuple(I0, I0), in_thread_val_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // do element-wise pre-reduction operation - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = - thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - in_thread_idx_buf(Number{}) = indexStart + iK(); + in_thread_idx_buf(Number{}) = indexStart + iK(); - in_elementwise_op(in_thread_val_buf(Number{}), - in_thread_val_buf(Number{})); - }); + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); }); ThreadwiseReduceWithIndex::Reduce( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp index 637ba27cef..e679592766 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp @@ -160,13 +160,11 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d make_tuple(I0, I0), in_thread_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // do element-wise pre-reduction operation - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - in_elementwise_op(in_thread_buf(Number{}), - in_thread_buf(Number{})); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op(in_thread_buf(Number{}), in_thread_buf(Number{})); }); ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp index e30ddf5c1c..10d83c4b32 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -1297,15 +1297,15 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); - static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { - static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { - auto I = Number{}; - Acc1DataType acc1 = acc1_thread_buf[I]; // P*V - Acc1DataType c = c_thread_buf[I]; // O - Acc1DataType c_new = c + acc1; // Simply add results since we are no longer using softmax. + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iN = Number{}]>{}; + auto I = Number{}; + Acc1DataType acc1 = acc1_thread_buf[I]; // P*V + Acc1DataType c = c_thread_buf[I]; // O + Acc1DataType c_new = c + acc1; // Simply add results since we are no longer using softmax. - c_thread_buf(I) = c_new; // O_new - }); + c_thread_buf(I) = c_new; // O_new }); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index 5722bbc146..98eba4fd2e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -1234,19 +1234,19 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); - static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { - static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { - auto I = Number{}; - FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V - FloatGemmAcc c = c_thread_buf[I]; // O - FloatGemmAcc c_new = - (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + - math::exp(max[iM] - running_max_new[iM]) * acc1) / - running_sum_new[iM]; // Formula by Dao et al., - // https://arxiv.org/pdf/2205.14135v2.pdf section 3.1 + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iN = Number{}]>{}; + auto I = Number{}; + FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V + FloatGemmAcc c = c_thread_buf[I]; // O + FloatGemmAcc c_new = + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + + math::exp(max[iM] - running_max_new[iM]) * acc1) / + running_sum_new[iM]; // Formula by Dao et al., + // https://arxiv.org/pdf/2205.14135v2.pdf section 3.1 - c_thread_buf(I) = c_new; // O_new - }); + c_thread_buf(I) = c_new; // O_new }); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp index 231acc7e4f..fd48c42b26 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp @@ -1410,18 +1410,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); - static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { - static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { - auto I = Number{}; - Acc1DataType acc1 = acc1_thread_buf[I]; // P*V - Acc1DataType c = c_thread_buf[I]; // O - Acc1DataType c_new = - (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + - math::exp(max[iM] - running_max_new[iM]) * acc1) / - running_sum_new[iM]; + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iN = Number{}]>{}; + auto I = Number{}; + Acc1DataType acc1 = acc1_thread_buf[I]; // P*V + Acc1DataType c = c_thread_buf[I]; // O + Acc1DataType c_new = + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + + math::exp(max[iM] - running_max_new[iM]) * acc1) / + running_sum_new[iM]; - c_thread_buf(I) = c_new; // O_new - }); + c_thread_buf(I) = c_new; // O_new }); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index 5cb9eac548..ffa3e464b9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -1048,19 +1048,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); - static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { - static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { - auto I = Number{}; - FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V - FloatGemmAcc c = c_thread_buf[I]; // O - FloatGemmAcc c_new = - (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + - math::exp(max[iM] - running_max_new[iM]) * acc1) / - running_sum_new[iM]; // Formula by Dao et al., - // https://arxiv.org/pdf/2205.14135v2.pdf section 3.1 + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iN = Number{}]>{}; + auto I = Number{}; + FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V + FloatGemmAcc c = c_thread_buf[I]; // O + FloatGemmAcc c_new = + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + + math::exp(max[iM] - running_max_new[iM]) * acc1) / + running_sum_new[iM]; // Formula by Dao et al., + // https://arxiv.org/pdf/2205.14135v2.pdf section 3.1 - c_thread_buf(I) = c_new; // O_new - }); + c_thread_buf(I) = c_new; // O_new }); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp index 1a9bbcb603..196cfceb4b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp @@ -437,19 +437,17 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford make_tuple(I0, I0), dy_thread_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - dy_elementwise_op(dy_thread_buf(Number{}), - dy_thread_buf[Number{}]); + dy_elementwise_op(dy_thread_buf(Number{}), dy_thread_buf[Number{}]); - AccDataType norm_x = (x_thread_buf[Number{}] - mean_thread_buf[iM]) * - inv_var_thread_buf[iM]; + AccDataType norm_x = + (x_thread_buf[Number{}] - mean_thread_buf[iM]) * inv_var_thread_buf[iM]; - tmp1_thread_buf(Number{}) = norm_x * dy_thread_buf[Number{}]; - }); + tmp1_thread_buf(Number{}) = norm_x * dy_thread_buf[Number{}]; }); ThreadwiseReduce::Reduce(tmp1_thread_buf, dscale_thread_buf); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp index f2cd6fca5c..2d8f9dd3ab 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp @@ -355,8 +355,10 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { // input add loop - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + static_ford>{}( + [&](auto mk) { // input add loop + constexpr auto iM = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; constexpr auto offset_m_k = thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); @@ -376,7 +378,6 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk unpack2(x_elementwise_op, out_data_refs, in_data_refs); }); - }); threadwise_welford.Run(x_thread_buf[iK0], mean_thread_buf, var_thread_buf); if constexpr(!SweepOnce) @@ -435,21 +436,20 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); - static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}([&](auto ii) { + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // normalize - y_thread_buf(iK0)(Number{}) = - (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * divisor; - // gamma - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) * - gamma_thread_buf(iK0)(Number{}); - }); + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); }); }); @@ -463,19 +463,19 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto mii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // beta - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) + - beta_thread_buf(iK0)(Number{}); - }); + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); }); - }); static_for<0, YThreadBufferNumber, 1>{}([&](auto i) { threadwise_y_store.Run(thread_buffer_desc_m_k, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index ccf4b04a6c..26692594da 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -937,8 +937,10 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; }); // reduce in VGPR - static_for<0, mreduce_per_thread, 1>{}([&](auto im) { - static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + static_ford>{}( + [&](auto ii) { + constexpr auto im = Number{}]>{}; + constexpr auto in = Number{}]>{}; constexpr auto offset = Number{}; @@ -946,7 +948,6 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 reduce_in_element_op(c_reduce_thread_buf(offset), c_reduce_thread_buf(offset)); }); - }); ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 3d9095bffb..c2af166e85 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -881,14 +881,14 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ThreadReduceOperation::template GetIdentityValue(); static_for<0, mreduce_per_thread, 1>{}( [&](auto I) { r_thread_buf(I) = reduce_identityVal; }); - static_for<0, mreduce_per_thread, 1>{}([&](auto im) { - static_for<0, nreduce_per_thread, 1>{}([&](auto in) { - constexpr auto offset = - Number{}; + static_ford>{}([&](auto ii) { + constexpr auto im = Number{}]>{}; + constexpr auto in = Number{}]>{}; + constexpr auto offset = + Number{}; - qs_element_op[Ir](e_thread_buf(offset), e_thread_buf(offset)); - }); + qs_element_op[Ir](e_thread_buf(offset), e_thread_buf(offset)); }); ThreadwiseReduce::Reduce(e_thread_buf, r_thread_buf); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index f1e97455b7..530194ee22 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -820,8 +820,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; }); // reduce in VGPR - static_for<0, mreduce_per_thread, 1>{}([&](auto im) { - static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + static_ford>{}( + [&](auto ii) { + constexpr auto im = Number{}]>{}; + constexpr auto in = Number{}]>{}; constexpr auto offset = Number{}; @@ -829,7 +831,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 reduce_in_element_op(c_reduce_thread_buf(offset), c_reduce_thread_buf(offset)); }); - }); ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index 9569cab98b..ec5c449a78 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -997,26 +997,26 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_grid_desc_mblock_mperblock_nblock_nperblock.GetTransforms()[I0] .GetUpperLengths()[I1]; // TODO: proper handle - static_for<0, mreduce_per_thread, 1>{}([&](auto im) { - static_for<0, nreduce_per_thread, 1>{}([&](auto in) { - constexpr auto dst_offset = - Number{}; + static_ford>{}([&](auto ii) { + constexpr auto im = Number{}]>{}; + constexpr auto in = Number{}]>{}; + constexpr auto dst_offset = + Number{}; - constexpr auto src_offset = - Number{}; + constexpr auto src_offset = + Number{}; - FloatReduceAcc avg_sum = d0_thread_buf(src_offset) / NRaw; - FloatReduceAcc avg_squared_sum = d1_thread_buf(src_offset) / NRaw; + FloatReduceAcc avg_sum = d0_thread_buf(src_offset) / NRaw; + FloatReduceAcc avg_squared_sum = d1_thread_buf(src_offset) / NRaw; - FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum; - FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum; - FloatReduceAcc divisor_sqrt; - tensor_operation::element_wise::UnarySqrt{}(divisor_sqrt, divisor); + FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum; + FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum; + FloatReduceAcc divisor_sqrt; + tensor_operation::element_wise::UnarySqrt{}(divisor_sqrt, divisor); - c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt; - }); + c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt; }); // scaling diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp index a6fa04a824..a3835ed7fd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp @@ -290,12 +290,12 @@ struct GridwiseSoftmax_mk_to_mk } // do element-wise pre-reduction operation - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - out_thread_buf(Number{}) = - math::exp(in_thread_buf(Number{}) - max_value_buf(iM)); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + math::exp(in_thread_buf(Number{}) - max_value_buf(iM)); }); ThreadwiseSumReduce::Reduce(out_thread_buf, accu_value_buf); @@ -330,15 +330,13 @@ struct GridwiseSoftmax_mk_to_mk in_thread_buf); } - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = - thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - out_thread_buf(Number{}) = - alpha * math::exp(in_thread_buf(Number{}) - max_value_buf(iM)) / - accu_value_buf(iM); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + alpha * math::exp(in_thread_buf(Number{}) - max_value_buf(iM)) / + accu_value_buf(iM); }); threadwise_dst_store.Run(thread_buffer_desc, @@ -376,16 +374,14 @@ struct GridwiseSoftmax_mk_to_mk make_tuple(I0, I0), in_prior_dst_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = - thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - out_thread_buf(Number{}) = - alpha * math::exp(in_thread_buf(Number{}) - max_value_buf(iM)) / - accu_value_buf(iM) + - beta * in_prior_dst_buf(Number{}); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + alpha * math::exp(in_thread_buf(Number{}) - max_value_buf(iM)) / + accu_value_buf(iM) + + beta * in_prior_dst_buf(Number{}); }); threadwise_dst_store.Run(thread_buffer_desc, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp index 5036f4ae7f..88719c831e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp @@ -175,36 +175,36 @@ struct GridwiseSparseEmbeddingsForwardLayernorm }; auto accumulate_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { - static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { - static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { - constexpr auto register_offset = thread_buf_desc.CalculateOffset( - make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); - auto in_data_refs = generate_tie( - [&](auto i_embedding_) -> const auto& { - return in_thread_bufs(i_embedding_)(Number{}); - }, - Number{}); - auto out_data_refs = generate_tie( - [&](auto) -> auto& { return acc_thread_buf(Number{}); }, - Number<1>{}); - unpack2(emb_elementwise_op, out_data_refs, in_data_refs); - }); + static_ford>{}([&](auto ii) { + constexpr auto i_dim_vec_ = Number{}]>{}; + constexpr auto i_row_vec_ = Number{}]>{}; + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + auto in_data_refs = generate_tie( + [&](auto i_embedding_) -> const auto& { + return in_thread_bufs(i_embedding_)(Number{}); + }, + Number{}); + auto out_data_refs = generate_tie( + [&](auto) -> auto& { return acc_thread_buf(Number{}); }, + Number<1>{}); + unpack2(emb_elementwise_op, out_data_refs, in_data_refs); }); }; auto threadwise_welford_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { - static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { - static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { - constexpr auto register_offset = thread_buf_desc.CalculateOffset( - make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); - constexpr auto mean_var_offset = - mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_)); + static_ford>{}([&](auto ii) { + constexpr auto i_dim_vec_ = Number{}]>{}; + constexpr auto i_row_vec_ = Number{}]>{}; + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + constexpr auto mean_var_offset = + mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_)); - threadwise_welford.cur_count_++; - threadwise_welford.Update(mean_thread_buf(Number{}), - var_thread_buf(Number{}), - acc_thread_buf(Number{})); - }); + threadwise_welford.cur_count_++; + threadwise_welford.Update(mean_thread_buf(Number{}), + var_thread_buf(Number{}), + acc_thread_buf(Number{})); }); }; @@ -246,12 +246,11 @@ struct GridwiseSparseEmbeddingsForwardLayernorm }; // first load index - ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) { + ck::static_ford>{}([&](auto ie) { + constexpr auto i_idx_ = Number{}]>{}; + constexpr auto i_embedding_ = Number{}]>{}; // prefer use s_load - ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { - index_bufs(i_embedding_)(i_idx_) = - p_indexes[i_embedding_][index_start + i_idx_.value]; - }); + index_bufs(i_embedding_)(i_idx_) = p_indexes[i_embedding_][index_start + i_idx_.value]; }); // load gamma/beta diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm_builtins.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm_builtins.hpp index a97114d48d..6aa046a43a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm_builtins.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm_builtins.hpp @@ -176,36 +176,36 @@ struct GridwiseSparseEmbeddingsForwardLayernorm }; auto accumulate_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { - static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { - static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { - constexpr auto register_offset = thread_buf_desc.CalculateOffset( - make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); - auto in_data_refs = generate_tie( - [&](auto i_embedding_) -> const auto& { - return in_thread_bufs(i_embedding_)(Number{}); - }, - Number{}); - auto out_data_refs = generate_tie( - [&](auto) -> auto& { return acc_thread_buf(Number{}); }, - Number<1>{}); - unpack2(emb_elementwise_op, out_data_refs, in_data_refs); - }); + static_ford>{}([&](auto ii) { + constexpr auto i_dim_vec_ = Number{}]>{}; + constexpr auto i_row_vec_ = Number{}]>{}; + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + auto in_data_refs = generate_tie( + [&](auto i_embedding_) -> const auto& { + return in_thread_bufs(i_embedding_)(Number{}); + }, + Number{}); + auto out_data_refs = generate_tie( + [&](auto) -> auto& { return acc_thread_buf(Number{}); }, + Number<1>{}); + unpack2(emb_elementwise_op, out_data_refs, in_data_refs); }); }; auto threadwise_welford_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { - static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { - static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { - constexpr auto register_offset = thread_buf_desc.CalculateOffset( - make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); - constexpr auto mean_var_offset = - mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_)); + static_ford>{}([&](auto ii) { + constexpr auto i_dim_vec_ = Number{}]>{}; + constexpr auto i_row_vec_ = Number{}]>{}; + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + constexpr auto mean_var_offset = + mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_)); - threadwise_welford.cur_count_++; - threadwise_welford.Update(mean_thread_buf(Number{}), - var_thread_buf(Number{}), - acc_thread_buf(Number{})); - }); + threadwise_welford.cur_count_++; + threadwise_welford.Update(mean_thread_buf(Number{}), + var_thread_buf(Number{}), + acc_thread_buf(Number{})); }); }; @@ -247,12 +247,11 @@ struct GridwiseSparseEmbeddingsForwardLayernorm }; // first load index - ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) { + ck::static_ford>{}([&](auto ie) { + constexpr auto i_idx_ = Number{}]>{}; + constexpr auto i_embedding_ = Number{}]>{}; // prefer use s_load - ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { - index_bufs(i_embedding_)(i_idx_) = - p_indexes[i_embedding_][index_start + i_idx_.value]; - }); + index_bufs(i_embedding_)(i_idx_) = p_indexes[i_embedding_][index_start + i_idx_.value]; }); // load gamma/beta diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp index 6c42dc33f4..f23744bf67 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp @@ -328,14 +328,14 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk make_tuple(I0, I0), gamma_thread_buf(i)); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - x_square_thread_buf(i)(Number{}) = - x_thread_buf(i)(Number{}) * - x_thread_buf(i)(Number{}); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + x_square_thread_buf(i)(Number{}) = + x_thread_buf(i)(Number{}) * + x_thread_buf(i)(Number{}); }); ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf); @@ -391,24 +391,24 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk } // normalization - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // normalize - y_thread_buf(iK0)(Number{}) = - (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - inv_std_thread_buf(iM); + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + inv_std_thread_buf(iM); - // gamma & beta - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) * - gamma_thread_buf(iK0)(Number{}); - }); + // gamma & beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_beta_load.Run(beta_grid_desc_m_k, @@ -422,19 +422,19 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // beta - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) + - beta_thread_buf(iK0)(Number{}); - }); + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_y_store.Run(thread_buffer_desc_m_k, @@ -460,14 +460,14 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk x_thread_buf(i)); threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - x_square_thread_buf(i)(Number{}) = - x_thread_buf(i)(Number{}) * - x_thread_buf(i)(Number{}); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + x_square_thread_buf(i)(Number{}) = + x_thread_buf(i)(Number{}) * + x_thread_buf(i)(Number{}); }); ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf); @@ -544,24 +544,24 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // normalize - y_thread_buf(iK0)(Number{}) = - (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - inv_std_thread_buf(iM); + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + inv_std_thread_buf(iM); - // gamma - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) * - gamma_thread_buf(iK0)(Number{}); - }); + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_beta_load.Run(beta_grid_desc_m_k, @@ -573,19 +573,19 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // beta - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) + - beta_thread_buf(iK0)(Number{}); - }); + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_y_store.Run(thread_buffer_desc_m_k, diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp index 9deb5a5f48..0a33d08a5e 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp @@ -441,24 +441,24 @@ struct GridwiseNormalizationSplitK2nd thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // normalize - y_thread_buf(iK0)(Number{}) = - (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - inv_std_thread_buf(iM); + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + inv_std_thread_buf(iM); - // gamma - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) * - gamma_thread_buf(iK0)(Number{}); - }); + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_beta_load.Run(beta_grid_desc_m_k, @@ -470,19 +470,19 @@ struct GridwiseNormalizationSplitK2nd thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // beta - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) + - beta_thread_buf(iK0)(Number{}); - }); + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_y_store.Run(thread_buffer_desc_m_k, diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp index d57871f331..7ca6c9ce7a 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp @@ -365,24 +365,24 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk } // normalization - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // normalize - y_thread_buf(iK0)(Number{}) = - (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - inv_std_thread_buf(iM); + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + inv_std_thread_buf(iM); - // gamma & beta - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) * - gamma_thread_buf(iK0)(Number{}); - }); + // gamma & beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_beta_load.Run(beta_grid_desc_m_k, @@ -396,19 +396,19 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // beta - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) + - beta_thread_buf(iK0)(Number{}); - }); + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_y_store.Run(thread_buffer_desc_m_k, @@ -496,24 +496,24 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // normalize - y_thread_buf(iK0)(Number{}) = - (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - inv_std_thread_buf(iM); + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + inv_std_thread_buf(iM); - // gamma - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) * - gamma_thread_buf(iK0)(Number{}); - }); + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_beta_load.Run(beta_grid_desc_m_k, @@ -525,19 +525,19 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // beta - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) + - beta_thread_buf(iK0)(Number{}); - }); + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_y_store.Run(thread_buffer_desc_m_k, diff --git a/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp index f4a2bc399d..44d14e4c04 100644 --- a/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp @@ -35,14 +35,13 @@ struct ThreadwiseReduction template __device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf) { - static_for<0, src_length_m, 1>{}([&](auto iM) { + static_ford>{}([&](auto mk) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM)); + constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - static_for<0, src_length_k, 1>{}([&](auto iK) { - constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - - Accumulation::Calculate(dst_buf(Number{}), src_buf[Number{}]); - }); + Accumulation::Calculate(dst_buf(Number{}), src_buf[Number{}]); }); }; }; @@ -81,17 +80,16 @@ struct ThreadwiseReductionWithIndex DstValueBufferType& dst_val_buf, DstIndexBufferType& dst_idx_buf) { - static_for<0, src_length_m, 1>{}([&](auto iM) { + static_ford>{}([&](auto mk) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM)); + constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - static_for<0, src_length_k, 1>{}([&](auto iK) { - constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - - Accumulation::Calculate(dst_val_buf(Number{}), - src_val_buf[Number{}], - dst_idx_buf(Number{}), - src_idx_buf[Number{}]); - }); + Accumulation::Calculate(dst_val_buf(Number{}), + src_val_buf[Number{}], + dst_idx_buf(Number{}), + src_idx_buf[Number{}]); }); }; }; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp index 06aca9c922..3a4d8f5d80 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp @@ -81,28 +81,21 @@ struct ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1 constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); - static_for<0, TK, 1>{}([&](auto tk) { - static_for<0, TM0, 1>{}([&](auto tm0) { - static_for<0, TM1, 1>{}([&](auto tm1) { - static_for<0, TN0, 1>{}([&](auto tn0) { - static_for<0, TN1, 1>{}([&](auto tn1) { - constexpr index_t a_offset = - AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( - a_origin_idx + make_multi_index(tk, tm0, tm1)); - constexpr index_t b_offset = - BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( - b_origin_idx + make_multi_index(tk, tn0, tn1)); - constexpr index_t c_offset = - CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( - c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); + static_ford>{}([&](auto tkmn) { + constexpr auto tk = Number{}]>{}; + constexpr auto tm0 = Number{}]>{}; + constexpr auto tm1 = Number{}]>{}; + constexpr auto tn0 = Number{}]>{}; + constexpr auto tn1 = Number{}]>{}; + constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( + a_origin_idx + make_multi_index(tk, tm0, tm1)); + constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( + b_origin_idx + make_multi_index(tk, tn0, tn1)); + constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( + c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); - inner_product(a_buf[Number{}], - b_buf[Number{}], - c_buf(Number{})); - }); - }); - }); - }); + inner_product( + a_buf[Number{}], b_buf[Number{}], c_buf(Number{})); }); } }; @@ -181,42 +174,35 @@ struct ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0 constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); - static_for<0, TK0, 1>{}([&](auto tk0) { - static_for<0, TM0, 1>{}([&](auto tm0) { - static_for<0, TM1, 1>{}([&](auto tm1) { - static_for<0, TN0, 1>{}([&](auto tn0) { - static_for<0, TN1, 1>{}([&](auto tn1) { - vector_type a_vec; - vector_type b_vec; + static_ford>{}([&](auto tkmn) { + constexpr auto tk0 = Number{}]>{}; + constexpr auto tm0 = Number{}]>{}; + constexpr auto tm1 = Number{}]>{}; + constexpr auto tn0 = Number{}]>{}; + constexpr auto tn1 = Number{}]>{}; + vector_type a_vec; + vector_type b_vec; - static_for<0, TK1, 1>{}([&](auto tk1) { - constexpr index_t a_offset = - AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( - a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1)); + static_for<0, TK1, 1>{}([&](auto tk1) { + constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( + a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1)); - constexpr index_t b_offset = - BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( - b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1)); + constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( + b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1)); - a_vec.template AsType()(tk1) = a_buf[Number{}]; - b_vec.template AsType()(tk1) = b_buf[Number{}]; - }); - - using a_vector_t = typename vector_type::type; - using b_vector_t = typename vector_type::type; - - constexpr index_t c_offset = - CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( - c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); - - inner_product( - a_vec.template AsType()[I0], - b_vec.template AsType()[I0], - c_buf(Number{})); - }); - }); - }); + a_vec.template AsType()(tk1) = a_buf[Number{}]; + b_vec.template AsType()(tk1) = b_buf[Number{}]; }); + + using a_vector_t = typename vector_type::type; + using b_vector_t = typename vector_type::type; + + constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( + c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); + + inner_product(a_vec.template AsType()[I0], + b_vec.template AsType()[I0], + c_buf(Number{})); }); } }; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp index 2896375636..89ded8ee7c 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp @@ -81,53 +81,49 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3 static_for<0, K, 1>{}([&](auto k) { static_for<0, Ho, SubHW>{}([&](auto h) { static_for<0, Wo, SubHW>{}([&](auto w) { - static_for<0, E1, 1>{}([&](auto e1) { - static_for<0, E2, 1>{}([&](auto e2) { - constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset( - a_origin_idx + make_tuple(e1, k, e2)); + static_ford>{}([&](auto ee) { + constexpr auto e1 = Number{}]>{}; + constexpr auto e2 = Number{}]>{}; + constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset( + a_origin_idx + make_tuple(e1, k, e2)); - constexpr index_t b0_offset = - BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( - b_origin_idx + make_tuple(e1, 0, h, w, e2)); + constexpr index_t b0_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h, w, e2)); - constexpr index_t b1_offset = - BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( - b_origin_idx + make_tuple(e1, 0, h, w + 1, e2)); + constexpr index_t b1_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h, w + 1, e2)); - constexpr index_t b2_offset = - BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( - b_origin_idx + make_tuple(e1, 0, h + 1, w, e2)); + constexpr index_t b2_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h + 1, w, e2)); - constexpr index_t b3_offset = - BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( - b_origin_idx + make_tuple(e1, 0, h + 1, w + 1, e2)); + constexpr index_t b3_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h + 1, w + 1, e2)); - constexpr index_t c0_offset = - CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + - make_tuple(k, 0, h, w)); + constexpr index_t c0_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h, w)); - constexpr index_t c1_offset = - CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( - c_origin_idx + make_tuple(k, 0, h, w + 1)); + constexpr index_t c1_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h, w + 1)); - constexpr index_t c2_offset = - CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( - c_origin_idx + make_tuple(k, 0, h + 1, w)); + constexpr index_t c2_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h + 1, w)); - constexpr index_t c3_offset = - CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( - c_origin_idx + make_tuple(k, 0, h + 1, w + 1)); + constexpr index_t c3_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h + 1, w + 1)); - amd_assembly_outer_product_1x4(a_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - c_buf(Number{}), - c_buf(Number{}), - c_buf(Number{}), - c_buf(Number{})); - }); + amd_assembly_outer_product_1x4(a_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{})); }); }); }); @@ -136,29 +132,24 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3 else { - static_for<0, K, 1>{}([&](auto k) { - static_for<0, Ho, 1>{}([&](auto h) { - static_for<0, Wo, 1>{}([&](auto w) { - static_for<0, E1, 1>{}([&](auto e1) { - static_for<0, E2, 1>{}([&](auto e2) { - constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset( - a_origin_idx + make_tuple(e1, k, e2)); + static_ford>{}([&](auto khwe) { + constexpr auto k = Number{}]>{}; + constexpr auto h = Number{}]>{}; + constexpr auto w = Number{}]>{}; + constexpr auto e1 = Number{}]>{}; + constexpr auto e2 = Number{}]>{}; + constexpr index_t a_offset = + AThreadDesc_E1_K_E2{}.CalculateOffset(a_origin_idx + make_tuple(e1, k, e2)); - constexpr index_t b_offset = - BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( - b_origin_idx + make_tuple(e1, 0, h, w, e2)); + constexpr index_t b_offset = BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h, w, e2)); - constexpr index_t c_offset = - CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + - make_tuple(k, 0, h, w)); + constexpr index_t c_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, h, w)); - inner_product(a_buf[Number{}], - b_buf[Number{}], - c_buf(Number{})); - }); - }); - }); - }); + inner_product(a_buf[Number{}], + b_buf[Number{}], + c_buf(Number{})); }); } } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index afd1e67bd0..91bf2d5832 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1827,25 +1827,24 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic } else { - static_for<0, num_access, 1>{}([&](auto idx_1d) { + static_ford>{}([&](auto access_idx) { + constexpr auto idx_1d = Number{}]>{}; + constexpr auto i = Number{}]>{}; constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); - // copy data from src_buf into dst_vector - static_for<0, DstScalarPerVector, 1>{}([&](auto i) { - constexpr index_t src_offset = src_desc.CalculateOffset( - src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - DstData v; + DstData v; - // apply element-wise operation - element_op_(v, src_buf[Number{}]); + // apply element-wise operation + element_op_(v, src_buf[Number{}]); - // apply type convert - dst_buf(Number{}) = v; - }); + // apply type convert + dst_buf(Number{}) = v; }); } } @@ -1933,57 +1932,56 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - static_for<0, num_access, 1>{}([&](auto idx_1d) { + static_ford>{}([&](auto access_idx) { + constexpr auto idx_1d = Number{}]>{}; + constexpr auto i = Number{}]>{}; constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); - // copy data from src_buf into dst_vector - static_for<0, DstScalarPerVector, 1>{}([&](auto i) { - // src_desc error, non constexpr, caused by merge transform - constexpr index_t src_offset = src_desc.CalculateOffset( - src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + // src_desc error, non constexpr, caused by merge transform + constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md + + i * dst_scalar_step_in_vector); - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + constexpr index_t dst_offset = dst_desc.CalculateOffset(dst_slice_origin_idx + idx_md + + i * dst_scalar_step_in_vector); - SrcData v_this_row, v_theother_row; - // int type temp value due to intrinsic requirement - int temp = 0; + SrcData v_this_row, v_theother_row; + // int type temp value due to intrinsic requirement + int temp = 0; - // apply element-wise operation - element_op_(v_this_row, src_buf[Number{}]); + // apply element-wise operation + element_op_(v_this_row, src_buf[Number{}]); - // apply intra-row permute. - if constexpr(IntraRowSwizzlePerm) - { - temp = __builtin_amdgcn_permlane16( - temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); - v_this_row = type_convert_sp(temp); - } + // apply intra-row permute. + if constexpr(IntraRowSwizzlePerm) + { + temp = __builtin_amdgcn_permlane16( + temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); + v_this_row = type_convert_sp(temp); + } - // apply inter-row permute. - temp = __builtin_amdgcn_permlanex16(temp, - type_convert_sp(v_this_row), - LowEightRowlaneIdx, - HighEightRowLaneIdx, - 1, - 0); - v_theother_row = type_convert_sp(temp); + // apply inter-row permute. + temp = __builtin_amdgcn_permlanex16(temp, + type_convert_sp(v_this_row), + LowEightRowlaneIdx, + HighEightRowLaneIdx, + 1, + 0); + v_theother_row = type_convert_sp(temp); - if(get_thread_local_1d_id() % 32 < 16) - { - // apply type convert - dst_buf(Number{}) = type_convert_sp(v_this_row); - dst_buf(Number{}) = - type_convert_sp(v_theother_row); - } - else - { - // apply type convert - dst_buf(Number{}) = - type_convert_sp(v_this_row); - dst_buf(Number{}) = type_convert_sp(v_theother_row); - } - }); + if(get_thread_local_1d_id() % 32 < 16) + { + // apply type convert + dst_buf(Number{}) = type_convert_sp(v_this_row); + dst_buf(Number{}) = + type_convert_sp(v_theother_row); + } + else + { + // apply type convert + dst_buf(Number{}) = + type_convert_sp(v_this_row); + dst_buf(Number{}) = type_convert_sp(v_theother_row); + } }); } }; @@ -2060,36 +2058,35 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - static_for<0, num_access, 1>{}([&](auto idx_1d) { + static_ford>{}([&](auto access_idx) { + constexpr auto idx_1d = Number{}]>{}; + constexpr auto i = Number{}]>{}; constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); - // copy data from src_buf into dst_vector - static_for<0, DstScalarPerVector, 1>{}([&](auto i) { - // src_desc error, non constexpr, caused by merge transform - constexpr index_t src_offset = src_desc.CalculateOffset( - src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + // src_desc error, non constexpr, caused by merge transform + constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md + + i * dst_scalar_step_in_vector); - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + constexpr index_t dst_offset = dst_desc.CalculateOffset(dst_slice_origin_idx + idx_md + + i * dst_scalar_step_in_vector); - SrcData v_this_row; - // int type temp value due to intrinsic requirement - int temp = 0; + SrcData v_this_row; + // int type temp value due to intrinsic requirement + int temp = 0; - // apply element-wise operation - element_op_(v_this_row, src_buf[Number{}]); + // apply element-wise operation + element_op_(v_this_row, src_buf[Number{}]); - // apply intra-row permute. - if constexpr(IntraRowSwizzlePerm) - { - temp = __builtin_amdgcn_permlane16( - temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); - v_this_row = type_convert_sp(temp); - } + // apply intra-row permute. + if constexpr(IntraRowSwizzlePerm) + { + temp = __builtin_amdgcn_permlane16( + temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); + v_this_row = type_convert_sp(temp); + } - // apply type convert - dst_buf(Number{}) = type_convert_sp(v_this_row); - }); + // apply type convert + dst_buf(Number{}) = type_convert_sp(v_this_row); }); } ElementwiseOperation element_op_{};