mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
LWPCK-2429: Device grouped GEMM uses Async Memcpy (#1695)
* LWPCK-2429: Device grouped GEMM uses Async Memcpy Resolving merge conflicts * reverting changes to profile_grouped_gemm * revert date change --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -603,11 +603,11 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
}
|
||||
|
||||
hipGetErrorString(
|
||||
hipMemcpyWithStream(arg.p_workspace_,
|
||||
arg.gemm_desc_kernel_arg_.data(),
|
||||
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_desc_kernel_arg_.data(),
|
||||
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop,
|
||||
auto has_double_tail_k_block_loop) {
|
||||
|
||||
@@ -761,11 +761,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
float time{0.f};
|
||||
|
||||
hip_check_error(
|
||||
hipMemcpyWithStream(dev_gemm_kargs,
|
||||
arg.gemm_kernel_args_.data(),
|
||||
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
hipMemcpyAsync(dev_gemm_kargs,
|
||||
arg.gemm_kernel_args_.data(),
|
||||
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
|
||||
auto preprocess = [&]() {
|
||||
hip_check_error(hipMemsetAsync(
|
||||
|
||||
@@ -940,10 +940,10 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
const void* p_host_kernel_args) const
|
||||
{
|
||||
arg.p_dev_gemm_args_ = p_dev_kernel_args;
|
||||
hip_check_error(hipMemcpy(p_dev_kernel_args,
|
||||
p_host_kernel_args,
|
||||
GetDeviceKernelArgSize(&arg),
|
||||
hipMemcpyHostToDevice));
|
||||
hip_check_error(hipMemcpyAsync(p_dev_kernel_args,
|
||||
p_host_kernel_args,
|
||||
GetDeviceKernelArgSize(&arg),
|
||||
hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
virtual void SetDeviceKernelArgs(BaseArgument* p_arg,
|
||||
|
||||
@@ -557,12 +557,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
|
||||
arg.gemm_desc_kernel_arg_.data(),
|
||||
arg.gemm_desc_kernel_arg_.size() *
|
||||
sizeof(GemmBiasTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
hipGetErrorString(
|
||||
hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_desc_kernel_arg_.data(),
|
||||
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
|
||||
@@ -421,11 +421,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
}
|
||||
|
||||
hip_check_error(
|
||||
hipMemcpyWithStream(arg.p_workspace_,
|
||||
arg.gemm_kernel_args_.data(),
|
||||
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_kernel_args_.data(),
|
||||
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user