diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index 54760b0556..2c42ac9345 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -945,7 +945,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext } }(); - if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) { return transform_tensor_descriptor(c_grid_desc_m_n, make_tuple(make_right_pad_transform(M, MPad - M), diff --git a/test/batched_gemm/CMakeLists.txt b/test/batched_gemm/CMakeLists.txt index d169c135ca..9482821b68 100644 --- a/test/batched_gemm/CMakeLists.txt +++ b/test/batched_gemm/CMakeLists.txt @@ -2,22 +2,8 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_fp16 PRIVATE utility device_batched_gemm_instance) - endif() - add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_fp32 PRIVATE utility device_batched_gemm_instance) - endif() - add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_bf16 PRIVATE utility device_batched_gemm_instance) - endif() - add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_int8 PRIVATE utility device_batched_gemm_instance) - endif() + add_gtest_executable(test_batched_gemm test_batched_gemm.cpp) + target_link_libraries(test_batched_gemm PRIVATE utility device_batched_gemm_instance) set(target 1) endif() endforeach() \ No newline at end of file diff --git a/test/batched_gemm/batched_gemm_bf16.cpp b/test/batched_gemm/batched_gemm_bf16.cpp deleted file mode 100644 index 5d12a1e956..0000000000 --- a/test/batched_gemm/batched_gemm_bf16.cpp +++ /dev/null @@ -1,114 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "profiler/profile_batched_gemm_impl.hpp" - -#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" - -namespace { -using ADataType = ck::bhalf_t; -using BDataType = ck::bhalf_t; -using CDataType = ck::bhalf_t; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -} // namespace - -int main() -{ - int M = 256; - int N = 256; - int K = 128; - int BatchCount = 3; - - bool pass = true; - - using namespace ck::tensor_operation::device; - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); - - std::cout << "test BatchedGEMM bf16: " << (pass ? "Pass" : "Fail") << std::endl; - return pass ? 0 : 1; -} diff --git a/test/batched_gemm/batched_gemm_fp16.cpp b/test/batched_gemm/batched_gemm_fp16.cpp deleted file mode 100644 index a2b61d951a..0000000000 --- a/test/batched_gemm/batched_gemm_fp16.cpp +++ /dev/null @@ -1,114 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "profiler/profile_batched_gemm_impl.hpp" - -#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" - -namespace { -using ADataType = ck::half_t; -using BDataType = ck::half_t; -using CDataType = ck::half_t; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -} // namespace - -int main() -{ - int M = 512; - int N = 256; - int K = 128; - int BatchCount = 3; - - bool pass = true; - - using namespace ck::tensor_operation::device; - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); - - std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl; - return pass ? 0 : 1; -} diff --git a/test/batched_gemm/batched_gemm_fp32.cpp b/test/batched_gemm/batched_gemm_fp32.cpp deleted file mode 100644 index 2b18d166e6..0000000000 --- a/test/batched_gemm/batched_gemm_fp32.cpp +++ /dev/null @@ -1,114 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "profiler/profile_batched_gemm_impl.hpp" - -#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" - -namespace { -using ADataType = float; -using BDataType = float; -using CDataType = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -} // namespace - -int main() -{ - int M = 256; - int N = 256; - int K = 128; - int BatchCount = 3; - - bool pass = true; - - using namespace ck::tensor_operation::device; - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); - - std::cout << "test BatchedGEMM fp32: " << (pass ? "Pass" : "Fail") << std::endl; - return pass ? 0 : 1; -} diff --git a/test/batched_gemm/batched_gemm_int8.cpp b/test/batched_gemm/batched_gemm_int8.cpp deleted file mode 100644 index f607eaa84b..0000000000 --- a/test/batched_gemm/batched_gemm_int8.cpp +++ /dev/null @@ -1,114 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "profiler/profile_batched_gemm_impl.hpp" - -#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" - -namespace { -using ADataType = int8_t; -using BDataType = int8_t; -using CDataType = int8_t; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -} // namespace - -int main() -{ - int M = 256; - int N = 256; - int K = 128; - int BatchCount = 3; - - bool pass = true; - - using namespace ck::tensor_operation::device; - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); - - std::cout << "test BatchedGEMM int8: " << (pass ? "Pass" : "Fail") << std::endl; - return pass ? 0 : 1; -} diff --git a/test/batched_gemm/test_batched_gemm.cpp b/test/batched_gemm/test_batched_gemm.cpp new file mode 100644 index 0000000000..f9bb626ce5 --- /dev/null +++ b/test/batched_gemm/test_batched_gemm.cpp @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include + +#include "profiler/profile_batched_gemm_impl.hpp" + +#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" + +struct GemmParams +{ + ck::index_t M; + ck::index_t N; + ck::index_t K; + ck::index_t BatchCount; +}; + +class TestBatchedGemm : public ::testing::Test +{ + protected: + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + std::vector params; + + template + void Run() + { + using namespace ck::tensor_operation::device; + + bool pass = true; + for(auto& param : params) + { + const auto M = param.M; + const auto N = param.N; + const auto K = param.K; + const auto BatchCount = param.BatchCount; + + pass = + pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + + pass = + pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); + + pass = + pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); + + pass = + pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); + } + EXPECT_TRUE(pass); + } +}; + +#ifdef CK_ENABLE_INT8 +TEST_F(TestBatchedGemm, i8) +{ + this->params.push_back({64, 64, 64, 2}); + this->params.push_back({64, 64, 64, 1}); + this->params.push_back({60, 60, 60, 2}); + this->params.push_back({68, 68, 68, 2}); + this->params.push_back({40, 40, 40, 2}); + this->params.push_back({256, 256, 128, 3}); + this->template Run(); +} +#endif + +#ifdef CK_ENABLE_BF16 +TEST_F(TestBatchedGemm, bf16) +{ + this->params.push_back({64, 64, 64, 2}); + this->params.push_back({64, 64, 64, 1}); + this->params.push_back({60, 60, 60, 2}); + this->params.push_back({68, 68, 68, 2}); + this->params.push_back({40, 40, 40, 2}); + this->params.push_back({256, 256, 128, 3}); + this->template Run(); +} +#endif + +#ifdef CK_ENABLE_FP16 +TEST_F(TestBatchedGemm, fp16) +{ + this->params.push_back({64, 64, 64, 2}); + this->params.push_back({64, 64, 64, 1}); + this->params.push_back({60, 60, 60, 2}); + this->params.push_back({68, 68, 68, 2}); + this->params.push_back({40, 40, 40, 2}); + this->params.push_back({256, 256, 128, 3}); + this->template Run(); +} +#endif + +#ifdef CK_ENABLE_FP32 +TEST_F(TestBatchedGemm, fp32) +{ + this->params.push_back({64, 64, 64, 2}); + this->params.push_back({64, 64, 64, 1}); + this->params.push_back({60, 60, 60, 2}); + this->params.push_back({68, 68, 68, 2}); + this->params.push_back({40, 40, 40, 2}); + this->params.push_back({256, 256, 128, 3}); + this->template Run(); +} +#endif