diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp index 8e0210a37b..e241feb843 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp @@ -121,7 +121,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale; using Base::I0; - using Base::I1; using Base::KRepeat; using Base::xdlops_gemm; using typename Base::HotLoopInstList; @@ -144,10 +143,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}) == 1, "Pipeline v3 only support scaleblocksliceN=1"); // assume kperblock = scaleblockk - using AThreadBufType = decltype(make_static_buffer( - a_thread_desc_.GetElementSpaceSize())); - using BThreadBufType = decltype(make_static_buffer( - b_thread_desc_.GetElementSpaceSize())); - AThreadBufType a_thread_buf{}; - BThreadBufType b_thread_buf{}; - - using AScaleBufferType = decltype(make_static_buffer( - a_scale_thread_desc.GetElementSpaceSize())); - using BScaleBufferType = decltype(make_static_buffer( - b_scale_thread_desc.GetElementSpaceSize())); - StaticallyIndexedArray{}> a_scale_thread_bufs; - StaticallyIndexedArray{}> b_scale_thread_bufs; + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); auto c_scale_thread_buf = make_static_buffer( c_scale_thread_desc.GetElementSpaceSize()); @@ -359,7 +352,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{})); }); @@ -379,12 +372,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { - c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; + c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0]; }); // Local prefill 1 @@ -403,7 +396,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{})); }); @@ -423,19 +416,17 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale c_thread_buf_per_scale; @@ -461,33 +452,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(I0).template AsType()( - Number{}) = 0; - }); - - // Fill first MFMA buffer with index I0, this output used in the first part of main loop for - // scale-FMA - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = a_thread_buf - [Number{}]; - b_thread_vec.template AsType()(ik) = b_thread_buf - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - }); - __builtin_amdgcn_sched_barrier(0); // main body @@ -507,72 +471,26 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc, - make_tuple(m0, I0), - a_scale_thread_bufs(I0)); - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); - }); - - if constexpr(NumKBlockPerScale == 1) - { - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); - } - else - { - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); - } - - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc, - make_tuple(I0, I0), - b_scale_thread_bufs(I0)); - - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); - - static_for<0, MRepeat, 1>{}([&](auto m0) { - vector_type c_scale_thread_vec; - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[m0]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[m0]; static_for<0, NRepeat, 1>{}([&](auto n0) { - // Calculate buffer offsets using future tile approach - constexpr auto mfma_buf_offset = - ((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops(); - constexpr auto scale_buf_offset = + // Use double buffering with temporal offset to decouple MFMA and scaling + constexpr auto buffer_offset = ((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops(); - // Calculate future tile data offsets - constexpr auto a_future_tile_offset = - ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat; - constexpr auto b_future_tile_offset = - ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat; - - // Clear the MFMA output buffer for future tile static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number{}) + c_thread_buf_per_scale.GetVectorTypeReference(Number{}) .template AsType()(Number{}) = 0; }); - - // Compute MFMA for future tile static_for<0, KRepeat, 1>{}([&](auto k0) { vector_type a_thread_vec; vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - // Use future tile offsets for MFMA computation a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; + make_tuple(m0, I0, k0, ik))>{}]; b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; + make_tuple(n0, I0, k0, ik))>{}]; }); using mfma_input_type = @@ -583,24 +501,16 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale(), b_thread_vec.template AsType(), c_thread_buf_per_scale.GetVectorTypeReference( - Number{})); + Number{})); }); - - // Run the element-wise FMA with data from previous iteration buffer - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale // mfma output from previous iteration - .GetVectorTypeReference(Number{}) - .template AsType()[t], - c_scale_thread_vec - .template AsType()[Number<0>{}], // scales c=a*b - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale + .GetVectorTypeReference(Number{}) + .template AsType()(Number{}) * + type_convert(c_scale_thread_buf[m0]); }); }); }); @@ -625,90 +535,89 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { - c_scale_thread_buf(m0) = - a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; - }); - HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); + // Compute scale values early to allow better instruction scheduling + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0]; + }); + + // Prefetch scale data early to overlap with MFMA computation + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + i += 1; } while(i < (num_loop - 1)); } - // __builtin_amdgcn_sched_barrier(0); + // tail if constexpr(TailNum == TailNumber::Full) { static_for<0, MRepeat, 1>{}([&](auto m0) { - vector_type c_scale_thread_vec; - c_scale_thread_vec.template AsType()(Number<0>{}) = - c_scale_thread_buf[m0]; - c_scale_thread_vec.template AsType()(Number<1>{}) = - c_scale_thread_buf[m0]; - static_for<0, NRepeat, 1>{}([&](auto n0) { - // Calculate buffer offsets using the same future tile approach - constexpr auto mfma_buf_offset = - ((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops(); - constexpr auto scale_buf_offset = + // Use double buffering with temporal offset in tail section as well + constexpr auto buffer_offset = ((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops(); - // Calculate future tile data offsets - constexpr auto a_future_tile_offset = - ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat; - constexpr auto b_future_tile_offset = - ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat; + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - // Skip MFMA computation for the last tile to avoid out-of-bounds - if constexpr(!((m0 == (MRepeat - 1)) && (n0 == (NRepeat - 1)))) - { - // Clear the MFMA buffer for future tile computation - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number{}) - .template AsType()(Number{}) = 0; + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; }); - // Compute MFMA for future tile - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + using mfma_input_type = + typename vector_type::type; - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference( - Number{})); - }); - } - - // Scale and accumulate the previous iteration's result - constexpr auto c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale - .GetVectorTypeReference(Number{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number{}) + .template AsType()(Number{}) * + type_convert(c_scale_thread_buf[m0]); }); }); }); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp index 36d0314518..65b87932db 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp @@ -231,24 +231,23 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 } }; - // constexpr index_t minimum_occupancy = [&]() { - // // if constexpr(is_same_v && - // // is_same_v) - // // { - // // // FIXME: many instances have many spills with occupancy > 1, a better - // solution - // // // needed to get best performance - // // return 1; - // // } - // // else - // { - // return (BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave && - // MPerBlock * NPerBlock / BlockSize > 64) - // ? 1 - // : 2; - // } - // }(); - constexpr index_t minimum_occupancy = 2; + constexpr index_t minimum_occupancy = [&]() { + if constexpr(is_same_v && + is_same_v) + { + // FIXME: many instances have many spills with occupancy > 1, a better solution + // needed to get best performance + return 1; + } + else + { + return (BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave && + MPerBlock * NPerBlock / BlockSize > 64) + ? 1 + : 2; + } + }(); + // constexpr index_t minimum_occupancy = 2; if(has_main_k_block_loop) {