From 19718224476790858dc7fefcc11e4597885b0563 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 2 Feb 2024 20:25:54 +0100 Subject: [PATCH] Extend gemm traits number for ck wrapper (#1153) [ROCm/composable_kernel commit: 171ca260b506b32e53c899bdc580accb3469937c] --- .../traits/blockwise_gemm_xdl_traits.hpp | 21 +++++++++++++++++++ test/wrapper/test_gemm.cpp | 8 +++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp b/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp index 24d863f5b1..8301636a9f 100644 --- a/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp +++ b/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp @@ -34,6 +34,7 @@ struct BlockwisGemmXdlTraits static constexpr index_t K1 = K1Value; }; +// K1 = 4 struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4> { }; @@ -43,6 +44,26 @@ struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits< struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4> { }; +// K1 = 8 +struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 8> +{ +}; +struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 8> +{ +}; +struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 8> +{ +}; +// K1 = 16 +struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 16> +{ +}; +struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 16> +{ +}; +struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 16> +{ +}; } // namespace wrapper } // namespace ck diff --git a/test/wrapper/test_gemm.cpp b/test/wrapper/test_gemm.cpp index b26cd5fed6..12245490d1 100644 --- a/test/wrapper/test_gemm.cpp +++ b/test/wrapper/test_gemm.cpp @@ -225,10 +225,10 @@ TEST(TestGemm, Int8) using DataType = int8_t; const auto thread_layout = ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}); const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); - PerformGemm( + PerformGemm( 512, 512, 128, tile_shape, thread_layout); // Irregular case - PerformGemm( + PerformGemm( 129, 129, 67, tile_shape, thread_layout); } @@ -237,10 +237,10 @@ TEST(TestGemm, Half) using DataType = ck::half_t; const auto thread_layout = ck::make_tuple(ck::Number<32>{}, ck::Number<8>{}); const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); - PerformGemm( + PerformGemm( 512, 512, 128, tile_shape, thread_layout); // Irregular case - PerformGemm( + PerformGemm( 129, 129, 67, tile_shape, thread_layout); }