From d4e81a76ee5719a53d7f30b3a7bbc6b5c429babf Mon Sep 17 00:00:00 2001 From: Haocong WANG Date: Wed, 12 Apr 2023 04:42:47 +0800 Subject: [PATCH] Add memory index guard in wmma device ops (#667) [ROCm/composable_kernel commit: e85178b4ca892a78344271ae64103c9d4d1bfc40] --- .../gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 9 +++++++++ .../ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp | 7 +++++++ 2 files changed, 16 insertions(+) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 38edace197..d3f81566e9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -505,6 +505,15 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle } // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + return true; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 1fee302c3c..2694aaf6f0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -264,6 +264,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma } // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB)) + { + return false; + } return true; }