From 94a88e2ecc2ceb7aada1ca6d5f27acb8329bb71b Mon Sep 17 00:00:00 2001 From: aledudek Date: Mon, 17 Mar 2025 16:42:43 +0100 Subject: [PATCH] Async grouped gemm v3 (#1940) * Fully async grouped gemm * Remove commented code * Remvoe maybe_unused * host kernel args * Checkpoint segfault debugging... * Working part1 * Working part2 * Remvoe comments... * Use void ptr for gemm kernel host args * Fix device_grouped_gemm_multiple_d_dl build issue * Fix device_grouped_gemm_xdl build issue [ROCm/composable_kernel commit: 50959069750521e03b4a7ca5e9eeb51125338cbb] --- .../run_grouped_gemm_example.inc | 20 +++++- .../device_grouped_gemm_multiple_d_dl.hpp | 70 ++++++++++++++++--- .../device/impl/device_grouped_gemm_xdl.hpp | 57 ++++++++++++--- ...evice_grouped_gemm_xdl_splitk_cshuffle.hpp | 60 +++++++++++++--- 4 files changed, 179 insertions(+), 28 deletions(-) diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 64125cd1d0..86b3182a52 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -173,8 +173,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument); std::size_t kargs_size = gemm.GetDeviceKernelArgSize(&argument); + std::size_t hargs_size = gemm.GetHostKernelArgSize(&argument); DeviceMem gemm_workspace, gemm_kargs; + void* gemm_hargs; // The following is necessary since TwoStage kernel is using additional memory both // for Workspace and kernel arguments. @@ -188,6 +190,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co gemm_workspace.Realloc(workspace_size); gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer()); } + if(hargs_size > 0) + { + hip_check_error(hipHostMalloc(&gemm_hargs, hargs_size)); + gemm.SetHostKernelArgs(&argument, gemm_hargs); + } if(!gemm.IsSupportedArgument(argument)) { @@ -196,7 +203,16 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co "not support this GEMM problem"); } - invoker.Run(argument, StreamConfig{nullptr, false}); + hipStream_t stream0 = nullptr; + hip_check_error(hipStreamCreate(&stream0)); + + hipEvent_t event0 = nullptr; + hip_check_error(hipEventCreate(&event0)); + + invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0); + + hip_check_error(hipEventSynchronize(event0)); + hip_check_error(hipStreamSynchronize(stream0)); bool pass = true; if(config.do_verification) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index 959fc890b8..c148d7dbb7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -1,6 +1,6 @@ #pragma once // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -420,7 +420,8 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm> b_mtx_nraw_kraw_; index_t grid_size_; + void* gemm_kernel_host_args_; }; // Invoker @@ -545,7 +547,10 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm(p_arg)->group_count_ * sizeof(GemmKernelArg); } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return GetWorkSpaceSize(p_arg); + } + + size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args); + } + + void SetHostKernelArgs(BaseArgument* p_arg, void* p_host_kernel_args) const + { + Argument* pArg_ = dynamic_cast(p_arg); + if(!pArg_) + { + throw std::runtime_error("Failed to cast argument pointer!"); + } + + pArg_->gemm_kernel_host_args_ = p_host_kernel_args; + std::copy(pArg_->gemm_desc_kernel_arg_.begin(), + pArg_->gemm_desc_kernel_arg_.end(), + static_cast(pArg_->gemm_kernel_host_args_)); + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index 8b40eea56c..2a6406aac3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -1,6 +1,6 @@ #pragma once // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -500,6 +500,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm> b_mtx_nraw_kraw_; index_t grid_size_; + void* gemm_kernel_host_args_; }; // Invoker @@ -507,7 +508,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemmSetWorkSpacePointer(p_arg, p_dev_kernel_args); } + + size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); } + + void SetHostKernelArgs(BaseArgument* p_arg, void* p_host_kernel_args) const + { + Argument* pArg_ = dynamic_cast(p_arg); + if(!pArg_) + { + throw std::runtime_error("Failed to cast argument pointer!"); + } + + pArg_->gemm_kernel_host_args_ = p_host_kernel_args; + std::copy(pArg_->gemm_desc_kernel_arg_.begin(), + pArg_->gemm_desc_kernel_arg_.end(), + static_cast(pArg_->gemm_kernel_host_args_)); + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index 994c667fbc..03431d7156 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -244,7 +244,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK& p_Es, std::vector& gemm_descs, index_t kbatch) - : K_BATCH{kbatch} + : K_BATCH{kbatch}, gemm_kernel_host_args_{nullptr} { grid_size_ = 0; group_count_ = ck::type_convert(gemm_descs.size()); @@ -365,13 +365,17 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK gemm_kernel_args_; + void* gemm_kernel_host_args_; index_t grid_size_; }; // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + float Run(const Argument& arg, + const StreamConfig& stream_config = StreamConfig{}, + hipStream_t cpy_stream = nullptr, + hipEvent_t cpy_event = nullptr) { index_t K0 = arg.gemm_kernel_args_[0].karg_.K0Padded; bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1; @@ -419,12 +423,34 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitKSetWorkSpacePointer(p_arg, p_dev_kernel_args); } + + void SetHostKernelArgs(BaseArgument* p_arg, void* p_host_kernel_args) const + { + Argument* pArg_ = dynamic_cast(p_arg); + if(!pArg_) + { + throw std::runtime_error("Failed to cast argument pointer!"); + } + + pArg_->gemm_kernel_host_args_ = p_host_kernel_args; + std::copy(pArg_->gemm_kernel_args_.begin(), + pArg_->gemm_kernel_args_.end(), + static_cast(pArg_->gemm_kernel_host_args_)); + } }; } // namespace device