From 7d995ced079fdb3b6e35532d6d557bd488ae35a0 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Thu, 30 Nov 2023 15:09:27 -0600 Subject: [PATCH] Fixed GroupedGemmFixedNK with hipGraph (#1065) * fixed examples; add async_mem_set * add stream to all deviceOp using SetWorkspace --------- Co-authored-by: Jing Zhang [ROCm/composable_kernel commit: 49df1dc595734d20ecdf9dfe11933e527fea84f1] --- example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp | 4 ++-- example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp | 4 ++-- include/ck/tensor_operation/gpu/device/device_base.hpp | 4 +++- .../gpu/device/impl/device_batchnorm_backward_impl.hpp | 4 +++- .../gpu/device/impl/device_batchnorm_forward_impl.hpp | 4 +++- .../device/impl/device_batchnorm_forward_impl_obsolete.hpp | 4 +++- .../impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp | 4 +++- .../gpu/device/impl/device_gemm_xdl_streamk.hpp | 4 +++- .../gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp | 7 +++++-- .../device/impl/device_normalization_fwd_splitk_impl.hpp | 4 +++- 10 files changed, 30 insertions(+), 13 deletions(-) diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp index 95b8526094..2c1feafce3 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -299,8 +299,8 @@ int main(int argc, char* argv[]) for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(256 + 256 * i); - problem_size.Ns.push_back(128 + 128 * i); - problem_size.Ks.push_back(128 + 64 * i); + problem_size.Ns.push_back(256); + problem_size.Ks.push_back(128); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp index 84abe1d1db..9fd63cba77 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp @@ -300,8 +300,8 @@ int main(int argc, char* argv[]) for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(256 + 256 * i); - problem_size.Ns.push_back(128 + 128 * i); - problem_size.Ks.push_back(128 + 64 * i); + problem_size.Ns.push_back(256); + problem_size.Ks.push_back(128); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 1981690111..908ada016d 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -59,7 +59,9 @@ struct BaseOperator virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } - virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const + virtual void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const { assert(p_arg); p_arg->p_workspace_ = p_workspace; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp index f46237e005..3b62cf10a3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp @@ -376,7 +376,9 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd(pArg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp index ad8e795603..e7e4668d92 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp @@ -354,7 +354,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd(pArg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp index b826793c27..c3e0837722 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp @@ -345,7 +345,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd(pArg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp index b0efa9d4e4..f7319226a9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp @@ -821,7 +821,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle return (workspace_size); }; - void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override + void SetWorkSpacePointer(BaseArgument* pArg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override { Argument* pArg_ = dynamic_cast(pArg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp index 8de42ba9ef..c8799e5154 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp @@ -226,7 +226,9 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK(pArg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index 56132f7a0f..0a0cb59063 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -817,12 +817,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK); } - void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const override + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& stream_config = StreamConfig{}) const override { auto p_arg_ = dynamic_cast(p_arg); p_arg_->p_workspace_ = p_workspace; - hip_check_error(hipMemset(p_workspace, 0, GetWorkSpaceSize(p_arg))); + hip_check_error( + hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_)); } static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp index 58db34c9f2..6a117920f4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp @@ -577,7 +577,9 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd(pArg);