From 6002d3277429ada442cda775337728e89950d45a Mon Sep 17 00:00:00 2001 From: zjing14 Date: Wed, 21 Sep 2022 10:15:43 -0500 Subject: [PATCH] fixed G offset calc for long_index (#428) [ROCm/composable_kernel commit: 01876afafe1c09028dc4d513b5d040cec798fae6] --- ...ce_batched_contraction_multiple_d_xdl_cshuffle.hpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index 9152e8d85a..bb3c09b427 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -506,12 +506,12 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const { - return g_idx * static_cast(batch_stride_A_); + return static_cast(g_idx) * batch_stride_A_; } __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const { - return g_idx * static_cast(batch_stride_B_); + return static_cast(g_idx) * batch_stride_B_; } __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const @@ -519,8 +519,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle std::array ds_offset; static_for<0, NumDTensor, 1>{}([&](auto i) { - ds_offset[i] = - ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(g_idx, 0, 0)); + ds_offset[i] = static_cast(g_idx) * + ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0)); }); return ds_offset; @@ -528,7 +528,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const { - return e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + return static_cast(g_idx) * + e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0)); } private: