From 03b4c5322d52454bd6a408dc512f78aa10e624bc Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Thu, 3 Apr 2025 11:48:54 -0700 Subject: [PATCH] Add the MI355 support for CK TILE GEMM (#2046) * Get the root cause of the ck tile gemm failing on mi355 * Fix the ck tile gemm on MI355 * delete the debug info [ROCm/composable_kernel commit: 50d1f8ff905eeabc61123864d9a805d215676a53] --- example/ck_tile/03_gemm/CMakeLists.txt | 9 ++++++--- example/ck_tile/03_gemm/run_gemm_example.inc | 8 ++++---- test/ck_tile/gemm/CMakeLists.txt | 20 +++++++++++++++++++ .../gemm/test_gemm_pipeline_compv3.cpp | 2 +- .../gemm/test_gemm_pipeline_compv4.cpp | 2 +- 5 files changed, 32 insertions(+), 9 deletions(-) diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 30cfee22f6..61c3a57391 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,5 +1,8 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) -target_compile_options(tile_example_gemm_universal PRIVATE - -mllvm -enable-noalias-to-md-conversion=0 -) +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() +list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) +target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 6cb40e45d1..c3b4ec609c 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -240,8 +240,8 @@ int run_gemm_example_with_layouts(int argc, if(init_method == 0) { - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); } else if(init_method == 1) { @@ -250,8 +250,8 @@ int run_gemm_example_with_layouts(int argc, } else if(init_method == 2) { - ck_tile::FillConstant{static_cast(1)}(a_m_k); - ck_tile::FillConstant{static_cast(1)}(b_k_n); + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n); } else { diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 7701e451ad..3e7296b1eb 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -1,8 +1,28 @@ # Currently ck_tile is only built on gfx94/gfx95 +set(EXAMPLE_GEMM_COMPILE_OPTIONS "") +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() +set(EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS "") +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() +list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS + -mllvm + -enable-noalias-to-md-conversion=0 +) + +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp) + + target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) else() message("Skipping ck_tile_gemm tests for current target") endif() +endif() diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp index d81e870ffc..8944e6865d 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp @@ -9,7 +9,7 @@ class TestCkTileGemmPipelineCompV3 : public TestCkTileGemmPipeline #define TEST_SUITE_NAME TestCkTileGemmPipelineCompV3 -TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV3, KernelTypesMem); +TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV3, KernelTypesCompV3); #include "test_gemm_pipeline_ut_cases.inc" diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp index 1da0028f63..22e77fac41 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp @@ -9,7 +9,7 @@ class TestCkTileGemmPipelineCompV4 : public TestCkTileGemmPipeline #define TEST_SUITE_NAME TestCkTileGemmPipelineCompV4 -TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV4, KernelTypesMem); +TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV4, KernelTypesCompV4); #include "test_gemm_pipeline_ut_cases.inc"