From ffac25fffa40187ef52531dd0f5fa73d6a1ff9f2 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 3 May 2023 08:25:25 -0700 Subject: [PATCH] Fix grouped_gemm_splitk kernels on MI300. (#694) * replace amd_buffer_atomic_add with hip_atomic_add * fix grouped_gemm_splitk kernels on mi300 * fix syntax * revert experimental atomic_add changes --------- Co-authored-by: Jing Zhang [ROCm/composable_kernel commit: 4a51d2da9de3524e96f2abed5f843ff9da535db3] --- example/15_grouped_gemm/run_grouped_gemm_example.inc | 2 +- .../device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index bceff29b63..320870e0de 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -147,7 +147,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co #else a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); - c_tensors_device[i]->SetZero(); + c_tensors_device[i]->SetZero(); #endif p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index 26a4319eaa..467a8429ab 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -34,7 +34,8 @@ __global__ void kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, const index_t group_count) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size];