mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Fix bug where scaling may not be applied in some code path (#526)
* fix bug where scaling may not be applied in some code path
* more test
* revert accidental example code changes
[ROCm/composable_kernel commit: d156709432]
This commit is contained in:
@@ -796,6 +796,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
}
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, acc_thread_buf.Size(), 1>{}(
|
||||
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
|
||||
}
|
||||
|
||||
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
|
||||
int BatchStrideB0 = -1,
|
||||
int BatchStrideB1 = -1,
|
||||
int BatchStrideC = -1,
|
||||
float alpha = 1.f)
|
||||
float alpha = -1.f)
|
||||
|
||||
{
|
||||
|
||||
@@ -187,6 +187,10 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
|
||||
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
|
||||
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
|
||||
|
||||
if(alpha < 0)
|
||||
{
|
||||
alpha = 1.f / std::sqrt(K); // usually 1 / sqrt(head_dim)
|
||||
}
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{alpha};
|
||||
|
||||
@@ -45,7 +45,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
|
||||
int O,
|
||||
int G0,
|
||||
int G1,
|
||||
float alpha = 1.f)
|
||||
float alpha = -1.f)
|
||||
|
||||
{
|
||||
|
||||
@@ -154,6 +154,10 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
|
||||
b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
|
||||
b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
|
||||
|
||||
if(alpha < 0)
|
||||
{
|
||||
alpha = 1.f / std::sqrt(K); // usually 1 / sqrt(head_dim)
|
||||
}
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{alpha};
|
||||
|
||||
Reference in New Issue
Block a user