mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Use asynchronous version of hipMemset (#850)
This commit is contained in:
@@ -532,11 +532,12 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
float ave_time = 0;
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
hipGetErrorString(hipMemset(
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
arg.p_c_grid_,
|
||||
0,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
|
||||
sizeof(CDataType)));
|
||||
sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
|
||||
@@ -158,8 +158,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(kbatch > 1)
|
||||
hipGetErrorString(
|
||||
hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
|
||||
hipGetErrorString(hipMemsetAsync(karg.p_c_grid,
|
||||
0,
|
||||
karg.M * karg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg, b2c_map);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -147,7 +147,10 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
|
||||
if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
|
||||
hipGetErrorString(hipMemsetAsync(karg.p_c_grid,
|
||||
0,
|
||||
karg.M * karg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
grid_dims,
|
||||
|
||||
@@ -421,8 +421,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
for(const auto& trans_arg : arg.gemm_kernel_args_)
|
||||
{
|
||||
const auto& karg = trans_arg.karg_;
|
||||
hip_check_error(
|
||||
hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(EDataType)));
|
||||
hip_check_error(hipMemsetAsync(karg.p_c_grid,
|
||||
0,
|
||||
karg.M * karg.N * sizeof(EDataType),
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -886,11 +886,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
|
||||
typename GridwiseGemmAtomicAdd::DefaultBlock2ETileMap,
|
||||
has_main_loop>;
|
||||
|
||||
hipGetErrorString(hipMemset(
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
arg.p_e_grid_,
|
||||
0,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
|
||||
sizeof(EDataType)));
|
||||
sizeof(EDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
|
||||
Reference in New Issue
Block a user