mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
use scale (#363)
This commit is contained in:
@@ -561,11 +561,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
FloatAB,
|
||||
decltype(acc_thread_desc_k0_m_k1),
|
||||
decltype(a1_thread_desc_k0_m_k1),
|
||||
decltype(acc_element_op),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
n4>{acc_element_op};
|
||||
n4>{tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
// B1 matrix blockwise copy
|
||||
auto b1_blockwise_copy =
|
||||
@@ -717,6 +717,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
blockwise_gemm,
|
||||
acc_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
// Acc0 elementwise Op
|
||||
static_for<0, acc_thread_buf.Size(), 1>{}(
|
||||
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
|
||||
|
||||
// softmax
|
||||
SoftmaxBuf& max = blockwise_softmax.max_value_buf;
|
||||
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
|
||||
|
||||
Reference in New Issue
Block a user