mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
use scale (#363)
This commit is contained in:
@@ -0,0 +1 @@
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
|
||||
@@ -51,7 +51,7 @@ using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
@@ -122,7 +122,7 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
CElementOp>;
|
||||
Acc0ElementOp>;
|
||||
|
||||
// Ref Softmax: fp32 in, fp16 out
|
||||
using ReferenceSoftmaxInstance =
|
||||
@@ -157,6 +157,7 @@ int main(int argc, char* argv[])
|
||||
ck::index_t BatchStrideB0 = -1;
|
||||
ck::index_t BatchStrideB1 = -1;
|
||||
ck::index_t BatchStrideC = -1;
|
||||
float alpha = 1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
@@ -181,7 +182,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
BatchCount = std::stoi(argv[8]);
|
||||
}
|
||||
else if(argc == 17)
|
||||
else if(argc == 18)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
@@ -203,6 +204,8 @@ int main(int argc, char* argv[])
|
||||
BatchStrideB0 = std::stoi(argv[14]);
|
||||
BatchStrideB1 = std::stoi(argv[15]);
|
||||
BatchStrideC = std::stoi(argv[16]);
|
||||
|
||||
alpha = std::stof(argv[17]);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -211,6 +214,7 @@ int main(int argc, char* argv[])
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 17: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, "
|
||||
"BatchStrideB0, BatchStrideB1, BatchStrideC\n");
|
||||
printf("arg18: alpha\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
@@ -304,7 +308,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{alpha};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
@@ -368,7 +372,7 @@ int main(int argc, char* argv[])
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, PassThrough{});
|
||||
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
# TODO: add example batched_gemm_gemm_xdl_fp16
|
||||
add_example_executable(example_batched_gemm_softmax_gemm_xdl_fp16 batched_gemm_softmax_gemm_xdl_fp16.cpp)
|
||||
@@ -46,7 +46,7 @@ add_subdirectory(28_grouped_gemm_bias_e_permute)
|
||||
add_subdirectory(29_batched_gemm_bias_e_permute)
|
||||
add_subdirectory(30_grouped_convnd_fwd_bias_relu_add)
|
||||
add_subdirectory(31_batched_gemm_gemm)
|
||||
add_subdirectory(32_batched_gemm_softmax_gemm)
|
||||
add_subdirectory(32_batched_gemm_scale_softmax_gemm)
|
||||
add_subdirectory(33_multiple_reduce)
|
||||
add_subdirectory(34_batchnorm)
|
||||
|
||||
|
||||
@@ -561,11 +561,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
FloatAB,
|
||||
decltype(acc_thread_desc_k0_m_k1),
|
||||
decltype(a1_thread_desc_k0_m_k1),
|
||||
decltype(acc_element_op),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
n4>{acc_element_op};
|
||||
n4>{tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
// B1 matrix blockwise copy
|
||||
auto b1_blockwise_copy =
|
||||
@@ -717,6 +717,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
blockwise_gemm,
|
||||
acc_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
// Acc0 elementwise Op
|
||||
static_for<0, acc_thread_buf.Size(), 1>{}(
|
||||
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
|
||||
|
||||
// softmax
|
||||
SoftmaxBuf& max = blockwise_softmax.max_value_buf;
|
||||
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
|
||||
|
||||
Reference in New Issue
Block a user