diff --git a/test/batched_gemm/CMakeLists.txt b/test/batched_gemm/CMakeLists.txt index 619c82df34..0574f98e87 100644 --- a/test/batched_gemm/CMakeLists.txt +++ b/test/batched_gemm/CMakeLists.txt @@ -2,6 +2,14 @@ add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp) target_link_libraries(test_batched_gemm_fp16 PRIVATE utility) target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance) +add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp) +target_link_libraries(test_batched_gemm_fp32 PRIVATE utility) +target_link_libraries(test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance) + +add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp) +target_link_libraries(test_batched_gemm_bf16 PRIVATE utility) +target_link_libraries(test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance) + add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp) target_link_libraries(test_batched_gemm_int8 PRIVATE utility) target_link_libraries(test_batched_gemm_int8 PRIVATE device_batched_gemm_instance) diff --git a/test/batched_gemm/batched_gemm_bf16.cpp b/test/batched_gemm/batched_gemm_bf16.cpp new file mode 100644 index 0000000000..698e9faada --- /dev/null +++ b/test/batched_gemm/batched_gemm_bf16.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "profiler/include/profile_batched_gemm_impl.hpp" + +namespace { +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +} // namespace + +int main() +{ + int M = 256; + int N = 256; + int K = 128; + int BatchCount = 3; + + bool pass = true; + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); + + std::cout << "test BatchedGEMM bf16: " << (pass ? "Pass" : "Fail") << std::endl; + return pass ? 0 : 1; +} diff --git a/test/batched_gemm/batched_gemm_fp32.cpp b/test/batched_gemm/batched_gemm_fp32.cpp new file mode 100644 index 0000000000..59072acc50 --- /dev/null +++ b/test/batched_gemm/batched_gemm_fp32.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "profiler/include/profile_batched_gemm_impl.hpp" + +namespace { +using ADataType = float; +using BDataType = float; +using CDataType = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +} // namespace + +int main() +{ + int M = 256; + int N = 256; + int K = 128; + int BatchCount = 3; + + bool pass = true; + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); + + std::cout << "test BatchedGEMM fp32: " << (pass ? "Pass" : "Fail") << std::endl; + return pass ? 0 : 1; +}