diff --git a/include/ck/tensor_description/tensor_adaptor.hpp b/include/ck/tensor_description/tensor_adaptor.hpp index ee8c7ed71b..0804e07fc3 100644 --- a/include/ck/tensor_description/tensor_adaptor.hpp +++ b/include/ck/tensor_description/tensor_adaptor.hpp @@ -350,16 +350,16 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a }); // match hidden id - static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { - static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) { - // if this low dim is bottom dim, then do id matching - if constexpr(low_dim_hidden_ids_1[idim_low_1] == - TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1]) - { - low_dim_hidden_ids_1_mod_(idim_low_1) = - TensorAdaptor0::GetTopDimensionHiddenIds()[idim_bottom_1]; - } - }); + static_ford>{}([&](auto ii) { + constexpr auto idim_low_1 = Number{}]>{}; + constexpr auto idim_bottom_1 = Number{}]>{}; + // if this low dim is bottom dim, then do id matching + if constexpr(low_dim_hidden_ids_1[idim_low_1] == + TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1]) + { + low_dim_hidden_ids_1_mod_(idim_low_1) = + TensorAdaptor0::GetTopDimensionHiddenIds()[idim_bottom_1]; + } }); return low_dim_hidden_ids_1_mod_; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp index e41cf8c82d..fb1ae8c543 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -202,22 +202,22 @@ struct BlockwiseGemmWmmaops_pipeline_base using AScaleThreadDesc = decltype(AScaleStruct::scale_thread_desc); using BScaleThreadDesc = decltype(BScaleStruct::scale_thread_desc); - static_for<0, num_scale_m_block, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_bufs(I0)(Number{}) = - a_scale_struct.scale_thread_bufs(I0)[Number{}] * - b_scale_struct.scale_thread_bufs(I0)[Number{}]; - }); + c_scale_thread_bufs(I0)(Number{}) = + a_scale_struct.scale_thread_bufs(I0)[Number{}] * + b_scale_struct.scale_thread_bufs(I0)[Number{}]; }); - }); } __device__ void Clear() diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp index 4f884b1df3..ee36f75164 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -224,87 +224,75 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(I0, m0, k0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0, I0, I0, I0, I0), - a_thread_buf); - if constexpr(m0 == I0) + static_ford>{}([&](auto km) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0, I0, I0, I0), + a_thread_buf); + if constexpr(m0 == I0) + { + if constexpr(ck::is_same::value == true) { - if constexpr(ck::is_same::value == true) - { - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(I0, n0, k0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0, I0), - b_thread_buf); - }); - } - else - { - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(I0, n0, k0, I0, I0, I0, I0), - b_block_buf, - b_scale_struct.scale_thread_bufs( - I0)[Number{}], - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0, I0), - b_thread_buf); - }); - } - } - - static_for<0, KInner, 1>{}([&](auto k_inner) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - I0, - I0, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - I0, - I0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0, I0), + b_thread_buf); }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_scale_struct.scale_thread_bufs( + I0)[Number{}], + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0, I0), + b_thread_buf); + }); + } + } + + static_ford>{}([&](auto kn) { + constexpr auto k_inner = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, I0, I0, I0, I0, I0, Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, n0, I0, I0, I0, I0, Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }; @@ -341,20 +329,17 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto) { __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read }); - static_for<0, KRepeat, 1>{}([&](auto) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - if constexpr(m0 == I0) - { - static_for<0, NRepeat, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - } - static_for<0, KInner, 1>{}([&](auto) { - static_for<0, NRepeat, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA - }); + static_ford>{}([&](auto km) { + constexpr auto m0 = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + if constexpr(m0 == I0) + { + static_for<0, NRepeat, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read }); + } + static_ford>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA }); }); static_for<0, num_ds_write_inst, 1>{}([&](auto) { @@ -464,59 +449,55 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { - c_scale_struct.Clear(); - static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KInner, 1>{}([&](auto k_inner) { - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k_index, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k_index, - I0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - wmma_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( - Number<0>{})); - }); + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; }); - c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( + Number<0>{})); }); }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); }); }; @@ -850,73 +831,71 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0_inner) { - static_for<0, KInner, 1>{}([&](auto k_inner) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}( + [&](auto kkmn) { + constexpr auto k0_inner = Number{}]>{}; + constexpr auto k_inner = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0_inner, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k0_inner, - I0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - // The block_sync_lds() here performs double duty: - // A) safeguard against data hazard. - // B) reduce VMEM FIFO congestion by applying small delays to - // different wavefronts. - // It is performed near the end of MAC cluster to minimize lgkmcnt - // penalty - if constexpr(k0_offset + k0_inner == KRepeat - 1 && - m0 == MRepeat - 1 && n0 == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - wmma_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } - }); + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0_inner, + I0, + I0, + I0, + Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0_inner, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard. + // B) reduce VMEM FIFO congestion by applying small delays to + // different wavefronts. + // It is performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 && + n0 == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); - }); __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_setprio(0); @@ -1249,15 +1228,15 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(I0, m0, k0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, k0, I0, I0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); }); // Initialize C @@ -1287,66 +1266,64 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - static_for<0, KInner, 1>{}([&](auto k_inner) { - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[wmma_reg_buf] - [Number{}, - I0, - I0, - n0, - I0, - k0, - Number{}))>{}]; - }); - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0, + I0, + I0, + I0, + Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[wmma_reg_buf] + [Number{}, + I0, + I0, + n0, + I0, + k0, + Number{}))>{}]; + }); + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); block_sync_lds(); // loop prefetch copy - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(I0, m0, k0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, k0, I0, I0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); }); HotLoopScheduler(); @@ -1373,112 +1350,86 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - static_for<0, KInner, 1>{}([&](auto k_inner) { - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}, - I0, - I0, - n0, - I0, - k0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, m0, k0, I0, I0, I0, Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}, I0, I0, n0, I0, k0, Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); block_sync_lds(); // tail Local Prefetch A1 - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(I0, m0, k0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, k0, I0, I0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); }); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - static_for<0, KInner, 1>{}([&](auto k_inner) { - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}, - I0, - I0, - n0, - I0, - k0, - Number{}))>{}]; - }); - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, m0, k0, I0, I0, I0, Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}, I0, I0, n0, I0, k0, Number{}))>{}]; + }); + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle @@ -1487,49 +1438,36 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - static_for<0, KInner, 1>{}([&](auto k_inner) { - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}, - I0, - I0, - n0, - I0, - k0, - Number{}))>{}]; - }); - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, m0, k0, I0, I0, I0, Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}, I0, I0, n0, I0, k0, Number{}))>{}]; + }); + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); } @@ -1590,70 +1528,65 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { - c_scale_struct.Clear(); - static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - static_for<0, KInner, 1>{}([&](auto k_inner) { - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k_index, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[reg_buf] - [Number{}, - I0, - I0, - n0, - I0, - k_index, - Number{}))>{}]; - }); - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - wmma_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( - Number<0>{})); - }); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; }); - c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[reg_buf][Number{}, + I0, + I0, + n0, + I0, + k_index, + Number{}))>{}]; + }); + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( + Number<0>{})); }); }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); }); }; auto a_local_prefetch_func = [&]() { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(I0, m0, k0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, k0, I0, I0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); }); }; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp index b8d451363e..03146f22ee 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp @@ -434,53 +434,38 @@ struct BlockwiseGemmWmmaops_pipeline_v3((i + 2) % num_loop_per_scale == 0); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KInner, 1>{}([&](auto k_inner) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmnk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k_inner = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, m0, k0, I0, I0, I0, Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, n0, k0, I0, I0, I0, Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); @@ -506,52 +491,35 @@ struct BlockwiseGemmWmmaops_pipeline_v3(num_loop % num_loop_per_scale == 0); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KInner, 1>{}([&](auto k_inner) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmnk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k_inner = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, m0, k0, I0, I0, I0, Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, n0, k0, I0, I0, I0, Number{}))>{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); @@ -564,52 +532,35 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KInner, 1>{}([&](auto k_inner) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmnk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k_inner = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k0, - I0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, m0, k0, I0, I0, I0, Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, n0, k0, I0, I0, I0, Number{}))>{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle // latency @@ -747,58 +698,55 @@ struct BlockwiseGemmWmmaops_pipeline_v3((i + 2) % num_loop_per_scale == 0); b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { - c_scale_struct.Clear(); - static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { - static_for<0, KInner, 1>{}([&](auto k_inner) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + c_scale_struct.Clear(); + static_ford>{}([&](auto kk_id) { + constexpr auto k0 = Number{}]>{}; + constexpr auto k_inner = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k_index, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k_index, - I0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_scale_struct.c_thread_buf_per_scale - .GetVectorTypeReference(Number<0>{})); - }); - }); - c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( + Number<0>{})); }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); }); c_scale_struct.Load(a_scale_struct, b_scale_struct); @@ -825,59 +773,55 @@ struct BlockwiseGemmWmmaops_pipeline_v3(num_loop % num_loop_per_scale == 0); b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { - c_scale_struct.Clear(); - static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { - static_for<0, KInner, 1>{}([&](auto k_inner) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + c_scale_struct.Clear(); + static_ford>{}([&](auto kk_id) { + constexpr auto k0 = Number{}]>{}; + constexpr auto k_inner = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k_index, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k_index, - I0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - wmma_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( - Number<0>{})); - }); - }); - c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); }); c_scale_struct.Load(a_scale_struct, b_scale_struct); @@ -891,58 +835,54 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { - c_scale_struct.Clear(); - static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - static_for<0, KInner, 1>{}([&](auto k_inner) { - static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k_index, - I0, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { - constexpr index_t kk = ik + k_inner * KPerWaveBlock; - constexpr index_t k_index = - kscale0 * (KRepeat / NumScaleKBlock) + k0; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k_index, - I0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - wmma_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( - Number<0>{})); - }); - }); - c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + c_scale_struct.Clear(); + static_ford>{}([&](auto kk_id) { + constexpr auto k0 = Number{}]>{}; + constexpr auto k_inner = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); }); // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle // latency diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp index fa0a69ed1f..812e14d73f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp @@ -512,22 +512,22 @@ struct BlockwiseGemmXdlops_pipeline_v4 // Local prefetch 1th, Fill Ping Reg block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(I0)); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(I0), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(I0)); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(I0)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(I0), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(I0)); }); }); @@ -566,22 +566,22 @@ struct BlockwiseGemmXdlops_pipeline_v4 // DS_READ: Pong LDS to Pong Reg block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(PongP1{}), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(PongP1{})); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(PongP1{}), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(PongP1{})); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); }); }); @@ -594,33 +594,31 @@ struct BlockwiseGemmXdlops_pipeline_v4 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[PingP1{}][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[PingP1{}][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); HotLoopScheduler(); @@ -634,22 +632,22 @@ struct BlockwiseGemmXdlops_pipeline_v4 // DS_READ: Ping LDS to Ping Reg block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(PongP2{}), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(PongP2{})); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(PongP2{}), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(PongP2{})); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP2{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP2{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP2{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP2{})); }); }); @@ -662,33 +660,31 @@ struct BlockwiseGemmXdlops_pipeline_v4 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[PingP2{}][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[PingP2{}][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); HotLoopScheduler(); @@ -708,54 +704,52 @@ struct BlockwiseGemmXdlops_pipeline_v4 // DS_READ: Pong LDS to Pong Reg block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(PongP1{}), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(PongP1{})); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(PongP1{}), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(PongP1{})); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); }); }); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{})); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{})); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[PingP1{}][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[PingP1{}][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); TailScheduler<1>(); @@ -769,82 +763,78 @@ struct BlockwiseGemmXdlops_pipeline_v4 // DS_READ: Ping LDS to Ping Reg block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(PongP2{}), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(PongP2{})); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(PongP2{}), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(PongP2{})); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP2{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP2{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP2{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP2{})); }); }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[PingP2{}][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[PingP2{}][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); TailScheduler<2>(); __builtin_amdgcn_sched_barrier(0); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[PongP2{}][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[PongP2{}][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PongP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PongP2{}][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); // 64 v_mfma @@ -860,51 +850,49 @@ struct BlockwiseGemmXdlops_pipeline_v4 // DS_READ: Pong LDS to Pong Reg block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(PongP1{}), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(PongP1{})); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(PongP1{}), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(PongP1{})); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); }); }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[PingP1{}][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[PingP1{}][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); TailScheduler<2>(); @@ -916,32 +904,30 @@ struct BlockwiseGemmXdlops_pipeline_v4 // DS_WRITE: To Pong LDS // DS_READ: Ping LDS to Ping Reg - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[PingP2{}][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[PingP2{}][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); // 64 v_mfma diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp index b2cca08cf5..6d45539e35 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp @@ -275,15 +275,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); // B VGPR->VGPR dequant b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, @@ -318,47 +318,43 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[mfma_reg_buf] - [Number{}]; - }); - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[mfma_reg_buf][Number< + b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}]; }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); // B VGPR->VGPR dequant b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, @@ -390,45 +386,42 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_dequant_bufs + [I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); // B VGPR->VGPR dequant b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, @@ -440,32 +433,29 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_dequant_bufs + [I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle // latency @@ -473,32 +463,29 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_dequant_bufs + [I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp index 682760a3df..2c43998a36 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp @@ -593,39 +593,38 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[mfma_reg_buf] - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[mfma_reg_buf] + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == MRepeat - 1) @@ -710,30 +709,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == MRepeat - 1) @@ -781,30 +780,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value != (MRepeat - 1)) @@ -837,30 +836,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value != (MRepeat - 1)) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp index 64ea0f9eab..7285685404 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp @@ -289,15 +289,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< // Local prefetch A1 block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); // B VGPR->VGPR dequant b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, @@ -345,57 +345,51 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[mfma_reg_buf] - [Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_dequant_bufs_up - [mfma_reg_buf][Number{}]; - }); - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[mfma_reg_buf][Number< + b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[mfma_reg_buf][Number< + b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}]; }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); // B VGPR->VGPR dequant b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, @@ -439,52 +433,49 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< block_sync_lds(); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_dequant_bufs_up[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_dequant_bufs + [I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); // B VGPR->VGPR dequant b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, @@ -502,39 +493,36 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< b_thread_dequant_bufs_up(I1)); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[I1][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_dequant_bufs_up[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_dequant_bufs + [I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle // latency @@ -542,39 +530,36 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< } else { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_dequant_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_dequant_bufs_up[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_dequant_bufs + [I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp index 31a86199c8..c5e040cfb4 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp @@ -298,17 +298,16 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); // Initialize C @@ -342,60 +341,53 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[mfma_reg_buf] - [Number{}]; - }); - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[mfma_reg_buf][Number< + b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}]; }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); HotLoopScheduler(); @@ -425,93 +417,83 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = b_thread_bufs_up + [I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = b_thread_bufs_up + [I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle // latency @@ -519,39 +501,35 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = b_thread_bufs_up + [I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp index 1b661b29ca..3fc204e6a2 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp @@ -255,54 +255,53 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3= 3 ? 1 : 0; // B global read - static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + static_ford>{}([&](auto ii) { + constexpr auto i = Number{}]>{}; + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(((i < buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_more == - buffer_load_issue_point_b)) || - ((i >= buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_less == - buffer_load_issue_point_b))) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_b)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_b))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); // A global read + A local write - static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_more == - ds_write_issue_point)) || - (((i + buffer_load_b_stages) >= buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_less == - ds_write_issue_point))) - { - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - } - if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_more == - buffer_load_issue_point_a)) || - (((i + buffer_load_b_stages) >= buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_less == - buffer_load_issue_point_a))) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + static_ford>{}([&](auto ii) { + constexpr auto i = Number{}]>{}; + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + ds_write_issue_point)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == ds_write_issue_point))) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); // lds synchronization, prefetch next loop local A @@ -511,17 +510,16 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); // Initialize C @@ -554,130 +552,129 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[mfma_reg_buf] - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[mfma_reg_buf] + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == MRepeat - 2) { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } else if constexpr(m0.value == (MRepeat - 1)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } else { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(mfma_reg_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); @@ -706,100 +703,100 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == (MRepeat - 2)) { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } else if constexpr(m0.value == MRepeat - 1) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } else { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } }); @@ -807,58 +804,58 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); if constexpr(m0.value < (MRepeat - 2)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); @@ -870,53 +867,53 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); if constexpr(m0.value < (MRepeat - 2)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp index 3ed035e1cc..b064889a8a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp @@ -259,25 +259,26 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< // Stage 1 // global read more - static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + static_ford>{}([&](auto ii) { + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_more == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } + if constexpr(imfma % buffer_load_issue_point_interval_more == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); // global read less - static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + static_ford< + Sequence<(num_total_stages - 2 - buffer_load_stages_more), num_mfma_perstage>>{}( + [&](auto ii) { + constexpr auto imfma = Number{}]>{}; __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA if constexpr(imfma % buffer_load_issue_point_interval_less == 0) { @@ -288,22 +289,20 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read } }); - }); // Stage 2, Sync // lds synchronization, prefetch next loop local A - static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + static_ford>{}([&](auto ii) { + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); } @@ -463,25 +462,24 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< // Local prefetch 1, sync the async load __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); block_sync_lds(); - static_for<0, LocalPrefetchStages, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * - (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_m3_k, - make_tuple( - I0, I0, Number{}, I0, Number{}), - a_block_bufs(I0), - a_thread_desc_, - make_tuple( - I0, I0, Number{}, k, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k = Number{}]>{}; + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(I0, I0, Number{}, I0, Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple( + I0, I0, Number{}, k, Number{}), + a_thread_buf); + }); }); // Global prefetch 2 @@ -583,105 +581,97 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< static_for<0, MRepeat, 1>{}([&](auto m0) { constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto ik_major = k0 / KXdlPack; constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset( - make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset( - make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = a_scale_thread_desc.CalculateOffset( + make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = b_scale_thread_desc.CalculateOffset( + make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; - vector_type - b_scale_thread_vec_up; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type + b_scale_thread_vec_up; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); - // B Gate scale - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); - // B Up scale - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up( - scale_comp_buf)[Number{}]; - }); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + // B Gate scale + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + // B Up scale + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up( + scale_comp_buf)[Number{}]; + }); - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = b_thread_bufs + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[scale_comp_buf] + [Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up [scale_comp_buf][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up - [scale_comp_buf][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation A * Gate - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - // MFMA accumulation A * Up - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up - .template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation A * Gate + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + // MFMA accumulation A * Up + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == SwitchM) @@ -798,91 +788,91 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< static_for<0, MRepeat, 1>{}([&](auto m0) { constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto ik_major = k0 / KXdlPack; constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - vector_type b_scale_thread_vec_up; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; - }); - // B Gate scale - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); - // B Up scale - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up(I0)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I0][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation A * Gate - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - // MFMA accumulation A * Gate - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; }); + // B Gate scale + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + // B Up scale + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation A * Gate + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + // MFMA accumulation A * Gate + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == SwitchM) { @@ -920,89 +910,89 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< static_for<0, MRepeat, 1>{}([&](auto m0) { constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto ik_major = k0 / KXdlPack; constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - vector_type b_scale_thread_vec_up; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I1)[Number{}]; - }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I1)[Number{}]; - }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up(I1)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I1][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation A * Gate - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - // MFMA accumulation A * Up - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; }); + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I1)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation A * Gate + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + // MFMA accumulation A * Up + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) { @@ -1040,91 +1030,91 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< static_for<0, MRepeat, 1>{}([&](auto m0) { constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto ik_major = k0 / KXdlPack; constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - vector_type b_scale_thread_vec_up; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; - }); - // B Gate scale - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); - // B Up scale - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up(I0)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I0][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation A * Gate - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - // MFMA accumulation A * up - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up.template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; }); + // B Gate scale + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + // B Up scale + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation A * Gate + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + // MFMA accumulation A * up + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp index a8ceddc1a3..e0a9b43986 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp @@ -357,31 +357,31 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k) { + static_ford< + Sequence>{}( + [&](auto mkc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); }); - }); // Initialize C c_thread_buf.Clear(); @@ -448,118 +448,107 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1{}([&](auto m0) { + static_ford>{}([&](auto mkn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { - constexpr auto ik_major = k0 / KXdlPack; - constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset( - make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset( - make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = b_thread_bufs - [scale_comp_buf][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(scale_comp_buf)[Number{}]; }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(scale_comp_buf)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[scale_comp_buf][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k) { + static_ford>{}( + [&](auto mkc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, - xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), - 1>{}([&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); }); - }); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); }; @@ -611,257 +600,246 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1{}([&](auto m0) { + static_ford>{}([&](auto mkn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { - constexpr auto ik_major = k0 / KXdlPack; - constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; }); - // constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0; + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k) { + static_ford>{}( + [&](auto mkc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); }); - }); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { + static_ford>{}([&](auto mkn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { - constexpr auto ik_major = k0 / KXdlPack; - constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I1)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I1)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat, 1>{}([&](auto m0) { + static_ford>{}([&](auto mkn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { - constexpr auto ik_major = k0 / KXdlPack; - constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp index 0de34baa42..b18ea372a1 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp @@ -261,54 +261,49 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + static_ford>{}([&](auto ii) { + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_more == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } + if constexpr(imfma % buffer_load_issue_point_interval_more == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier( - 0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); // global read less - static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_less == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier( - 0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + static_ford>{}([&](auto ii) { + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_less == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); // Stage 2, Sync // lds synchronization, prefetch next loop local A - static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier( - 0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + static_ford>{}([&](auto ii) { + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); } else @@ -536,25 +531,24 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * - (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_m3_k, - make_tuple( - I0, I0, Number{}, I0, Number{}), - a_block_bufs(I0), - a_thread_desc_, - make_tuple( - I0, I0, Number{}, k, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k = Number{}]>{}; + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(I0, I0, Number{}, I0, Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple( + I0, I0, Number{}, k, Number{}), + a_thread_buf); + }); }); // Global prefetch 2 @@ -628,83 +622,76 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}([&](auto m0) { constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto ik_major = k0 / KXdlPack; constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset( - make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset( - make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = a_scale_thread_desc.CalculateOffset( + make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = b_scale_thread_desc.CalculateOffset( + make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = b_thread_bufs - [scale_comp_buf][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[scale_comp_buf] + [Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == SwitchM) @@ -802,73 +789,73 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}([&](auto m0) { constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto ik_major = k0 / KXdlPack; constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == SwitchM) { @@ -906,73 +893,73 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}([&](auto m0) { constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto ik_major = k0 / KXdlPack; constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I1)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I1)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) { @@ -1010,73 +997,73 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}([&](auto m0) { constexpr auto im_major = m0 / MXdlPack; constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; constexpr auto ik_major = k0 / KXdlPack; constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp index 9dccd9b4e6..b4c64b718b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp @@ -280,17 +280,16 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); // Initialize C @@ -318,51 +317,46 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; - }); - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); // loop prefetch copy - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); HotLoopScheduler(); @@ -387,79 +381,71 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); // tail Local Prefetch A1 - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle // latency @@ -467,32 +453,29 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp index 3da80ba25e..eff2c09571 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp @@ -281,17 +281,16 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_bufs(I0)); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_bufs(I0)); }); // Local prefill A2 @@ -323,18 +322,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_bufs(local_read_buf)); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_bufs(local_read_buf)); }); a_blockwise_copy.RunWrite( @@ -343,36 +341,31 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf] - [Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); HotLoopScheduler(); @@ -398,48 +391,44 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(local_read_reg), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_bufs(local_read_reg)); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(local_read_reg), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_bufs(local_read_reg)); }); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(mfma_reg), mfma_reg); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); HotLoopScheduler(); @@ -455,46 +444,42 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(local_read_reg), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_bufs(local_read_reg)); - }); - }); + static_ford>{}([&](auto mkg) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(local_read_reg), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_bufs(local_read_reg)); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); HotLoopScheduler(); @@ -502,32 +487,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp index 4c20f21d22..c7b7948e41 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp @@ -258,52 +258,50 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3= 3 ? 1 : 0; // B global read - static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); + static_ford>{}([&](auto ii) { + constexpr auto i = Number{}]>{}; + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); - if constexpr(((i < buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_more == 0)) || - ((i >= buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_less == 0))) - { - __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0); - } + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == 0)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == 0))) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0); + } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier( - SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); - } - }); + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); + } }); // A global read + A local write - static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); - if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_more == 0)) || - (((i + buffer_load_b_stages) >= buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_less == 0))) - { - __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_LDS_WRITE, 1, 0); - } - if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_more == - buffer_load_issue_point_a)) || - (((i + buffer_load_b_stages) >= buffer_load_stages_more) && - (imfma % buffer_load_issue_point_interval_less == - buffer_load_issue_point_a))) - { - __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0); - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier( - SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); - } - }); + static_ford>{}([&](auto ii) { + constexpr auto i = Number{}]>{}; + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == 0)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == 0))) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_LDS_WRITE, 1, 0); + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0); + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); + } }); // lds synchronization, prefetch next loop local A @@ -379,17 +377,16 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); // Initialize C @@ -416,119 +413,114 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == (MRepeat - 2)) { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<0>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number<0>{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } else if constexpr(m0.value == (MRepeat - 1)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } else { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(mfma_reg_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); @@ -552,88 +544,87 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value == (MRepeat - 2)) { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number<0>{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<0>{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } else if constexpr(m0.value == (MRepeat - 1)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } else { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } }); @@ -641,50 +632,50 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value < (MRepeat - 2)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); @@ -694,46 +685,46 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); if constexpr(m0.value < (MRepeat - 2)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp index 19cd141a3e..bfeeb715e3 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp @@ -347,22 +347,19 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1{}); constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); - static_for<0, num_scale_m_block, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * b_scale_thread_buf[Number{}]; }); - }); // Local prefill A1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); @@ -409,18 +406,16 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); // Initialize C @@ -448,114 +443,104 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - constexpr index_t cscale_offset = - CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number< - b_thread_desc_.CalculateOffset(make_tuple( - n0, - I0, - kscale0 * KRepeat / num_scale_k_block + k0, - ik))>{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}( - [&](auto t) { - using pk_fma_type = - typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = - __builtin_elementwise_fma( - c_thread_buf_per_scale - .GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec - .template AsType()[Number<0>{}], - c_thread_buf - .GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; }); - }); static_for<0, MRepeat, 1>{}([&](auto m0) { a_scale_thread_copy.Run(a_scale_grid_desc, @@ -606,231 +591,207 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - }); - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * b_scale_thread_buf[Number{}]; }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp index 16bb7d0e2f..48f815c509 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp @@ -538,18 +538,16 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); #if 1 @@ -717,28 +715,28 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(Number{}), - a_thread_desc_, - make_tuple(Number<(m0 + LocalPrefetchStages + - HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(Number{}), + a_thread_desc_, + make_tuple(Number<(m0 + LocalPrefetchStages + + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); }); @@ -841,26 +839,25 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3= (MRepeat - LocalPrefetchStages) ? I1 : I0; - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + LocalPrefetchStages) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(Number{}), - a_thread_desc_, - make_tuple(Number<(m0 + LocalPrefetchStages) % 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + LocalPrefetchStages) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(Number{}), + a_thread_desc_, + make_tuple(Number<(m0 + LocalPrefetchStages) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); }); @@ -943,28 +940,27 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + LocalPrefetchStages + HotloopLocalBufSwitch) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + LocalPrefetchStages + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); @@ -1042,22 +1038,22 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(Number<(m0 + LocalPrefetchStages) % 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(Number<(m0 + LocalPrefetchStages) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp index 735891c556..7de7376e2a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp @@ -375,25 +375,22 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1< constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{}); constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{}); constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); - static_for<0, num_scale_m_block, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - c_scale_thread_buf_up(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf_up[Number{}]; - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * b_scale_thread_buf[Number{}]; + c_scale_thread_buf_up(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf_up[Number{}]; }); - }); // Local prefill A1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); @@ -450,18 +447,16 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1< // Local prefetch A1 block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); // Initialize C @@ -496,149 +491,134 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1< a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - vector_type c_scale_thread_vec_up; - constexpr index_t cscale_offset = - CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + vector_type c_scale_thread_vec_up; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec_up.template AsType()(Number<0>{}) = - c_scale_thread_buf_up[Number{}]; - c_scale_thread_vec_up.template AsType()(Number<1>{}) = - c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number< - b_thread_desc_.CalculateOffset(make_tuple( - n0, - I0, - kscale0 * KRepeat / num_scale_k_block + k0, - ik))>{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[mfma_reg_buf][Number< - b_thread_desc_.CalculateOffset(make_tuple( - n0, - I0, - kscale0 * KRepeat / num_scale_k_block + k0, - ik))>{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_per_scale_up.GetVectorTypeReference( - Number<0>{})); - }); - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}( - [&](auto t) { - using pk_fma_type = - typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = - __builtin_elementwise_fma( - c_thread_buf_per_scale - .GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec - .template AsType()[Number<0>{}], - c_thread_buf - .GetVectorTypeReference(Number{}) - .template AsType()[t]); - c_thread_buf_up.GetVectorTypeReference(Number{}) - .template AsType()(t) = - __builtin_elementwise_fma( - c_thread_buf_per_scale_up - .GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec_up - .template AsType()[Number<0>{}], - c_thread_buf_up - .GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up + [mfma_reg_buf][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - c_scale_thread_buf_up(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf_up[Number{}]; - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + c_scale_thread_buf_up(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf_up[Number{}]; }); - }); a_scale_thread_copy.Run(a_scale_grid_desc, a_scale_grid_buf, @@ -699,310 +679,277 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1< block_sync_lds(); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - vector_type c_scale_thread_vec_up; - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + vector_type c_scale_thread_vec_up; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec_up.template AsType()(Number<0>{}) = - c_scale_thread_buf_up[Number{}]; - c_scale_thread_vec_up.template AsType()(Number<1>{}) = - c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); - }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - c_thread_buf_up.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec_up.template AsType()[Number<0>{}], - c_thread_buf_up.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - c_scale_thread_buf_up(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf_up[Number{}]; - }); - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * b_scale_thread_buf[Number{}]; + c_scale_thread_buf_up(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf_up[Number{}]; }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); // __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - vector_type c_scale_thread_vec_up; - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + vector_type c_scale_thread_vec_up; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec_up.template AsType()(Number<0>{}) = - c_scale_thread_buf_up[Number{}]; - c_scale_thread_vec_up.template AsType()(Number<1>{}) = - c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); - }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - c_thread_buf_up.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec_up.template AsType()[Number<0>{}], - c_thread_buf_up.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - vector_type c_scale_thread_vec_up; - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + vector_type c_scale_thread_vec_up; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec_up.template AsType()(Number<0>{}) = - c_scale_thread_buf_up[Number{}]; - c_scale_thread_vec_up.template AsType()(Number<1>{}) = - c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_bufs_up[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec_up.template AsType(), - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); - }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - c_thread_buf_up.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec_up.template AsType()[Number<0>{}], - c_thread_buf_up.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp index e5cff43bf9..cbdf56f2f5 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp @@ -569,17 +569,16 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< // Local prefetch A1 block_sync_lds(); - static_for<0, 2, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); __builtin_amdgcn_sched_barrier(0); @@ -729,80 +728,80 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } else if constexpr(m0.value == (MRepeat - 1)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } else { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(mfma_reg_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); @@ -916,62 +915,62 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } else if constexpr(m0.value == (MRepeat - 1)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } else { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } }); @@ -1058,22 +1057,22 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< if constexpr(m0.value < (MRepeat - 2)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); @@ -1156,18 +1155,18 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< if constexpr(m0.value < (MRepeat - 2)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp index a76be40753..cc4dc0ed5c 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp @@ -343,22 +343,19 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{}); constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); - static_for<0, num_scale_m_block, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * b_scale_thread_buf[Number{}]; }); - }); __builtin_amdgcn_sched_barrier(0); // Local prefill A1 @@ -402,18 +399,16 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< // Local prefetch A1 block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); // Initialize C @@ -441,115 +436,105 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - constexpr index_t cscale_offset = - CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number< - b_thread_desc_.CalculateOffset(make_tuple( - n0, - I0, - kscale0 * KRepeat / num_scale_k_block + k0, - ik))>{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}( - [&](auto t) { - using pk_fma_type = - typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = - __builtin_elementwise_fma( - c_thread_buf_per_scale - .GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec - .template AsType()[Number<0>{}], - c_thread_buf - .GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; }); - }); __builtin_amdgcn_sched_barrier(0); a_scale_thread_copy.Run(a_scale_grid_desc, @@ -597,233 +582,209 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< block_sync_lds(); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - }); - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * b_scale_thread_buf[Number{}]; }); block_sync_lds(); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); // __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - vector_type c_scale_thread_vec; - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[Number{}]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); }); }); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp index f2d85150d6..36fd54fce6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp @@ -526,17 +526,16 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< // Local prefetch A1 block_sync_lds(); - static_for<0, 2, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mkk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); #if 0 @@ -672,80 +671,80 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } else if constexpr(m0.value == (MRepeat - 1)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } else { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(mfma_reg_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); @@ -830,62 +829,62 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } else if constexpr(m0.value == (MRepeat - 1)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } else { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } }); @@ -947,22 +946,22 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< if constexpr(m0.value < (MRepeat - 2)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, - I0, - I0, - k0, - I0, - Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); } }); @@ -1023,18 +1022,18 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< if constexpr(m0.value < (MRepeat - 2)) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple( - Number{}, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + static_ford>{}([&](auto kk) { + constexpr auto k0 = Number{}]>{}; + constexpr auto kg0 = Number{}]>{}; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); }); } }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp index 1f8d3b28b5..43a405bcaa 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp @@ -496,71 +496,74 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_bufs(I0), - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_bufs(I0), - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_bufs_up(I0), - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf_up); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I0), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs_up(I0), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf_up); + }); }); // Global prefetch 2 @@ -664,116 +667,97 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; - vector_type - b_scale_thread_vec_up; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(scale_comp_buf)[Number{}]; + }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up( - scale_comp_buf)[Number{}]; - }); + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(scale_comp_buf)[Number{}]; + }); - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up( + scale_comp_buf)[Number{}]; + }); - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()( - ik) = a_thread_buf - [Number{}]; - b_thread_vec.template AsType()( - ik) = b_thread_buf - [Number{}]; - b_thread_vec_up.template AsType()( - ik) = b_thread_buf_up - [Number{}]; - }); + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference( - Number{})); - - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec_up - .template AsType(), - b_scale_thread_vec_up - .template AsType(), - c_thread_buf_up.GetVectorTypeReference( - Number{})); - }); - }); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); }); @@ -927,105 +911,97 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - vector_type b_scale_thread_vec_up; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up(I0)[Number{}]; - }); + using mfma_input_type_b = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_buf_up[Number{}]; - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up - .template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); - }); - }); + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); }); - }); __builtin_amdgcn_s_waitcnt(3952); block_sync_lds(); @@ -1033,277 +1009,261 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_bufs(I1), - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_bufs(I1), - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_bufs_up(I1), - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf_up); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I1), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I1), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs_up(I1), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf_up); + }); }); - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - vector_type b_scale_thread_vec_up; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I1)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I1)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I1)[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up(I1)[Number{}]; - }); + using mfma_input_type_b = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_buf_up[Number{}]; - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up - .template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); - }); - }); + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); }); - }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - vector_type b_scale_thread_vec_up; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up(I0)[Number{}]; - }); + using mfma_input_type_b = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_buf_up[Number{}]; - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up - .template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); - }); - }); + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); }); - }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp index e25136afb1..c426a0e1f4 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp @@ -529,71 +529,74 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf_up, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf_up); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf_up, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf_up); + }); }); // Initialize C @@ -689,116 +692,97 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; - vector_type - b_scale_thread_vec_up; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(scale_comp_buf)[Number{}]; + }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up( - scale_comp_buf)[Number{}]; - }); + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(scale_comp_buf)[Number{}]; + }); - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up( + scale_comp_buf)[Number{}]; + }); - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()( - ik) = a_thread_buf - [Number{}]; - b_thread_vec.template AsType()( - ik) = b_thread_buf - [Number{}]; - b_thread_vec_up.template AsType()( - ik) = b_thread_buf_up - [Number{}]; - }); + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference( - Number{})); - - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec_up - .template AsType(), - b_scale_thread_vec_up - .template AsType(), - c_thread_buf_up.GetVectorTypeReference( - Number{})); - }); - }); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); }); @@ -957,382 +941,358 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - vector_type b_scale_thread_vec_up; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up(I0)[Number{}]; - }); + using mfma_input_type_b = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_buf_up[Number{}]; - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up - .template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); - }); - }); + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); }); - }); block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf_up, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf_up); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf_up, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf_up); + }); }); - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - vector_type b_scale_thread_vec_up; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I1)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I1)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I1)[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up(I1)[Number{}]; - }); + using mfma_input_type_b = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_buf_up[Number{}]; - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up - .template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); - }); - }); + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); }); - }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - vector_type b_scale_thread_vec_up; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec_up.template AsType()(s) = - b_scale_thread_bufs_up(I0)[Number{}]; - }); + using mfma_input_type_b = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; - vector_type b_thread_vec_up; + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - b_thread_vec_up.template AsType()(ik) = - b_thread_buf_up[Number{}]; - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec_up.template AsType(), - b_scale_thread_vec_up - .template AsType(), - c_thread_buf_up.GetVectorTypeReference(Number{})); - }); - }); - }); + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); }); }); - }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp index 0be289287d..d598d281de 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp @@ -324,133 +324,128 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); + static_ford>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); }); - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_buf[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_buf[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference( - Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); // Prefetch a_scales static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { @@ -510,132 +505,126 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); }); - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_buf[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_buf[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp index cf6fa231e1..2d295c881b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp @@ -492,49 +492,51 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); }); // Initialize C @@ -603,91 +605,78 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(scale_comp_buf)[Number{}]; + }); - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(scale_comp_buf)[Number{}]; + }); - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()( - ik) = a_thread_buf - [Number{}]; - b_thread_vec.template AsType()( - ik) = b_thread_buf - [Number{}]; - }); + vector_type a_thread_vec; + vector_type b_thread_vec; - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference( - Number{})); - }); - }); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); @@ -805,299 +794,281 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); }); - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I1)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I1)[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = + typename vector_type::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = + typename vector_type::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp index 6f086eed05..1d3f3c8bff 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp @@ -459,49 +459,51 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_bufs(I0), - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_bufs(I0), - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I0), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); }); // Global prefetch 2 @@ -577,91 +579,85 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford< + Sequence>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs( - scale_comp_buf)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_ford>{}([&](auto + kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()( - ik) = a_thread_buf - [Number{}]; - b_thread_vec.template AsType()( - ik) = b_thread_buf - [Number{}]; - }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference( - Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); // k indexes mapping to threads for 32x32x64: // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc. @@ -774,83 +770,81 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); __builtin_amdgcn_s_waitcnt(3952); block_sync_lds(); @@ -858,210 +852,206 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_bufs(I1), - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_bufs(I1), - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I1), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I1), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); }); - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I1)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I1)[Number{}]; - }); + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp index ae4504d6ba..723ef9cd1e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -220,69 +220,9 @@ struct BlockwiseGemmXdlops_pipeline_v1{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); - }); - }); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - - i += 1; - } while(i < (num_loop - 1)); - } - - // tail - if constexpr(TailNum == TailNumber::Full) - { - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, make_tuple(m0, I0, I0, Number{}), a_block_buf, @@ -298,34 +238,85 @@ struct BlockwiseGemmXdlops_pipeline_v1>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } } @@ -553,51 +544,51 @@ struct BlockwiseGemmXdlops_pipeline_v1{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - // The block_sync_lds() here performs double duty: - // A) safeguard against data hazard because barrier from - // blockwise_gemm is moved here B) reduce VMEM FIFO congestion by - // applying small delays to different wavefronts It is performed - // near the end of MAC cluster to minimize lgkmcnt penalty - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from + // blockwise_gemm is moved here B) reduce VMEM FIFO congestion by + // applying small delays to different wavefronts It is performed + // near the end of MAC cluster to minimize lgkmcnt penalty + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); @@ -642,46 +633,43 @@ struct BlockwiseGemmXdlops_pipeline_v1{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); @@ -942,73 +930,9 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); - }); - }); - - block_sync_lds(); - a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); - b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - block_sync_lds_direct_load(); - - i += 1; - } while(i < (num_loop - 1)); - } - - // tail - if constexpr(TailNum == TailNumber::Full) - { - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, make_tuple(m0, I0, I0, Number{}), a_block_buf, @@ -1024,34 +948,89 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + + block_sync_lds_direct_load(); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp index 6da21fcec7..cb5ef0e700 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp @@ -406,22 +406,19 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale{}); constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); - static_for<0, num_scale_m_block, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * b_scale_thread_buf[Number{}]; }); - }); // Local prefill 1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); @@ -512,74 +509,64 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - constexpr index_t cscale_offset = - CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - - c_thread_buf(Number{}) += - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[Number{}] * - type_convert( - c_scale_thread_buf[Number{}]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert(c_scale_thread_buf[Number{}]); }); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; }); - }); block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k) { @@ -642,72 +629,59 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - - c_thread_buf(Number{}) += - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[Number{}] * - type_convert( - c_scale_thread_buf[Number{}]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert(c_scale_thread_buf[Number{}]); }); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, num_scale_n_block, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto k0) { - constexpr index_t c_offset = - CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); - constexpr index_t a_offset = - AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); - constexpr index_t b_offset = - BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - c_scale_thread_buf(Number{}) = - a_scale_thread_buf[Number{}] * - b_scale_thread_buf[Number{}]; - }); - }); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * b_scale_thread_buf[Number{}]; }); block_sync_lds(); @@ -733,108 +707,90 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - - c_thread_buf(Number{}) += - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[Number{}] * - type_convert( - c_scale_thread_buf[Number{}]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert(c_scale_thread_buf[Number{}]); }); }); __builtin_amdgcn_sched_barrier(0); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto kscale0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( - make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); - - c_thread_buf(Number{}) += - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[Number{}] * - type_convert( - c_scale_thread_buf[Number{}]); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert(c_scale_thread_buf[Number{}]); }); }); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp index 6b95cd6ec7..79a72f4b5e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp @@ -277,38 +277,37 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(b_scale_thread_buf[n0]); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(b_scale_thread_buf[n0]); }); }); @@ -358,37 +357,34 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(b_scale_thread_buf[n0]); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(b_scale_thread_buf[n0]); }); }); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp index 189d0ad2c3..4c2d11c8e6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp @@ -331,56 +331,60 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx 47 96 --> 111| 160 --> 175 224 --> 239| etc. // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc. // k = 0 k = 1 - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = - k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; - - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); + // LDS reads for A + static_ford>{}( + [&](auto km_chunk) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / BPackedSize / KThreadChunk, 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); + // LDS reads for B + static_ford>{}( + [&](auto kn_chunk) { + constexpr auto k = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); }); - }); // load for next k loop block_sync_lds(); @@ -389,82 +393,78 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThreadA && - 0 < ScalesPerXdlopsRunPerThreadB, - "Must have at least one scale per Xdlops per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThreadA && + 0 < ScalesPerXdlopsRunPerThreadB, + "Must have at least one scale per Xdlops per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_buf[Number{}]; - }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_buf[Number{}]; + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + }); + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference( - Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); // Prefetch a_scales static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { @@ -519,131 +519,130 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}([&](auto k) { - constexpr auto k_step = - k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; - - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); + // LDS reads for A + static_ford< + Sequence>{}( + [&](auto km_chunk) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / BPackedSize / KThreadChunk, 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); + // LDS reads for B + static_ford< + Sequence>{}( + [&](auto kn_chunk) { + constexpr auto k = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); }); - }); - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + constexpr auto k0 = Number{}]>{}; + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThreadA && - 0 < ScalesPerXdlopsRunPerThreadB, - "Must have at least one scale per Xdlops per Thread."); + static_assert(0 < ScalesPerXdlopsRunPerThreadA && + 0 < ScalesPerXdlopsRunPerThreadB, + "Must have at least one scale per Xdlops per Thread."); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_buf[Number{}]; - }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_buf[Number{}]; + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + }); + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp index 18bc1e130b..e823a1f573 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp @@ -283,34 +283,31 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); @@ -354,34 +351,29 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); @@ -409,32 +401,28 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }; @@ -460,32 +448,28 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } else if constexpr(TailNum == TailNumber::Two) @@ -788,52 +772,52 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - // The block_sync_lds() here performs double duty: - // A) safeguard against data hazard because barrier from - // blockwise_gemm is moved here B) reduce VMEM FIFO congestion - // by applying small delays to different wavefronts It is - // performed near the end of MAC cluster to minimize lgkmcnt - // penalty - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from + // blockwise_gemm is moved here B) reduce VMEM FIFO congestion + // by applying small delays to different wavefronts It is + // performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); @@ -887,46 +871,46 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); @@ -963,46 +947,43 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); @@ -1039,46 +1020,43 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp index 712d26c897..96f683e60e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp @@ -349,39 +349,39 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[m0]) * - type_convert(b_scale_thread_buf[I0]); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[m0]) * + type_convert(b_scale_thread_buf[I0]); }); }); @@ -436,58 +436,57 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}([&](auto iprefetch) { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[m0]) * - type_convert(b_scale_thread_buf[I0]); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[m0]) * + type_convert(b_scale_thread_buf[I0]); }); }); @@ -526,57 +525,54 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[m0]) * - type_convert(b_scale_thread_buf[I0]); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[m0]) * + type_convert(b_scale_thread_buf[I0]); }); }); }; @@ -584,57 +580,54 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[m0]) * - type_convert(b_scale_thread_buf[I0]); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[m0]) * + type_convert(b_scale_thread_buf[I0]); }); }); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp index 43ff439e0d..cb56450721 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp @@ -264,54 +264,50 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto iprefetch) { // ------------------------------------------------------------------------------------------- block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); @@ -336,53 +332,48 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto iprefetch) { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); @@ -391,102 +382,94 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }; if constexpr(TailNum == TailNumber::One) { block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } else if constexpr(TailNum == TailNumber::Two) @@ -823,61 +806,52 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - // The block_sync_lds() here performs double duty: - // A) safeguard against data hazard because barrier from - // blockwise_gemm is moved here B) reduce VMEM FIFO congestion - // by applying small delays to different wavefronts It is - // performed near the end of MAC cluster to minimize lgkmcnt - // penalty - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) - // { - // constexpr index_t c_offset = - // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - // c_thread_buf(Number{}) += - // c_thread_buf_per_scale[Number{}] * - // type_convert(b_scale_thread_buf[n0]); - // }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from + // blockwise_gemm is moved here B) reduce VMEM FIFO congestion + // by applying small delays to different wavefronts It is + // performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); @@ -944,54 +918,46 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - // constexpr index_t c_offset = - // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - // c_thread_buf(Number{}) += - // c_thread_buf_per_scale[Number{}] * - // type_convert(b_scale_thread_buf[n0]); - // }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); @@ -1041,54 +1007,43 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - // constexpr index_t c_offset = - // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - // c_thread_buf(Number{}) += - // c_thread_buf_per_scale[Number{}] * - // type_convert(b_scale_thread_buf[n0]); - // }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); @@ -1125,54 +1080,43 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - // constexpr index_t c_offset = - // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - // c_thread_buf(Number{}) += - // c_thread_buf_per_scale[Number{}] * - // type_convert(b_scale_thread_buf[n0]); - // }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp index b4f90065dd..82d388ef9a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp @@ -363,34 +363,29 @@ struct BlockwiseGemmXdlops_pipeline_v3{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); @@ -423,32 +418,28 @@ struct BlockwiseGemmXdlops_pipeline_v3{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle // latency diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp index b78a9d3199..c14a06597c 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp @@ -471,42 +471,41 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - c_thread_buf(Number{}) += - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[Number{}] * - type_convert(c_scale_thread_buf[m0]); - }); + .template AsType()[Number{}] * + type_convert(c_scale_thread_buf[m0]); }); }); @@ -573,41 +572,39 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()(Number{}) = 0; - }); - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - c_thread_buf(Number{}) += - c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) - .template AsType()[Number{}] * - type_convert(c_scale_thread_buf[m0]); - }); + .template AsType()[Number{}] * + type_convert(c_scale_thread_buf[m0]); }); }); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp index 63177c8d45..f7e88d75dd 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp @@ -428,34 +428,29 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale{})); } - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); block_sync_lds(); @@ -490,32 +485,28 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); __builtin_amdgcn_sched_barrier(0); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp index 8d98a36dcd..a179c6c3bd 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp @@ -459,49 +459,51 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_bufs(I0), - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_bufs(I0), - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I0), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); }); // Global prefetch 2 @@ -577,91 +579,85 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford< + Sequence>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs( - scale_comp_buf)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_ford>{}([&](auto + kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()( - ik) = a_thread_buf - [Number{}]; - b_thread_vec.template AsType()( - ik) = b_thread_buf - [Number{}]; - }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference( - Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); // k indexes mapping to threads for 32x32x64: // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc. @@ -774,83 +770,81 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); __builtin_amdgcn_s_waitcnt(3952); block_sync_lds(); @@ -858,210 +852,206 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - a_block_bufs(I1), - a_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, - make_tuple(Number{}, - I0, - Number{}, - I0, - Number{}), - b_block_bufs(I1), - b_thread_desc_, - make_tuple(Number{}, - I0, - Number{}, - k, - Number{}), - b_thread_buf); - }); - }); + static_ford< + Sequence>{}( + [&](auto mc) { + constexpr auto m0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I1), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + // read block data in chunks to assemble correct thread vectors + static_ford< + Sequence>{}( + [&](auto nc) { + constexpr auto n0 = Number{}]>{}; + constexpr auto chunk = Number{}]>{}; + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I1), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); }); - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I1)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I1)[Number{}]; - }); + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + static_ford>{}( + [&](auto mnk) { + constexpr auto m0 = mnk[Number<0>{}]; + constexpr auto n0 = mnk[Number<1>{}]; + constexpr auto k0 = mnk[Number<2>{}]; - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_ford>{}([&](auto kmn_xdl) { + constexpr auto ikxdl = Number{}]>{}; + constexpr auto imxdl = Number{}]>{}; + constexpr auto inxdl = Number{}]>{}; + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto kxdl = ikxdl + k0 * KXdlPack; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0)); - using mfma_input_type_a = typename vector_type< // - ComputeTypeA, - xdlops_gemm.K1PerXdlops / APackedSize>::type; - - using mfma_input_type_b = typename vector_type< // - ComputeTypeB, - xdlops_gemm.K1PerXdlops / BPackedSize>::type; - - using mfma_scale_input_type_a = typename vector_type< // - AScaleDataType, - a_scale_thread_vec_size>::type; - using mfma_scale_input_type_b = typename vector_type< // - BScaleDataType, - b_scale_thread_vec_size>::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, imxdl, inxdl, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec - .template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec - .template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // MFMA accumulation + xdlops_gemm + .template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp index ff09fd011f..67a9769aca 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp @@ -261,54 +261,49 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + static_ford>{}([&](auto ii) { + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_more == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } + if constexpr(imfma % buffer_load_issue_point_interval_more == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier( - 0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); // global read less - static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_less == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier( - 0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + static_ford>{}([&](auto ii) { + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_less == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); // Stage 2, Sync // lds synchronization, prefetch next loop local A - static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier( - 0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); + static_ford>{}([&](auto ii) { + constexpr auto imfma = Number{}]>{}; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); } else @@ -537,25 +532,24 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * - (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_m3_k, - make_tuple( - I0, I0, Number{}, I0, Number{}), - a_block_bufs(I0), - a_thread_desc_, - make_tuple( - I0, I0, Number{}, k, Number{}), - a_thread_buf); - }); - }); + static_ford>{}([&](auto mk) { + constexpr auto m0 = Number{}]>{}; + constexpr auto k = Number{}]>{}; + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(I0, I0, Number{}, I0, Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple( + I0, I0, Number{}, k, Number{}), + a_thread_buf); + }); }); // Global prefetch 2 diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp index 7b94ed5086..d659f2c2e0 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp @@ -368,79 +368,10 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf] - [Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - HotLoopScheduler(); - }; - - LoopFunc(I1, I1, I0, I0); - LoopFunc(I0, I0, I1, I1); - - i += HotloopUnroll; - } while(i < (num_loop - PrefetchStages)); - } - - auto ReadWriteCompFunc = [&](auto lds_read_buf, - auto lds_read_reg_buf, - auto lds_write_buf, - auto mfma_reg_buf) { - block_sync_lds(); - - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(lds_read_buf), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(lds_read_reg_buf)); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(lds_read_buf), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(lds_read_reg_buf)); - }); - }); - - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf)); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf)); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; vector_type a_thread_vec; vector_type b_thread_vec; @@ -463,11 +394,72 @@ struct BlockwiseGemmXdlops_pipeline_v4(), c_thread_buf.GetVectorTypeReference(Number{})); }); - }); - }); - HotLoopScheduler(); - }; + HotLoopScheduler(); + }; + + LoopFunc(I1, I1, I0, I0); + LoopFunc(I0, I0, I1, I1); + + i += HotloopUnroll; + } while(i < (num_loop - PrefetchStages)); + } + + auto ReadWriteCompFunc = + [&](auto lds_read_buf, auto lds_read_reg_buf, auto lds_write_buf, auto mfma_reg_buf) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf)); + + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + + HotLoopScheduler(); + }; auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) { block_sync_lds(); @@ -491,64 +483,60 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); HotLoopScheduler(); }; auto CompFunc = [&](auto mfma_reg_buf) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }; // tail @@ -918,81 +906,10 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v4{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf] - [Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - HotLoopScheduler(); - }; - - LoopFunc(I1, I1, I0, I0); - LoopFunc(I0, I0, I1, I1); - - i += HotloopUnroll; - } while(i < (num_loop - PrefetchStages)); - } - - auto ReadWriteCompFunc = [&](auto lds_read_buf, - auto lds_read_reg_buf, - auto lds_write_buf, - auto mfma_reg_buf) { - block_sync_lds_direct_load(); - - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(lds_read_buf), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(lds_read_reg_buf)); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(lds_read_buf), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(lds_read_reg_buf)); - }); - }); - - a_blockwise_copy.Run( - a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(lds_write_buf)); - b_blockwise_copy.Run( - b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(lds_write_buf)); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; vector_type a_thread_vec; vector_type b_thread_vec; @@ -1015,11 +932,74 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v4(), c_thread_buf.GetVectorTypeReference(Number{})); }); - }); - }); - HotLoopScheduler(); - }; + HotLoopScheduler(); + }; + + LoopFunc(I1, I1, I0, I0); + LoopFunc(I0, I0, I1, I1); + + i += HotloopUnroll; + } while(i < (num_loop - PrefetchStages)); + } + + auto ReadWriteCompFunc = + [&](auto lds_read_buf, auto lds_read_reg_buf, auto lds_write_buf, auto mfma_reg_buf) { + block_sync_lds_direct_load(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(lds_write_buf)); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(lds_write_buf)); + + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + + HotLoopScheduler(); + }; auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) { block_sync_lds_direct_load(); @@ -1043,64 +1023,60 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v4{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); HotLoopScheduler(); }; auto CompFunc = [&](auto mfma_reg_buf) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }; // tail diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp index 3bbf4bb690..b48102b023 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp @@ -356,23 +356,23 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(I0)); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(I0), - b_scale_thread_bufs(I0)[n0], - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(I0)); - }); + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(I0)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(I0), + b_scale_thread_bufs(I0)[n0], + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(I0)); }); }); @@ -477,80 +477,10 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf] - [Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - HotLoopScheduler(); - }; - - LoopFunc(I1, I1, I0, I0); - LoopFunc(I0, I0, I1, I1); - - i += HotloopUnroll; - } while(i < (num_loop - PrefetchStages)); - } - - auto ReadWriteCompFunc = [&](auto lds_read_buf, - auto lds_read_reg_buf, - auto lds_write_buf, - auto mfma_reg_buf) { - block_sync_lds(); - - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(lds_read_buf), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(lds_read_reg_buf)); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(lds_read_buf), - b_scale_thread_bufs(lds_read_buf)[n0], - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(lds_read_reg_buf)); - }); - }); - - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf)); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf)); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; vector_type a_thread_vec; vector_type b_thread_vec; @@ -573,11 +503,73 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale(), c_thread_buf.GetVectorTypeReference(Number{})); }); - }); - }); - HotLoopScheduler(); - }; + HotLoopScheduler(); + }; + + LoopFunc(I1, I1, I0, I0); + LoopFunc(I0, I0, I1, I1); + + i += HotloopUnroll; + } while(i < (num_loop - PrefetchStages)); + } + + auto ReadWriteCompFunc = + [&](auto lds_read_buf, auto lds_read_reg_buf, auto lds_write_buf, auto mfma_reg_buf) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_scale_thread_bufs(lds_read_buf)[n0], + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf)); + + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + + HotLoopScheduler(); + }; auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) { block_sync_lds(); @@ -602,64 +594,60 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); HotLoopScheduler(); }; auto CompFunc = [&](auto mfma_reg_buf) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto kmn) { + constexpr auto k0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp index 2dfeb1d0cb..059544b239 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp @@ -590,27 +590,26 @@ struct BlockwiseGemmXdlops_pipeline_v5{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = a_thread_buf - [Number{}]; - }); - static_for<0, KPack, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = b_thread_buf - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; }); + static_for<0, KPack, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); HotLoopScheduler(); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index e3ce5e29c8..8bbf809521 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -365,56 +365,50 @@ struct BlockwiseGemmWMMA } else { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of - // k=0,kpack*1, .. - // read B - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0), - b_thread_buf); - // read A - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, I0, I0, I0, I0), - a_thread_buf); + static_ford>{}([&](auto nmk) { + constexpr auto n0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto k = Number{}]>{}; // k=0,1,2 instead of k=0,kpack*1, .. + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = - a_thread_buf[Number{}]; - }); - - static_for<0, KPack / B_KRow, 1>{}([&](auto i) { - b_thread_vec.template AsType()(i) = - b_thread_buf[Number{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - wmma_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack / A_KRow, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; }); + + static_for<0, KPack / B_KRow, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run<>(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } } @@ -862,60 +856,47 @@ struct BlockwiseGemmWMMA } else { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of - // k=0,kpack*1, .. - // read B - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0), - b_thread_buf); - // read A - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, I0, I0, I0, I0), - a_thread_buf); + static_ford>{}([&](auto nmk) { + constexpr auto n0 = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + constexpr auto k = Number{}]>{}; // k=0,1,2 instead of k=0,kpack*1, .. + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto i) { - b_thread_vec.template AsType()(i) = - b_thread_buf[Number{}]; - a_thread_vec.template AsType()(i) = - a_thread_buf[Number{}]; - }); - - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); + static_for<0, KPack, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index f81e93d82a..53646a4eba 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -514,54 +514,54 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 __builtin_amdgcn_sched_barrier(0); } static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_ford>{}([&](auto mn) { + constexpr auto m0 = Number{}]>{}; + constexpr auto n0 = Number{}]>{}; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - // The block_sync_lds() here performs double duty: - // A) safeguard against data hazard because barrier from blockwise_gemm is - // moved here B) reduce VMEM FIFO congestion by applying small delays to - // different wavefronts It is performed near the end of MAC cluster to - // minimize lgkmcnt penalty - if constexpr(k.value == KPerThread - KPerInnerLoop && - k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 && - n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - - // TODO: insert setprio in more precise manner since we - // could have more than >1 MFMA instructions in single call - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from blockwise_gemm is + // moved here B) reduce VMEM FIFO congestion by applying small delays to + // different wavefronts It is performed near the end of MAC cluster to + // minimize lgkmcnt penalty + if constexpr(k.value == KPerThread - KPerInnerLoop && + k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 && + n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + + // TODO: insert setprio in more precise manner since we + // could have more than >1 MFMA instructions in single call + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); }); __builtin_amdgcn_sched_barrier(0); @@ -953,44 +953,43 @@ struct BlockwiseGemmXdlops_v2 auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - static_for<0, KPerThread / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... - static_for<0, MRepeat, 1>{}([&](auto m0) { - // read A - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, + static_ford>{}([&](auto km) { + constexpr auto k = Number{}]>{}; + constexpr auto m0 = Number{}]>{}; + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, make_tuple(I0, I0, I0, I0), - a_thread_buf); + b_thread_buf); + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read B - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = a_thread_buf - [Number{}]; - b_thread_vec.template AsType()(i) = b_thread_buf - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp b/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp index e3d5afe6d8..ba061b20fa 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp @@ -98,13 +98,13 @@ struct BlockwiseSoftmax }); // calculate exp for elements, P=exp(s-max) - static_for<0, MRepeat, 1>{}([&](auto iM) { - static_for<0, KRepeat, 1>{}([&](auto iK) { - auto offset = Number{}; - in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset]) - ? 0 - : math::exp(in_thread_buf[offset] - max_value_buf(iM)); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + auto offset = Number{}; + in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset]) + ? 0 + : math::exp(in_thread_buf[offset] - max_value_buf(iM)); }); // sum data diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp index 6d1b454282..d234b9f846 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp @@ -480,19 +480,17 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf make_tuple(I0, I0), dy_thread_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - dy_elementwise_op(dy_thread_buf(Number{}), - dy_thread_buf[Number{}]); + dy_elementwise_op(dy_thread_buf(Number{}), dy_thread_buf[Number{}]); - AccDataType norm_x = (x_thread_buf[Number{}] - mean_thread_buf[iM]) * - inv_var_thread_buf[iM]; + AccDataType norm_x = + (x_thread_buf[Number{}] - mean_thread_buf[iM]) * inv_var_thread_buf[iM]; - tmp1_thread_buf(Number{}) = norm_x * dy_thread_buf[Number{}]; - }); + tmp1_thread_buf(Number{}) = norm_x * dy_thread_buf[Number{}]; }); ThreadwiseReduce::Reduce(tmp1_thread_buf, reduce_dscale_thread_buf); diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp index d1e7f35607..d516fd248c 100644 --- a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp @@ -502,15 +502,15 @@ struct EpilogueReduceCShuffle [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; }); // reduce in VGPR - static_for<0, mreduce_per_thread, 1>{}([&](auto im) { - static_for<0, nreduce_per_thread, 1>{}([&](auto in) { - constexpr auto offset = - Number{}; + static_ford>{}([&](auto ii) { + constexpr auto im = Number{}]>{}; + constexpr auto in = Number{}]>{}; + constexpr auto offset = + Number{}; - reduce_in_element_op(c_reduce_thread_buf(offset), - c_reduce_thread_buf(offset)); - }); + reduce_in_element_op(c_reduce_thread_buf(offset), + c_reduce_thread_buf(offset)); }); ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf); diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp index 87f3d50e10..f735124539 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp @@ -362,11 +362,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d make_tuple(I0), gamma_thread_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto m) { - static_for<0, NThreadSliceSize, 1>{}([&](auto n) { - constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n)); - h_thread_buf(Number{}) = h_thread_buf(Number{}) * gamma_thread_buf(n); - }); + static_ford>{}([&](auto mn) { + constexpr auto m = Number{}]>{}; + constexpr auto n = Number{}]>{}; + constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n)); + h_thread_buf(Number{}) = h_thread_buf(Number{}) * gamma_thread_buf(n); }); threadwise_beta_load_n.Run(beta_grid_desc_n, @@ -375,11 +375,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d make_tuple(I0), beta_thread_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto m) { - static_for<0, NThreadSliceSize, 1>{}([&](auto n) { - constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n)); - h_thread_buf(Number{}) = h_thread_buf(Number{}) + beta_thread_buf(n); - }); + static_ford>{}([&](auto mn) { + constexpr auto m = Number{}]>{}; + constexpr auto n = Number{}]>{}; + constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n)); + h_thread_buf(Number{}) = h_thread_buf(Number{}) + beta_thread_buf(n); }); threadwise_h_store_m_n.Run(thread_buffer_desc_m_n, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp index 949edb35f6..35f74fed95 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp @@ -218,14 +218,12 @@ struct GridwiseMultipleReduction_mk_to_m_multiblock in_thread_buf); static_for<0, NumReduction, 1>{}([&](auto iR) { - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // do element-wise pre-reduction operation - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = - thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number{}), - in_thread_buf(Number{})); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number{}), + in_thread_buf(Number{})); }); ThreadwiseReduce::Reduce(in_thread_buf_tuple(iR), accu_value_buf_tuple(iR)); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp index 5ad0ef3117..191a6942c4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp @@ -173,14 +173,12 @@ struct GridwiseMultipleReduction_mk_to_m_threadwise in_thread_buf); static_for<0, NumReduction, 1>{}([&](auto iR) { - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // do element-wise pre-reduction operation - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = - thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number{}), - in_thread_buf(Number{})); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number{}), + in_thread_buf(Number{})); }); ThreadwiseReduce::Reduce(in_thread_buf_tuple(iR), accu_value_buf_tuple(iR)); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp index f72aca8605..9ad1038251 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp @@ -212,13 +212,11 @@ struct GridwiseReduction_mk_to_m_multiblock make_tuple(I0, I0), in_thread_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // do element-wise pre-reduction operation - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - in_elementwise_op(in_thread_buf(Number{}), - in_thread_buf(Number{})); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op(in_thread_buf(Number{}), in_thread_buf(Number{})); }); ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp index fb20531133..a688878026 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp @@ -162,13 +162,11 @@ struct GridwiseReduction_mk_to_m_threadwise make_tuple(I0, I0), in_thread_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // do element-wise pre-reduction operation - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - in_elementwise_op(in_thread_buf(Number{}), - in_thread_buf(Number{})); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op(in_thread_buf(Number{}), in_thread_buf(Number{})); }); ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); @@ -340,15 +338,13 @@ struct GridwiseReduction_mk_to_m_threadwise make_tuple(I0, I0), in_thread_idx_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // do element-wise pre-reduction operation - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = - thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - in_elementwise_op(in_thread_val_buf(Number{}), - in_thread_val_buf(Number{})); - }); + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); }); ThreadwiseReduceWithIndex::Reduce( @@ -371,17 +367,15 @@ struct GridwiseReduction_mk_to_m_threadwise make_tuple(I0, I0), in_thread_val_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // do element-wise pre-reduction operation - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = - thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - in_thread_idx_buf(Number{}) = indexStart + iK(); + in_thread_idx_buf(Number{}) = indexStart + iK(); - in_elementwise_op(in_thread_val_buf(Number{}), - in_thread_val_buf(Number{})); - }); + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); }); ThreadwiseReduceWithIndex::Reduce( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp index 637ba27cef..e679592766 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp @@ -160,13 +160,11 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d make_tuple(I0, I0), in_thread_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // do element-wise pre-reduction operation - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - in_elementwise_op(in_thread_buf(Number{}), - in_thread_buf(Number{})); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op(in_thread_buf(Number{}), in_thread_buf(Number{})); }); ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp index e30ddf5c1c..10d83c4b32 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -1297,15 +1297,15 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); - static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { - static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { - auto I = Number{}; - Acc1DataType acc1 = acc1_thread_buf[I]; // P*V - Acc1DataType c = c_thread_buf[I]; // O - Acc1DataType c_new = c + acc1; // Simply add results since we are no longer using softmax. + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iN = Number{}]>{}; + auto I = Number{}; + Acc1DataType acc1 = acc1_thread_buf[I]; // P*V + Acc1DataType c = c_thread_buf[I]; // O + Acc1DataType c_new = c + acc1; // Simply add results since we are no longer using softmax. - c_thread_buf(I) = c_new; // O_new - }); + c_thread_buf(I) = c_new; // O_new }); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index 5722bbc146..98eba4fd2e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -1234,19 +1234,19 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); - static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { - static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { - auto I = Number{}; - FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V - FloatGemmAcc c = c_thread_buf[I]; // O - FloatGemmAcc c_new = - (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + - math::exp(max[iM] - running_max_new[iM]) * acc1) / - running_sum_new[iM]; // Formula by Dao et al., - // https://arxiv.org/pdf/2205.14135v2.pdf section 3.1 + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iN = Number{}]>{}; + auto I = Number{}; + FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V + FloatGemmAcc c = c_thread_buf[I]; // O + FloatGemmAcc c_new = + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + + math::exp(max[iM] - running_max_new[iM]) * acc1) / + running_sum_new[iM]; // Formula by Dao et al., + // https://arxiv.org/pdf/2205.14135v2.pdf section 3.1 - c_thread_buf(I) = c_new; // O_new - }); + c_thread_buf(I) = c_new; // O_new }); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp index 231acc7e4f..fd48c42b26 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp @@ -1410,18 +1410,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); - static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { - static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { - auto I = Number{}; - Acc1DataType acc1 = acc1_thread_buf[I]; // P*V - Acc1DataType c = c_thread_buf[I]; // O - Acc1DataType c_new = - (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + - math::exp(max[iM] - running_max_new[iM]) * acc1) / - running_sum_new[iM]; + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iN = Number{}]>{}; + auto I = Number{}; + Acc1DataType acc1 = acc1_thread_buf[I]; // P*V + Acc1DataType c = c_thread_buf[I]; // O + Acc1DataType c_new = + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + + math::exp(max[iM] - running_max_new[iM]) * acc1) / + running_sum_new[iM]; - c_thread_buf(I) = c_new; // O_new - }); + c_thread_buf(I) = c_new; // O_new }); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index 5cb9eac548..ffa3e464b9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -1048,19 +1048,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); - static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { - static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { - auto I = Number{}; - FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V - FloatGemmAcc c = c_thread_buf[I]; // O - FloatGemmAcc c_new = - (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + - math::exp(max[iM] - running_max_new[iM]) * acc1) / - running_sum_new[iM]; // Formula by Dao et al., - // https://arxiv.org/pdf/2205.14135v2.pdf section 3.1 + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iN = Number{}]>{}; + auto I = Number{}; + FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V + FloatGemmAcc c = c_thread_buf[I]; // O + FloatGemmAcc c_new = + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + + math::exp(max[iM] - running_max_new[iM]) * acc1) / + running_sum_new[iM]; // Formula by Dao et al., + // https://arxiv.org/pdf/2205.14135v2.pdf section 3.1 - c_thread_buf(I) = c_new; // O_new - }); + c_thread_buf(I) = c_new; // O_new }); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp index 1a9bbcb603..196cfceb4b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp @@ -437,19 +437,17 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford make_tuple(I0, I0), dy_thread_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - dy_elementwise_op(dy_thread_buf(Number{}), - dy_thread_buf[Number{}]); + dy_elementwise_op(dy_thread_buf(Number{}), dy_thread_buf[Number{}]); - AccDataType norm_x = (x_thread_buf[Number{}] - mean_thread_buf[iM]) * - inv_var_thread_buf[iM]; + AccDataType norm_x = + (x_thread_buf[Number{}] - mean_thread_buf[iM]) * inv_var_thread_buf[iM]; - tmp1_thread_buf(Number{}) = norm_x * dy_thread_buf[Number{}]; - }); + tmp1_thread_buf(Number{}) = norm_x * dy_thread_buf[Number{}]; }); ThreadwiseReduce::Reduce(tmp1_thread_buf, dscale_thread_buf); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp index f2cd6fca5c..2d8f9dd3ab 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp @@ -355,8 +355,10 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { // input add loop - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + static_ford>{}( + [&](auto mk) { // input add loop + constexpr auto iM = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; constexpr auto offset_m_k = thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); @@ -376,7 +378,6 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk unpack2(x_elementwise_op, out_data_refs, in_data_refs); }); - }); threadwise_welford.Run(x_thread_buf[iK0], mean_thread_buf, var_thread_buf); if constexpr(!SweepOnce) @@ -435,21 +436,20 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); - static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}([&](auto ii) { + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // normalize - y_thread_buf(iK0)(Number{}) = - (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * divisor; - // gamma - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) * - gamma_thread_buf(iK0)(Number{}); - }); + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); }); }); @@ -463,19 +463,19 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto mii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // beta - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) + - beta_thread_buf(iK0)(Number{}); - }); + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); }); - }); static_for<0, YThreadBufferNumber, 1>{}([&](auto i) { threadwise_y_store.Run(thread_buffer_desc_m_k, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index ccf4b04a6c..26692594da 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -937,8 +937,10 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; }); // reduce in VGPR - static_for<0, mreduce_per_thread, 1>{}([&](auto im) { - static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + static_ford>{}( + [&](auto ii) { + constexpr auto im = Number{}]>{}; + constexpr auto in = Number{}]>{}; constexpr auto offset = Number{}; @@ -946,7 +948,6 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 reduce_in_element_op(c_reduce_thread_buf(offset), c_reduce_thread_buf(offset)); }); - }); ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 3d9095bffb..c2af166e85 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -881,14 +881,14 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ThreadReduceOperation::template GetIdentityValue(); static_for<0, mreduce_per_thread, 1>{}( [&](auto I) { r_thread_buf(I) = reduce_identityVal; }); - static_for<0, mreduce_per_thread, 1>{}([&](auto im) { - static_for<0, nreduce_per_thread, 1>{}([&](auto in) { - constexpr auto offset = - Number{}; + static_ford>{}([&](auto ii) { + constexpr auto im = Number{}]>{}; + constexpr auto in = Number{}]>{}; + constexpr auto offset = + Number{}; - qs_element_op[Ir](e_thread_buf(offset), e_thread_buf(offset)); - }); + qs_element_op[Ir](e_thread_buf(offset), e_thread_buf(offset)); }); ThreadwiseReduce::Reduce(e_thread_buf, r_thread_buf); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index f1e97455b7..530194ee22 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -820,8 +820,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; }); // reduce in VGPR - static_for<0, mreduce_per_thread, 1>{}([&](auto im) { - static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + static_ford>{}( + [&](auto ii) { + constexpr auto im = Number{}]>{}; + constexpr auto in = Number{}]>{}; constexpr auto offset = Number{}; @@ -829,7 +831,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 reduce_in_element_op(c_reduce_thread_buf(offset), c_reduce_thread_buf(offset)); }); - }); ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index 9569cab98b..ec5c449a78 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -997,26 +997,26 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_grid_desc_mblock_mperblock_nblock_nperblock.GetTransforms()[I0] .GetUpperLengths()[I1]; // TODO: proper handle - static_for<0, mreduce_per_thread, 1>{}([&](auto im) { - static_for<0, nreduce_per_thread, 1>{}([&](auto in) { - constexpr auto dst_offset = - Number{}; + static_ford>{}([&](auto ii) { + constexpr auto im = Number{}]>{}; + constexpr auto in = Number{}]>{}; + constexpr auto dst_offset = + Number{}; - constexpr auto src_offset = - Number{}; + constexpr auto src_offset = + Number{}; - FloatReduceAcc avg_sum = d0_thread_buf(src_offset) / NRaw; - FloatReduceAcc avg_squared_sum = d1_thread_buf(src_offset) / NRaw; + FloatReduceAcc avg_sum = d0_thread_buf(src_offset) / NRaw; + FloatReduceAcc avg_squared_sum = d1_thread_buf(src_offset) / NRaw; - FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum; - FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum; - FloatReduceAcc divisor_sqrt; - tensor_operation::element_wise::UnarySqrt{}(divisor_sqrt, divisor); + FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum; + FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum; + FloatReduceAcc divisor_sqrt; + tensor_operation::element_wise::UnarySqrt{}(divisor_sqrt, divisor); - c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt; - }); + c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt; }); // scaling diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp index a6fa04a824..a3835ed7fd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp @@ -290,12 +290,12 @@ struct GridwiseSoftmax_mk_to_mk } // do element-wise pre-reduction operation - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - out_thread_buf(Number{}) = - math::exp(in_thread_buf(Number{}) - max_value_buf(iM)); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + math::exp(in_thread_buf(Number{}) - max_value_buf(iM)); }); ThreadwiseSumReduce::Reduce(out_thread_buf, accu_value_buf); @@ -330,15 +330,13 @@ struct GridwiseSoftmax_mk_to_mk in_thread_buf); } - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = - thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - out_thread_buf(Number{}) = - alpha * math::exp(in_thread_buf(Number{}) - max_value_buf(iM)) / - accu_value_buf(iM); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + alpha * math::exp(in_thread_buf(Number{}) - max_value_buf(iM)) / + accu_value_buf(iM); }); threadwise_dst_store.Run(thread_buffer_desc, @@ -376,16 +374,14 @@ struct GridwiseSoftmax_mk_to_mk make_tuple(I0, I0), in_prior_dst_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = - thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - out_thread_buf(Number{}) = - alpha * math::exp(in_thread_buf(Number{}) - max_value_buf(iM)) / - accu_value_buf(iM) + - beta * in_prior_dst_buf(Number{}); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + alpha * math::exp(in_thread_buf(Number{}) - max_value_buf(iM)) / + accu_value_buf(iM) + + beta * in_prior_dst_buf(Number{}); }); threadwise_dst_store.Run(thread_buffer_desc, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp index 5036f4ae7f..88719c831e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp @@ -175,36 +175,36 @@ struct GridwiseSparseEmbeddingsForwardLayernorm }; auto accumulate_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { - static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { - static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { - constexpr auto register_offset = thread_buf_desc.CalculateOffset( - make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); - auto in_data_refs = generate_tie( - [&](auto i_embedding_) -> const auto& { - return in_thread_bufs(i_embedding_)(Number{}); - }, - Number{}); - auto out_data_refs = generate_tie( - [&](auto) -> auto& { return acc_thread_buf(Number{}); }, - Number<1>{}); - unpack2(emb_elementwise_op, out_data_refs, in_data_refs); - }); + static_ford>{}([&](auto ii) { + constexpr auto i_dim_vec_ = Number{}]>{}; + constexpr auto i_row_vec_ = Number{}]>{}; + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + auto in_data_refs = generate_tie( + [&](auto i_embedding_) -> const auto& { + return in_thread_bufs(i_embedding_)(Number{}); + }, + Number{}); + auto out_data_refs = generate_tie( + [&](auto) -> auto& { return acc_thread_buf(Number{}); }, + Number<1>{}); + unpack2(emb_elementwise_op, out_data_refs, in_data_refs); }); }; auto threadwise_welford_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { - static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { - static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { - constexpr auto register_offset = thread_buf_desc.CalculateOffset( - make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); - constexpr auto mean_var_offset = - mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_)); + static_ford>{}([&](auto ii) { + constexpr auto i_dim_vec_ = Number{}]>{}; + constexpr auto i_row_vec_ = Number{}]>{}; + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + constexpr auto mean_var_offset = + mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_)); - threadwise_welford.cur_count_++; - threadwise_welford.Update(mean_thread_buf(Number{}), - var_thread_buf(Number{}), - acc_thread_buf(Number{})); - }); + threadwise_welford.cur_count_++; + threadwise_welford.Update(mean_thread_buf(Number{}), + var_thread_buf(Number{}), + acc_thread_buf(Number{})); }); }; @@ -246,12 +246,11 @@ struct GridwiseSparseEmbeddingsForwardLayernorm }; // first load index - ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) { + ck::static_ford>{}([&](auto ie) { + constexpr auto i_idx_ = Number{}]>{}; + constexpr auto i_embedding_ = Number{}]>{}; // prefer use s_load - ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { - index_bufs(i_embedding_)(i_idx_) = - p_indexes[i_embedding_][index_start + i_idx_.value]; - }); + index_bufs(i_embedding_)(i_idx_) = p_indexes[i_embedding_][index_start + i_idx_.value]; }); // load gamma/beta diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm_builtins.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm_builtins.hpp index a97114d48d..6aa046a43a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm_builtins.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm_builtins.hpp @@ -176,36 +176,36 @@ struct GridwiseSparseEmbeddingsForwardLayernorm }; auto accumulate_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { - static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { - static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { - constexpr auto register_offset = thread_buf_desc.CalculateOffset( - make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); - auto in_data_refs = generate_tie( - [&](auto i_embedding_) -> const auto& { - return in_thread_bufs(i_embedding_)(Number{}); - }, - Number{}); - auto out_data_refs = generate_tie( - [&](auto) -> auto& { return acc_thread_buf(Number{}); }, - Number<1>{}); - unpack2(emb_elementwise_op, out_data_refs, in_data_refs); - }); + static_ford>{}([&](auto ii) { + constexpr auto i_dim_vec_ = Number{}]>{}; + constexpr auto i_row_vec_ = Number{}]>{}; + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + auto in_data_refs = generate_tie( + [&](auto i_embedding_) -> const auto& { + return in_thread_bufs(i_embedding_)(Number{}); + }, + Number{}); + auto out_data_refs = generate_tie( + [&](auto) -> auto& { return acc_thread_buf(Number{}); }, + Number<1>{}); + unpack2(emb_elementwise_op, out_data_refs, in_data_refs); }); }; auto threadwise_welford_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { - static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { - static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { - constexpr auto register_offset = thread_buf_desc.CalculateOffset( - make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); - constexpr auto mean_var_offset = - mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_)); + static_ford>{}([&](auto ii) { + constexpr auto i_dim_vec_ = Number{}]>{}; + constexpr auto i_row_vec_ = Number{}]>{}; + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + constexpr auto mean_var_offset = + mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_)); - threadwise_welford.cur_count_++; - threadwise_welford.Update(mean_thread_buf(Number{}), - var_thread_buf(Number{}), - acc_thread_buf(Number{})); - }); + threadwise_welford.cur_count_++; + threadwise_welford.Update(mean_thread_buf(Number{}), + var_thread_buf(Number{}), + acc_thread_buf(Number{})); }); }; @@ -247,12 +247,11 @@ struct GridwiseSparseEmbeddingsForwardLayernorm }; // first load index - ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) { + ck::static_ford>{}([&](auto ie) { + constexpr auto i_idx_ = Number{}]>{}; + constexpr auto i_embedding_ = Number{}]>{}; // prefer use s_load - ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { - index_bufs(i_embedding_)(i_idx_) = - p_indexes[i_embedding_][index_start + i_idx_.value]; - }); + index_bufs(i_embedding_)(i_idx_) = p_indexes[i_embedding_][index_start + i_idx_.value]; }); // load gamma/beta diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp index 6c42dc33f4..f23744bf67 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp @@ -328,14 +328,14 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk make_tuple(I0, I0), gamma_thread_buf(i)); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - x_square_thread_buf(i)(Number{}) = - x_thread_buf(i)(Number{}) * - x_thread_buf(i)(Number{}); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + x_square_thread_buf(i)(Number{}) = + x_thread_buf(i)(Number{}) * + x_thread_buf(i)(Number{}); }); ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf); @@ -391,24 +391,24 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk } // normalization - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // normalize - y_thread_buf(iK0)(Number{}) = - (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - inv_std_thread_buf(iM); + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + inv_std_thread_buf(iM); - // gamma & beta - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) * - gamma_thread_buf(iK0)(Number{}); - }); + // gamma & beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_beta_load.Run(beta_grid_desc_m_k, @@ -422,19 +422,19 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // beta - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) + - beta_thread_buf(iK0)(Number{}); - }); + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_y_store.Run(thread_buffer_desc_m_k, @@ -460,14 +460,14 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk x_thread_buf(i)); threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - x_square_thread_buf(i)(Number{}) = - x_thread_buf(i)(Number{}) * - x_thread_buf(i)(Number{}); - }); + static_ford>{}([&](auto ii) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + x_square_thread_buf(i)(Number{}) = + x_thread_buf(i)(Number{}) * + x_thread_buf(i)(Number{}); }); ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf); @@ -544,24 +544,24 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // normalize - y_thread_buf(iK0)(Number{}) = - (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - inv_std_thread_buf(iM); + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + inv_std_thread_buf(iM); - // gamma - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) * - gamma_thread_buf(iK0)(Number{}); - }); + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_beta_load.Run(beta_grid_desc_m_k, @@ -573,19 +573,19 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // beta - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) + - beta_thread_buf(iK0)(Number{}); - }); + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_y_store.Run(thread_buffer_desc_m_k, diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp index 9deb5a5f48..0a33d08a5e 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp @@ -441,24 +441,24 @@ struct GridwiseNormalizationSplitK2nd thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // normalize - y_thread_buf(iK0)(Number{}) = - (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - inv_std_thread_buf(iM); + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + inv_std_thread_buf(iM); - // gamma - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) * - gamma_thread_buf(iK0)(Number{}); - }); + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_beta_load.Run(beta_grid_desc_m_k, @@ -470,19 +470,19 @@ struct GridwiseNormalizationSplitK2nd thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // beta - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) + - beta_thread_buf(iK0)(Number{}); - }); + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_y_store.Run(thread_buffer_desc_m_k, diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp index d57871f331..7ca6c9ce7a 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp @@ -365,24 +365,24 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk } // normalization - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // normalize - y_thread_buf(iK0)(Number{}) = - (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - inv_std_thread_buf(iM); + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + inv_std_thread_buf(iM); - // gamma & beta - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) * - gamma_thread_buf(iK0)(Number{}); - }); + // gamma & beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_beta_load.Run(beta_grid_desc_m_k, @@ -396,19 +396,19 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // beta - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) + - beta_thread_buf(iK0)(Number{}); - }); + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_y_store.Run(thread_buffer_desc_m_k, @@ -496,24 +496,24 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // normalize - y_thread_buf(iK0)(Number{}) = - (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - inv_std_thread_buf(iM); + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + inv_std_thread_buf(iM); - // gamma - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) * - gamma_thread_buf(iK0)(Number{}); - }); + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_beta_load.Run(beta_grid_desc_m_k, @@ -525,19 +525,19 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk thread_copy_fwd_step_m_k); }); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + static_ford>{}( + [&](auto idx) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK0 = Number{}]>{}; + constexpr auto iK1 = Number{}]>{}; + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // beta - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) + - beta_thread_buf(iK0)(Number{}); - }); + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); }); - }); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_y_store.Run(thread_buffer_desc_m_k, diff --git a/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp index f4a2bc399d..44d14e4c04 100644 --- a/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp @@ -35,14 +35,13 @@ struct ThreadwiseReduction template __device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf) { - static_for<0, src_length_m, 1>{}([&](auto iM) { + static_ford>{}([&](auto mk) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM)); + constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - static_for<0, src_length_k, 1>{}([&](auto iK) { - constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - - Accumulation::Calculate(dst_buf(Number{}), src_buf[Number{}]); - }); + Accumulation::Calculate(dst_buf(Number{}), src_buf[Number{}]); }); }; }; @@ -81,17 +80,16 @@ struct ThreadwiseReductionWithIndex DstValueBufferType& dst_val_buf, DstIndexBufferType& dst_idx_buf) { - static_for<0, src_length_m, 1>{}([&](auto iM) { + static_ford>{}([&](auto mk) { + constexpr auto iM = Number{}]>{}; + constexpr auto iK = Number{}]>{}; constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM)); + constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - static_for<0, src_length_k, 1>{}([&](auto iK) { - constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - - Accumulation::Calculate(dst_val_buf(Number{}), - src_val_buf[Number{}], - dst_idx_buf(Number{}), - src_idx_buf[Number{}]); - }); + Accumulation::Calculate(dst_val_buf(Number{}), + src_val_buf[Number{}], + dst_idx_buf(Number{}), + src_idx_buf[Number{}]); }); }; }; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp index 06aca9c922..3a4d8f5d80 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp @@ -81,28 +81,21 @@ struct ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1 constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); - static_for<0, TK, 1>{}([&](auto tk) { - static_for<0, TM0, 1>{}([&](auto tm0) { - static_for<0, TM1, 1>{}([&](auto tm1) { - static_for<0, TN0, 1>{}([&](auto tn0) { - static_for<0, TN1, 1>{}([&](auto tn1) { - constexpr index_t a_offset = - AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( - a_origin_idx + make_multi_index(tk, tm0, tm1)); - constexpr index_t b_offset = - BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( - b_origin_idx + make_multi_index(tk, tn0, tn1)); - constexpr index_t c_offset = - CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( - c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); + static_ford>{}([&](auto tkmn) { + constexpr auto tk = Number{}]>{}; + constexpr auto tm0 = Number{}]>{}; + constexpr auto tm1 = Number{}]>{}; + constexpr auto tn0 = Number{}]>{}; + constexpr auto tn1 = Number{}]>{}; + constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( + a_origin_idx + make_multi_index(tk, tm0, tm1)); + constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( + b_origin_idx + make_multi_index(tk, tn0, tn1)); + constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( + c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); - inner_product(a_buf[Number{}], - b_buf[Number{}], - c_buf(Number{})); - }); - }); - }); - }); + inner_product( + a_buf[Number{}], b_buf[Number{}], c_buf(Number{})); }); } }; @@ -181,42 +174,35 @@ struct ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0 constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); - static_for<0, TK0, 1>{}([&](auto tk0) { - static_for<0, TM0, 1>{}([&](auto tm0) { - static_for<0, TM1, 1>{}([&](auto tm1) { - static_for<0, TN0, 1>{}([&](auto tn0) { - static_for<0, TN1, 1>{}([&](auto tn1) { - vector_type a_vec; - vector_type b_vec; + static_ford>{}([&](auto tkmn) { + constexpr auto tk0 = Number{}]>{}; + constexpr auto tm0 = Number{}]>{}; + constexpr auto tm1 = Number{}]>{}; + constexpr auto tn0 = Number{}]>{}; + constexpr auto tn1 = Number{}]>{}; + vector_type a_vec; + vector_type b_vec; - static_for<0, TK1, 1>{}([&](auto tk1) { - constexpr index_t a_offset = - AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( - a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1)); + static_for<0, TK1, 1>{}([&](auto tk1) { + constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( + a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1)); - constexpr index_t b_offset = - BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( - b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1)); + constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( + b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1)); - a_vec.template AsType()(tk1) = a_buf[Number{}]; - b_vec.template AsType()(tk1) = b_buf[Number{}]; - }); - - using a_vector_t = typename vector_type::type; - using b_vector_t = typename vector_type::type; - - constexpr index_t c_offset = - CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( - c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); - - inner_product( - a_vec.template AsType()[I0], - b_vec.template AsType()[I0], - c_buf(Number{})); - }); - }); - }); + a_vec.template AsType()(tk1) = a_buf[Number{}]; + b_vec.template AsType()(tk1) = b_buf[Number{}]; }); + + using a_vector_t = typename vector_type::type; + using b_vector_t = typename vector_type::type; + + constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( + c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); + + inner_product(a_vec.template AsType()[I0], + b_vec.template AsType()[I0], + c_buf(Number{})); }); } }; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp index 2896375636..89ded8ee7c 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp @@ -81,53 +81,49 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3 static_for<0, K, 1>{}([&](auto k) { static_for<0, Ho, SubHW>{}([&](auto h) { static_for<0, Wo, SubHW>{}([&](auto w) { - static_for<0, E1, 1>{}([&](auto e1) { - static_for<0, E2, 1>{}([&](auto e2) { - constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset( - a_origin_idx + make_tuple(e1, k, e2)); + static_ford>{}([&](auto ee) { + constexpr auto e1 = Number{}]>{}; + constexpr auto e2 = Number{}]>{}; + constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset( + a_origin_idx + make_tuple(e1, k, e2)); - constexpr index_t b0_offset = - BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( - b_origin_idx + make_tuple(e1, 0, h, w, e2)); + constexpr index_t b0_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h, w, e2)); - constexpr index_t b1_offset = - BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( - b_origin_idx + make_tuple(e1, 0, h, w + 1, e2)); + constexpr index_t b1_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h, w + 1, e2)); - constexpr index_t b2_offset = - BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( - b_origin_idx + make_tuple(e1, 0, h + 1, w, e2)); + constexpr index_t b2_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h + 1, w, e2)); - constexpr index_t b3_offset = - BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( - b_origin_idx + make_tuple(e1, 0, h + 1, w + 1, e2)); + constexpr index_t b3_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h + 1, w + 1, e2)); - constexpr index_t c0_offset = - CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + - make_tuple(k, 0, h, w)); + constexpr index_t c0_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h, w)); - constexpr index_t c1_offset = - CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( - c_origin_idx + make_tuple(k, 0, h, w + 1)); + constexpr index_t c1_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h, w + 1)); - constexpr index_t c2_offset = - CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( - c_origin_idx + make_tuple(k, 0, h + 1, w)); + constexpr index_t c2_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h + 1, w)); - constexpr index_t c3_offset = - CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( - c_origin_idx + make_tuple(k, 0, h + 1, w + 1)); + constexpr index_t c3_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h + 1, w + 1)); - amd_assembly_outer_product_1x4(a_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - c_buf(Number{}), - c_buf(Number{}), - c_buf(Number{}), - c_buf(Number{})); - }); + amd_assembly_outer_product_1x4(a_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{})); }); }); }); @@ -136,29 +132,24 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3 else { - static_for<0, K, 1>{}([&](auto k) { - static_for<0, Ho, 1>{}([&](auto h) { - static_for<0, Wo, 1>{}([&](auto w) { - static_for<0, E1, 1>{}([&](auto e1) { - static_for<0, E2, 1>{}([&](auto e2) { - constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset( - a_origin_idx + make_tuple(e1, k, e2)); + static_ford>{}([&](auto khwe) { + constexpr auto k = Number{}]>{}; + constexpr auto h = Number{}]>{}; + constexpr auto w = Number{}]>{}; + constexpr auto e1 = Number{}]>{}; + constexpr auto e2 = Number{}]>{}; + constexpr index_t a_offset = + AThreadDesc_E1_K_E2{}.CalculateOffset(a_origin_idx + make_tuple(e1, k, e2)); - constexpr index_t b_offset = - BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( - b_origin_idx + make_tuple(e1, 0, h, w, e2)); + constexpr index_t b_offset = BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h, w, e2)); - constexpr index_t c_offset = - CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + - make_tuple(k, 0, h, w)); + constexpr index_t c_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, h, w)); - inner_product(a_buf[Number{}], - b_buf[Number{}], - c_buf(Number{})); - }); - }); - }); - }); + inner_product(a_buf[Number{}], + b_buf[Number{}], + c_buf(Number{})); }); } } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index afd1e67bd0..91bf2d5832 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1827,25 +1827,24 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic } else { - static_for<0, num_access, 1>{}([&](auto idx_1d) { + static_ford>{}([&](auto access_idx) { + constexpr auto idx_1d = Number{}]>{}; + constexpr auto i = Number{}]>{}; constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); - // copy data from src_buf into dst_vector - static_for<0, DstScalarPerVector, 1>{}([&](auto i) { - constexpr index_t src_offset = src_desc.CalculateOffset( - src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - DstData v; + DstData v; - // apply element-wise operation - element_op_(v, src_buf[Number{}]); + // apply element-wise operation + element_op_(v, src_buf[Number{}]); - // apply type convert - dst_buf(Number{}) = v; - }); + // apply type convert + dst_buf(Number{}) = v; }); } } @@ -1933,57 +1932,56 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - static_for<0, num_access, 1>{}([&](auto idx_1d) { + static_ford>{}([&](auto access_idx) { + constexpr auto idx_1d = Number{}]>{}; + constexpr auto i = Number{}]>{}; constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); - // copy data from src_buf into dst_vector - static_for<0, DstScalarPerVector, 1>{}([&](auto i) { - // src_desc error, non constexpr, caused by merge transform - constexpr index_t src_offset = src_desc.CalculateOffset( - src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + // src_desc error, non constexpr, caused by merge transform + constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md + + i * dst_scalar_step_in_vector); - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + constexpr index_t dst_offset = dst_desc.CalculateOffset(dst_slice_origin_idx + idx_md + + i * dst_scalar_step_in_vector); - SrcData v_this_row, v_theother_row; - // int type temp value due to intrinsic requirement - int temp = 0; + SrcData v_this_row, v_theother_row; + // int type temp value due to intrinsic requirement + int temp = 0; - // apply element-wise operation - element_op_(v_this_row, src_buf[Number{}]); + // apply element-wise operation + element_op_(v_this_row, src_buf[Number{}]); - // apply intra-row permute. - if constexpr(IntraRowSwizzlePerm) - { - temp = __builtin_amdgcn_permlane16( - temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); - v_this_row = type_convert_sp(temp); - } + // apply intra-row permute. + if constexpr(IntraRowSwizzlePerm) + { + temp = __builtin_amdgcn_permlane16( + temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); + v_this_row = type_convert_sp(temp); + } - // apply inter-row permute. - temp = __builtin_amdgcn_permlanex16(temp, - type_convert_sp(v_this_row), - LowEightRowlaneIdx, - HighEightRowLaneIdx, - 1, - 0); - v_theother_row = type_convert_sp(temp); + // apply inter-row permute. + temp = __builtin_amdgcn_permlanex16(temp, + type_convert_sp(v_this_row), + LowEightRowlaneIdx, + HighEightRowLaneIdx, + 1, + 0); + v_theother_row = type_convert_sp(temp); - if(get_thread_local_1d_id() % 32 < 16) - { - // apply type convert - dst_buf(Number{}) = type_convert_sp(v_this_row); - dst_buf(Number{}) = - type_convert_sp(v_theother_row); - } - else - { - // apply type convert - dst_buf(Number{}) = - type_convert_sp(v_this_row); - dst_buf(Number{}) = type_convert_sp(v_theother_row); - } - }); + if(get_thread_local_1d_id() % 32 < 16) + { + // apply type convert + dst_buf(Number{}) = type_convert_sp(v_this_row); + dst_buf(Number{}) = + type_convert_sp(v_theother_row); + } + else + { + // apply type convert + dst_buf(Number{}) = + type_convert_sp(v_this_row); + dst_buf(Number{}) = type_convert_sp(v_theother_row); + } }); } }; @@ -2060,36 +2058,35 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - static_for<0, num_access, 1>{}([&](auto idx_1d) { + static_ford>{}([&](auto access_idx) { + constexpr auto idx_1d = Number{}]>{}; + constexpr auto i = Number{}]>{}; constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); - // copy data from src_buf into dst_vector - static_for<0, DstScalarPerVector, 1>{}([&](auto i) { - // src_desc error, non constexpr, caused by merge transform - constexpr index_t src_offset = src_desc.CalculateOffset( - src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + // src_desc error, non constexpr, caused by merge transform + constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md + + i * dst_scalar_step_in_vector); - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + constexpr index_t dst_offset = dst_desc.CalculateOffset(dst_slice_origin_idx + idx_md + + i * dst_scalar_step_in_vector); - SrcData v_this_row; - // int type temp value due to intrinsic requirement - int temp = 0; + SrcData v_this_row; + // int type temp value due to intrinsic requirement + int temp = 0; - // apply element-wise operation - element_op_(v_this_row, src_buf[Number{}]); + // apply element-wise operation + element_op_(v_this_row, src_buf[Number{}]); - // apply intra-row permute. - if constexpr(IntraRowSwizzlePerm) - { - temp = __builtin_amdgcn_permlane16( - temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); - v_this_row = type_convert_sp(temp); - } + // apply intra-row permute. + if constexpr(IntraRowSwizzlePerm) + { + temp = __builtin_amdgcn_permlane16( + temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); + v_this_row = type_convert_sp(temp); + } - // apply type convert - dst_buf(Number{}) = type_convert_sp(v_this_row); - }); + // apply type convert + dst_buf(Number{}) = type_convert_sp(v_this_row); }); } ElementwiseOperation element_op_{};