From 6b3a060294b3b9ae912b58baa3d807ff69238fde Mon Sep 17 00:00:00 2001 From: guangzlu <87220526+guangzlu@users.noreply.github.com> Date: Fri, 1 Jul 2022 14:38:21 +0800 Subject: [PATCH] modified grouped gemm addressing method (#307) * modified grouped gemm addressing method * modified addressing method in device_grouped_gemm_xdl.hpp Co-authored-by: root Co-authored-by: Chao Liu [ROCm/composable_kernel commit: 8e374781d525393288b6bb9d8f6da0793fdb9902] --- .../gpu/device/device_grouped_gemm_xdl.hpp | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp index 8047cba885..999792807b 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp @@ -46,13 +46,22 @@ __global__ void const auto gemm_desc_ptr = reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); - index_t group_id = 0; - for(index_t i = 0; i < group_count; i++) + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ && + block_id < gemm_desc_ptr[group_id].BlockEnd_)) && + left <= right) { - group_id = - (block_id >= gemm_desc_ptr[i].BlockStart_ && block_id < gemm_desc_ptr[i].BlockEnd_) - ? i - : group_id; + if(block_id < gemm_desc_ptr[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); } GridwiseGemm::template Run(