diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt new file mode 100644 index 0000000000..2ff590b9d2 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) diff --git a/example/32_batched_gemm_softmax_gemm/batched_gemm_softmax_gemm_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp similarity index 98% rename from example/32_batched_gemm_softmax_gemm/batched_gemm_softmax_gemm_xdl_fp16.cpp rename to example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp index 18b0ea79a6..b3530d7aaf 100644 --- a/example/32_batched_gemm_softmax_gemm/batched_gemm_softmax_gemm_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp @@ -51,7 +51,7 @@ using CLayout = Row; using AElementOp = PassThrough; using B0ElementOp = PassThrough; -using Acc0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; using B1ElementOp = PassThrough; using CElementOp = PassThrough; @@ -122,7 +122,7 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm< AccDataType, AElementOp, B0ElementOp, - CElementOp>; + Acc0ElementOp>; // Ref Softmax: fp32 in, fp16 out using ReferenceSoftmaxInstance = @@ -157,6 +157,7 @@ int main(int argc, char* argv[]) ck::index_t BatchStrideB0 = -1; ck::index_t BatchStrideB1 = -1; ck::index_t BatchStrideC = -1; + float alpha = 1; if(argc == 1) { @@ -181,7 +182,7 @@ int main(int argc, char* argv[]) BatchCount = std::stoi(argv[8]); } - else if(argc == 17) + else if(argc == 18) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); @@ -203,6 +204,8 @@ int main(int argc, char* argv[]) BatchStrideB0 = std::stoi(argv[14]); BatchStrideB1 = std::stoi(argv[15]); BatchStrideC = std::stoi(argv[16]); + + alpha = std::stof(argv[17]); } else { @@ -211,6 +214,7 @@ int main(int argc, char* argv[]) printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4 to 17: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, " "BatchStrideB0, BatchStrideB1, BatchStrideC\n"); + printf("arg18: alpha\n"); exit(0); } @@ -304,7 +308,7 @@ int main(int argc, char* argv[]) auto a_element_op = AElementOp{}; auto b0_element_op = B0ElementOp{}; - auto acc0_element_op = Acc0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; auto b1_element_op = B1ElementOp{}; auto c_element_op = CElementOp{}; @@ -368,7 +372,7 @@ int main(int argc, char* argv[]) auto ref_gemm0 = ReferenceGemm0Instance{}; auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); auto ref_gemm0_argument = ref_gemm0.MakeArgument( - a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, PassThrough{}); + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); ref_gemm0_invoker.Run(ref_gemm0_argument); diff --git a/example/32_batched_gemm_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_softmax_gemm/CMakeLists.txt deleted file mode 100644 index ca4fb026cb..0000000000 --- a/example/32_batched_gemm_softmax_gemm/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -# TODO: add example batched_gemm_gemm_xdl_fp16 -add_example_executable(example_batched_gemm_softmax_gemm_xdl_fp16 batched_gemm_softmax_gemm_xdl_fp16.cpp) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 57cacecd26..1845d46c05 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -46,7 +46,7 @@ add_subdirectory(28_grouped_gemm_bias_e_permute) add_subdirectory(29_batched_gemm_bias_e_permute) add_subdirectory(30_grouped_convnd_fwd_bias_relu_add) add_subdirectory(31_batched_gemm_gemm) -add_subdirectory(32_batched_gemm_softmax_gemm) +add_subdirectory(32_batched_gemm_scale_softmax_gemm) add_subdirectory(33_multiple_reduce) add_subdirectory(34_batchnorm) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index db6f7cbb50..098056044a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -561,11 +561,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle FloatAB, decltype(acc_thread_desc_k0_m_k1), decltype(a1_thread_desc_k0_m_k1), - decltype(acc_element_op), + tensor_operation::element_wise::PassThrough, Sequence, Sequence<1, 0, 2>, 2, - n4>{acc_element_op}; + n4>{tensor_operation::element_wise::PassThrough{}}; // B1 matrix blockwise copy auto b1_blockwise_copy = @@ -717,6 +717,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle blockwise_gemm, acc_thread_buf, num_k_block_main_loop); + + // Acc0 elementwise Op + static_for<0, acc_thread_buf.Size(), 1>{}( + [&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); }); + // softmax SoftmaxBuf& max = blockwise_softmax.max_value_buf; SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;