From 96c07fc27d178cb01165420df10f2cc7b63cb61b Mon Sep 17 00:00:00 2001 From: Anthony Chang Date: Sat, 3 Dec 2022 01:43:34 +0800 Subject: [PATCH] Fix bug where scaling may not be applied in some code path (#526) * fix bug where scaling may not be applied in some code path * more test * revert accidental example code changes [ROCm/composable_kernel commit: d156709432b363a24e19dd33af632c3e328fdac5] --- .../gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp | 5 +++++ .../profiler/profile_batched_gemm_softmax_gemm_impl.hpp | 6 +++++- .../profile_batched_gemm_softmax_gemm_permute_impl.hpp | 6 +++++- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index c8bc33afa3..fec360b7fa 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -796,6 +796,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle } }); } + else + { + static_for<0, acc_thread_buf.Size(), 1>{}( + [&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); }); + } block_sync_lds(); // wait for lds read in gemm0 blockwise gemm diff --git a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp index fe76fcaf0b..f5ec235141 100644 --- a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp @@ -49,7 +49,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, int BatchStrideB0 = -1, int BatchStrideB1 = -1, int BatchStrideC = -1, - float alpha = 1.f) + float alpha = -1.f) { @@ -187,6 +187,10 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); + if(alpha < 0) + { + alpha = 1.f / std::sqrt(K); // usually 1 / sqrt(head_dim) + } auto a_element_op = AElementOp{}; auto b0_element_op = B0ElementOp{}; auto acc0_element_op = Acc0ElementOp{alpha}; diff --git a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp index 8012d6ea0a..91c28f25fc 100644 --- a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp @@ -45,7 +45,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification, int O, int G0, int G1, - float alpha = 1.f) + float alpha = -1.f) { @@ -154,6 +154,10 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification, b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); + if(alpha < 0) + { + alpha = 1.f / std::sqrt(K); // usually 1 / sqrt(head_dim) + } auto a_element_op = AElementOp{}; auto b0_element_op = B0ElementOp{}; auto acc0_element_op = Acc0ElementOp{alpha};