use scale (#363)

This commit is contained in:
Chao Liu
2022-08-17 10:38:00 -05:00
committed by GitHub
parent c961ce9226
commit bac7df8faf
5 changed files with 18 additions and 10 deletions

View File

@@ -561,11 +561,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatAB,
decltype(acc_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1),
decltype(acc_element_op),
tensor_operation::element_wise::PassThrough,
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
Sequence<1, 0, 2>,
2,
n4>{acc_element_op};
n4>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy
auto b1_blockwise_copy =
@@ -717,6 +717,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
blockwise_gemm,
acc_thread_buf,
num_k_block_main_loop);
// Acc0 elementwise Op
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
// softmax
SoftmaxBuf& max = blockwise_softmax.max_value_buf;
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;