From 146a41f0d8670f9c482bbe7f46f973648577438d Mon Sep 17 00:00:00 2001 From: msaffari-amd Date: Thu, 28 Aug 2025 10:47:16 +0200 Subject: [PATCH] Ck tile gemm low prec data types int4 int8 unit tests (#2718) * add gemm unit tests for int4, int8 datatypes * minor changes based on reviews --------- Co-authored-by: msaffari-amd [ROCm/composable_kernel commit: b951416cdb8dd394a511595bbe241d7cd09ae7cc] --- test/ck_tile/gemm/CMakeLists.txt | 6 ++++++ .../gemm/test_gemm_pipeline_universal_int8.cpp | 16 ++++++++++++++++ .../test_gemm_pipeline_universal_pk_int4.cpp | 16 ++++++++++++++++ 3 files changed, 38 insertions(+) create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_universal_int8.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_universal_pk_int4.cpp diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index a982e30a4c..5d34943e0d 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -30,6 +30,12 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_basic_bf8 test_gemm_pipeline_basic_bf8.cpp) target_compile_options(test_ck_tile_gemm_pipeline_basic_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + + add_test_executable(test_ck_tile_gemm_pipeline_universal_int8 test_gemm_pipeline_universal_int8.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_universal_int8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_test_executable(test_ck_tile_gemm_pipeline_universal_pk_int4 test_gemm_pipeline_universal_pk_int4.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_universal_pk_int4 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") # On Radeon devices, build the WMMA version instead add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_int8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_int8.cpp new file mode 100644 index 0000000000..e8a089d8ff --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_int8.cpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "test_gemm_pipeline_smoke_util.hpp" +#include "test_gemm_pipeline_smoke_run_test.inc" +#include "test_gemm_pipeline_universal_run_test.inc" + +int main() { return run_gemm_combinations(); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_pk_int4.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_pk_int4.cpp new file mode 100644 index 0000000000..043db10fb0 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_pk_int4.cpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "test_gemm_pipeline_smoke_util.hpp" +#include "test_gemm_pipeline_smoke_run_test.inc" +#include "test_gemm_pipeline_universal_run_test.inc" + +int main() { return run_gemm_combinations(); }