From abb90422b4d3b0e2d75770a476a3ca591b887380 Mon Sep 17 00:00:00 2001 From: Tianyuan Wu Date: Sat, 16 Aug 2025 07:22:27 +0800 Subject: [PATCH] [CK_TILE] CK_TILE GEMM WMMA Support for GFX11/GFX12 (#2466) * WMMA GEMM F16 Implementation Signed-off-by: root * Self-review Signed-off-by: root * ASIC check minor tweak Signed-off-by: root * add missing include file * Set GPU_TARGETS to gfx11/12 generic Signed-off-by: root * INT8 GFX12 Signed-off-by: root * add int8x16 branch * Fix CI script Signed-off-by: root * Fix typo Signed-off-by: root * Add CK_Tile WMMA example Signed-off-by: Tianyuan Wu * Fix CI Signed-off-by: Tianyuan Wu * fix clang format * Set M/N_Warp Back to Constant Signed-off-by: Tianyuan Wu * Use GemmConfigComputeV3 by default Signed-off-by: TianyuanWu * Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12 Signed-off-by: TianyuanWu * Remove CK_Tile wmma gemm examples from the CI list Signed-off-by: TianyuanWu * Add atomic add fallback method for gfx11 Signed-off-by: TianyuanWu * Fix typo Signed-off-by: TianyuanWu * Omit copyright year Signed-off-by: TianyuanWu * Support non-square cases Signed-off-by: TianyuanWu * Fix CI Signed-off-by: TianyuanWu * Add get_device_ip() Signed-off-by: TianyuanWu * Revert "Add atomic add fallback method for gfx11" This reverts commit 4f664969c01b37976c8518c19833d9f1574cd746. Signed-off-by: Tianyuan Wu * Revert "Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12" This reverts commit 949129a3858a825b2a2c4d3ec01663df18a165a5. * Revise method name and typos Signed-off-by: Tianyuan Wu * clang-format Signed-off-by: TianyuanWu * Try fix CI Signed-off-by: TianyuanWu * Revert "Try fix CI" This reverts commit 084c683227e64ab6a8137db00c8165fb05bdc902. * clang-format Signed-off-by: TianyuanWu * Fix typo caused by merge Signed-off-by: Tianyuan Wu * Fix typo caused by merging Signed-off-by: Tianyuan Wu --------- Signed-off-by: root Signed-off-by: Tianyuan Wu Signed-off-by: TianyuanWu Signed-off-by: Tianyuan Wu Co-authored-by: joye Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin_amdeng [ROCm/composable_kernel commit: 68134b60e45612b54f6c3165e39078676b41928d] --- CMakeLists.txt | 7 + Jenkinsfile | 4 +- .../gemm_bilinear_wmma_fp16.cpp | 2 +- .../gemm_bilinear_wmma_int8.cpp | 2 +- .../gemm_bilinear_xdl_fp16.cpp | 2 +- .../gemm_multi_ABD_xdl_fp16.cpp | 2 +- .../contraction_multi_ABD_xdl_fp16.cpp | 2 +- example/ck_tile/03_gemm/gemm_utils.hpp | 21 ++ example/ck_tile/03_gemm/universal_gemm.cpp | 1 - .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 0 include/ck_tile/core/arch/arch.hpp | 19 +- .../core/arch/generic_memory_space_atomic.hpp | 58 ++++++ include/ck_tile/core/config.hpp | 8 +- include/ck_tile/host/device_prop.hpp | 13 ++ .../ops/epilogue/cshuffle_epilogue.hpp | 14 +- .../ops/epilogue/default_2d_epilogue.hpp | 14 +- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 14 +- ...block_fmha_bwd_pipeline_default_policy.hpp | 67 ++++--- ...mha_bwd_pipeline_trload_default_policy.hpp | 6 +- ..._pipeline_qr_ks_vs_async_trload_policy.hpp | 49 +++-- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 17 +- .../fused_moegemm_pipeline_flatmm_policy.hpp | 8 +- include/ck_tile/ops/gemm.hpp | 6 + ...emm_asmem_bsmem_creg_v1_default_policy.hpp | 40 ++-- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 4 +- .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 2 +- ...peline_ag_bg_cr_comp_v4_default_policy.hpp | 21 +- ...peline_ag_bg_cr_comp_v5_default_policy.hpp | 19 +- ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 21 +- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 21 +- ..._pipeline_agmem_bgmem_creg_base_policy.hpp | 14 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 134 ++++++------- .../gemm/warp/warp_gemm_attribute_mfma.hpp | 18 +- .../gemm/warp/warp_gemm_attribute_wmma.hpp | 147 ++++++++++++++ .../warp/warp_gemm_attribute_wmma_impl.hpp | 132 +++++++++++++ ..._gemm_attribute_wmma_impl_16bit_traits.hpp | 87 ++++++++ ...p_gemm_attribute_wmma_impl_8bit_traits.hpp | 138 +++++++++++++ ...p_gemm_attribute_wmma_impl_base_traits.hpp | 86 ++++++++ .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 185 ++++++++++-------- .../ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp | 37 ++++ .../gemm_aquant_pipeline_ag_bg_cr_policy.hpp | 28 +-- test/ck_tile/gemm/CMakeLists.txt | 11 ++ .../gemm/test_gemm_pipeline_compv3.cpp | 3 +- .../gemm/test_gemm_pipeline_compv3_wmma.cpp | 17 ++ .../gemm/test_gemm_pipeline_compv4.cpp | 3 +- .../gemm/test_gemm_pipeline_compv4_wmma.cpp | 17 ++ .../gemm/test_gemm_pipeline_kernel_types.hpp | 144 ++++++++++---- test/ck_tile/gemm/test_gemm_pipeline_mem.cpp | 2 +- .../gemm/test_gemm_pipeline_mem_wmma.cpp | 17 ++ .../gemm/test_gemm_pipeline_persistent.cpp | 3 +- .../test_gemm_pipeline_persistent_wmma.cpp | 17 ++ test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 63 ++++-- .../gemm/test_gemm_pipeline_wmma_base.hpp | 24 +++ .../test_gemm_pipeline_ut_cases.inc | 0 54 files changed, 1388 insertions(+), 403 deletions(-) mode change 100755 => 100644 example/ck_tile/17_grouped_gemm/grouped_gemm.cpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_compv3_wmma.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_compv4_wmma.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_mem_wmma.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_persistent_wmma.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_wmma_base.hpp mode change 100755 => 100644 test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc diff --git a/CMakeLists.txt b/CMakeLists.txt index 19c036e1a5..07d2e166bb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -327,6 +327,7 @@ endif() if(USE_OPT_GFX11) add_compile_options(-mcumode) add_compile_options(-mno-wavefrontsize64) + add_compile_definitions(CK_TILE_WAVE32_ENABLED) message(STATUS "CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}") endif() @@ -336,6 +337,12 @@ if(ENABLE_ASM_DUMP) message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}") endif() +if(USE_OPT_GFX12 AND (SUPPORTED_GPU_TARGETS MATCHES "gfx12")) + add_compile_options(-mno-wavefrontsize64) + add_compile_definitions(CK_TILE_WAVE32_ENABLED) + message(STATUS "CK compiled with USE_OPT_GFX12 set to ${USE_OPT_GFX12}") +endif() + ## Threads set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) diff --git a/Jenkinsfile b/Jenkinsfile index ed4c39126b..d1f1baf15f 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1474,7 +1474,7 @@ pipeline { } agent{ label rocmnode("gfx1101") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DUSE_OPT_GFX11=ON -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx11-generic" \ @@ -1495,7 +1495,7 @@ pipeline { } agent{ label rocmnode("gfx1201") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx12-generic" -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx12-generic" -DUSE_OPT_GFX12=ON -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx12-generic" \ diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index 18731e810e..03c531c1ad 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp index 87812369bd..5167097b6d 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp index c3e6ef7d5d..abf7ef3905 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp index 93034a8b70..2582ea8a11 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp index e7c1d6f0be..57e2feb084 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index e319e2d668..eb0a6de8aa 100755 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -172,6 +172,27 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr int kBlockPerCu = 2; }; +template +struct GemmConfigComputeV3_WMMA : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; +}; + template struct GemmConfigComputeV4 : public GemmConfigBase { diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 14c4905720..149a8c2f0c 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -346,5 +346,4 @@ int main(int argc, char* argv[]) // Return a non-zero code to indicate failure return EXIT_FAILURE; } - return EXIT_SUCCESS; } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp old mode 100755 new mode 100644 diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index f0e9518120..ec5f49108e 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -9,6 +9,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/ignore.hpp" #define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111 #define CK_TILE_VMCNT(cnt) \ @@ -59,7 +60,7 @@ enum struct memory_operation_enum : std::uint16_t CK_TILE_HOST_DEVICE constexpr index_t get_warp_size() { -#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) +#if defined(__GFX9__) || (!defined(__HIP_DEVICE_COMPILE__) && !defined(CK_TILE_WAVE32_ENABLED)) return 64; #else return 32; @@ -230,4 +231,20 @@ CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_ } } +// Architecture tags +struct gfx11_t +{ +}; +struct gfx12_t +{ +}; + +CK_TILE_DEVICE static constexpr auto get_device_arch() +{ +#if defined(__gfx11__) + return gfx11_t{}; +#else // if defined(__gfx12__) + return gfx12_t{}; +#endif +} } // namespace ck_tile diff --git a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp index 07c6aa0baf..c02c46958c 100644 --- a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp +++ b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp @@ -6,6 +6,10 @@ #include "ck_tile/core/numeric/type_convert.hpp" #include "ck_tile/core/container/thread_buffer.hpp" +#define HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN \ + __has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16) && \ + __has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16) + namespace ck_tile { template @@ -32,6 +36,14 @@ CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b) return rtn; } +CK_TILE_HOST_DEVICE fp16x2_t add_f16x2_t(const fp16x2_t& a, const fp16x2_t& b) +{ + fp16x2_t rtn; + rtn[0] = add(a[0], b[0]); + rtn[1] = add(a[1], b[1]); + return rtn; +} + CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b) { fp8x4_t rtn; @@ -304,6 +316,44 @@ CK_TILE_DEVICE void atomic_add(bf8x8_t* p_dst, bf8x8_t const& x) } while(cur_v.u64 != old_v); } +// +// Atomic add for fp16x2_t +// +template <> +CK_TILE_DEVICE void atomic_add(fp16x2_t* p_dst, fp16x2_t const& x) +{ +#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN + __builtin_amdgcn_global_atomic_fadd_v2f16(c_style_pointer_cast(p_dst), x); +#else + union U32F162_ADDR + { + uint32_t* u32_a; + fp16x2_t* f162_a; + }; + + union U32F162 + { + uint32_t u32; + fp16x2_t f162; + }; + + U32F162_ADDR dword_addr; + U32F162 cur_v; + U32F162 new_; + uint32_t old_v, new_v; + dword_addr.f162_a = p_dst; + cur_v.u32 = *dword_addr.u32_a; + + do + { + old_v = cur_v.u32; + new_.f162 = add_f16x2_t(cur_v.f162, x); + new_v = new_.u32; + cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v); + } while(cur_v.u32 != old_v); +#endif +} + template CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer& x) { @@ -311,6 +361,7 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer& x) (std::is_same::value && (N == 1)) || (std::is_same::value && (N == 1 || N == 2)) || (std::is_same::value && (N == 1 || N == 2)) || + (std::is_same::value && (N == 2 || N == 4 || N == 8)) || (std::is_same::value && (N == 2 || N == 4 || N == 8)) || (std::is_same::value && (N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 4 || N == 8 || N == 16)), @@ -406,6 +457,13 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer& x) atomic_add(c_style_pointer_cast(p_dst) + 1, x.template get_as()[I1]); } } + else if constexpr(std::is_same::value) + { + static_for<0, N / 2, 1>{}([&](auto i) { + atomic_add(c_style_pointer_cast(p_dst) + i, + x.template get_as()[i]); + }); + } } template diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index e472bd01e5..f94065da2b 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -152,7 +152,7 @@ // buffer atomic add: floating point #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 -#elif defined(__gfx9__) // for GPU code +#elif defined(__gfx9__) || defined(__gfx12__) // for GPU code #define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #else // for GPU code #define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 @@ -274,6 +274,12 @@ #define CK_TILE_WA_ISSUE_2028 0 #endif +#ifndef CK_TILE_WAVE32_ENABLED +#if defined(__gfx11__) || defined(__gfx12__) +#define CK_TILE_WAVE32_ENABLED +#endif +#endif + // Y pointed to R, we don't see a valuable use case. // Will enforce encoding to check Y not pointed to R if set to zero #ifndef CK_TILE_ENC_SUPPORT_Y_TO_R diff --git a/include/ck_tile/host/device_prop.hpp b/include/ck_tile/host/device_prop.hpp index 0d8f89ea31..f86e4b889a 100644 --- a/include/ck_tile/host/device_prop.hpp +++ b/include/ck_tile/host/device_prop.hpp @@ -52,6 +52,19 @@ inline std::string get_device_name() } } +inline bool is_gfx11_supported() +{ + return get_device_name() == "gfx1100" || get_device_name() == "gfx1101" || + get_device_name() == "gfx1102" || get_device_name() == "gfx1103" || + get_device_name() == "gfx1150" || get_device_name() == "gfx1151" || + get_device_name() == "gfx1152"; +} + +inline bool is_gfx12_supported() +{ + return get_device_name() == "gfx1200" || get_device_name() == "gfx1201"; +} + inline bool is_load_tr_supported() { // Check if load transpose is supported. diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index d42f144baa..f773de9e7e 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -203,13 +203,13 @@ struct CShuffleEpilogue static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle); static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle); - using WG = WarpGemmMfmaDispatcher; + using WG = WarpGemmDispatcher; using CWarpDstr = typename WG::CWarpDstr; using CWarpTensor = typename WG::CWarpTensor; diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index fdbe2e7a6d..8a0970f494 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -130,13 +130,13 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue static constexpr index_t kKPerXdl = Problem::kKPerXdl; static constexpr index_t isCTransposed = Problem::isCTransposed; - using WG = WarpGemmMfmaDispatcher; + using WG = WarpGemmDispatcher; using CWarpDstr = typename WG::CWarpDstr; diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index cc00000efc..20783ea8bf 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -430,13 +430,13 @@ struct UniversalFlatmmPipelineAgBgCrPolicy // using AccDataType = float; using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using WarpGemm = WarpGemmMfmaDispatcher; + using WarpGemm = WarpGemmDispatcher; using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< typename Problem::ADataType, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index aa2ec99590..68ead7c765 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -43,7 +43,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0WarpTile>>; - using WarpGemm = WarpGemmMfmaDispatcher< + using WarpGemm = WarpGemmDispatcher< typename Problem::QDataType, typename Problem::KDataType, typename Problem::AccDataType, @@ -78,18 +78,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy typename Problem::BlockFmhaShape::Gemm1WarpTile>>; using WarpGemm = - WarpGemmMfmaDispatcher{}), - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), - true, - false, // SwizzleAccess - false, // UseStructuredSparsity - (Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) - ? WGAttrNumAccessEnum ::Double - : WGAttrNumAccessEnum ::Single>; + WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true, + false, // SwizzleAccess + false, // UseStructuredSparsity + (Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) + ? WGAttrNumAccessEnum ::Double + : WGAttrNumAccessEnum ::Single>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy>; - using WarpGemm = WarpGemmMfmaDispatcher< + using WarpGemm = WarpGemmDispatcher< typename Problem::OGradDataType, typename Problem::VDataType, typename Problem::AccDataType, @@ -150,18 +150,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy typename Problem::BlockFmhaShape::Gemm3WarpTile>>; using WarpGemm = - WarpGemmMfmaDispatcher{}), - Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}), - Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}), - true, - false, // SwizzleAccess - false, // UseStructuredSparsity - (Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) - ? WGAttrNumAccessEnum ::Double - : WGAttrNumAccessEnum ::Single>; + WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}), + true, + false, // SwizzleAccess + false, // UseStructuredSparsity + (Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) + ? WGAttrNumAccessEnum ::Double + : WGAttrNumAccessEnum ::Single>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy>; - using WarpGemm = - WarpGemmMfmaDispatcher{}), - Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}), - Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}), - false>; + using WarpGemm = WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}), + false>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy>; constexpr auto SwizzleA = false; - using WarpGemm = WarpGemmMfmaDispatcher< // + using WarpGemm = WarpGemmDispatcher< // typename Problem::QDataType, typename Problem::KDataType, typename Problem::AccDataType, @@ -66,7 +66,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy typename Problem::BlockFmhaShape::Gemm2WarpTile>>; constexpr auto SwizzleA = false; - using WarpGemm = WarpGemmMfmaDispatcher< // + using WarpGemm = WarpGemmDispatcher< // typename Problem::OGradDataType, typename Problem::VDataType, typename Problem::AccDataType, @@ -106,7 +106,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy typename BlockFmhaShape::Gemm4BlockWarps, typename BlockFmhaShape::Gemm4WarpTile>>; - using WarpGemm = WarpGemmMfmaDispatcher< // + using WarpGemm = WarpGemmDispatcher< // typename Problem::GemmDataType, typename Problem::KDataType, typename Problem::AccDataType, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp index 6582991207..6d414ee851 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -512,14 +512,13 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0WarpTile>>; - using WarpGemm = - WarpGemmMfmaDispatcher{}), - Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), - Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), - true>; + using WarpGemm = WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), + true>; using BlockGemmPolicy = BlockGemmARegBRegCRegV2CustomPolicy>; - using WarpGemm = WarpGemmMfmaDispatcher< - typename Problem::PDataType, - typename Problem::VDataType, - typename Problem::OaccDataType, - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), - true, - false, - false, - ((Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 16 && - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) || - (Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 && - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 16)) - ? WGAttrNumAccessEnum::Double - : WGAttrNumAccessEnum::Single>; + using WarpGemm = + WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + ((Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 16 && + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) || + (Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 && + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 16)) + ? WGAttrNumAccessEnum::Double + : WGAttrNumAccessEnum::Single>; using BlockGemmPolicy = BlockGemmARegBRegCRegV2CustomPolicy{}; // return - // WarpGemmImpl>>{}; } else { - return WarpGemmMfmaDispatcher< - typename Problem::PDataType, - typename Problem::VDataType, - typename Problem::OaccDataType, - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), - true>{}; + return WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true>{}; } }(); diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp index 0c8baaf191..dbd6913cdb 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp @@ -568,7 +568,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy std::is_same_v && S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16) { - return WarpGemmImpl, 2>>{}; } @@ -576,7 +576,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy std::is_same_v && S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32) { - return WarpGemmImpl, 2>>{}; } @@ -695,7 +695,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy std::is_same_v && S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16) { - return WarpGemmImpl, 2>>{}; } @@ -703,7 +703,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy std::is_same_v && S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32) { - return WarpGemmImpl, 2>>{}; } diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index e792820466..7a01420c51 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -58,9 +58,15 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" +#include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp index cfbd78967f..d16651da93 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp @@ -54,16 +54,16 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy return make_tuple(WarpGemmMfmaF16F16F32M32N32K16<>{}, 2, 2); } #else - using WG = WarpGemmMfmaDispatcher; + using WG = WarpGemmDispatcher; return make_tuple(WG{}, 4, 1); #endif } @@ -71,16 +71,16 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy std::is_same_v && std::is_same_v) { - using WG = WarpGemmMfmaDispatcher; + using WG = WarpGemmDispatcher; return make_tuple(WG{}, 4, 1); } else diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 5b7903a9e7..2d439c6970 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -182,7 +182,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; - constexpr index_t WaveSize = 64; + constexpr index_t WaveSize = get_warp_size(); constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); @@ -242,7 +242,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; - constexpr index_t WaveSize = 64; + constexpr index_t WaveSize = get_warp_size(); constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index e6da00da95..b0cd93a661 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -182,7 +182,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{}); constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{}); - constexpr index_t WaveSize = 64; + constexpr index_t WaveSize = get_warp_size(); constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp index 7d88c804f3..a80ed57be5 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp @@ -32,16 +32,17 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy ? WGAttrNumAccessEnum::Double : WGAttrNumAccessEnum::Single; - using WarpGemm = WarpGemmMfmaDispatcher; + using WarpGemm = WarpGemmDispatcher; + using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmDispatcher; + using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { - using AccDataType = float; - using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using WarpGemm = WarpGemmMfmaDispatcher; + using AccDataType = float; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmDispatcher; + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; + using WarpGemm = WarpGemmDispatcher; + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; + using WarpGemm = WarpGemmDispatcher; using BlockWeightPreshufflePolicy = BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy>>; + WarpGemmAttributeMfma>>; using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAttributeMfma>>; #if defined(__gfx950__) template using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl< - WarpGemmAtrributeMfma, + WarpGemmAttributeMfma, AttrNumAccess>>; #else template -using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl, 2, AttrNumAccess>>; @@ -36,42 +36,42 @@ using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl< - WarpGemmAtrributeMfma, + WarpGemmAttributeMfma, AttrNumAccess>>; #else template -using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl, 2, AttrNumAccess>>; #endif -using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl, 1>>; -using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl, 2>>; using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = - WarpGemmImpl>>; using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution = - WarpGemmImpl>>; #if defined(__gfx950__) template using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = - WarpGemmImpl, AttrNumAccess>>; #else template using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = - WarpGemmImpl, 2, AttrNumAccess>>; @@ -80,13 +80,13 @@ using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = #if defined(__gfx950__) template using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = - WarpGemmImpl, AttrNumAccess>>; #else template using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = - WarpGemmImpl, 2, AttrNumAccess>>; @@ -94,36 +94,36 @@ using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = #if defined(__gfx950__) using WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution = - WarpGemmImpl, 1>>; using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution = - WarpGemmImpl, 1>>; #endif using WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution = - WarpGemmImpl>>; #if defined(__gfx950__) using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = - WarpGemmImpl>>; #else using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = - WarpGemmImpl, 2>>; #endif -using WarpGemmMfmaF16F16F32M4N64K16 = WarpGemmImpl, 4>>; -using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl, 4>>; @@ -136,19 +136,19 @@ using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl>>; + WarpGemmAttributeMfma>>; using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAttributeMfma>>; #if defined(__gfx950__) template using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl< - WarpGemmAtrributeMfma, + WarpGemmAttributeMfma, AttrNumAccess>>; #else template -using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl, 2, AttrNumAccess>>; @@ -157,43 +157,43 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl< - WarpGemmAtrributeMfma, + WarpGemmAttributeMfma, AttrNumAccess>>; #else template -using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl, 2, AttrNumAccess>>; #endif -using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl, 1>>; using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = - WarpGemmImpl, 2>>; using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution = - WarpGemmImpl>>; using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution = - WarpGemmImpl>>; #if defined(__gfx950__) template using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = - WarpGemmImpl, AttrNumAccess>>; #else template using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = - WarpGemmImpl, 2, AttrNumAccess>>; @@ -202,153 +202,153 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = #if defined(__gfx950__) template using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = - WarpGemmImpl, AttrNumAccess>>; #else template using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = - WarpGemmImpl, 2, AttrNumAccess>>; #endif using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution = - WarpGemmImpl>>; #if defined(__gfx950__) using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = - WarpGemmImpl>>; #else using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = - WarpGemmImpl, 2>>; #endif -using WarpGemmMfmaBf16Bf16F32M4N64K16 = WarpGemmImpl, 4>>; -using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl, 4>>; // fp8 using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAttributeMfma>>; using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAttributeMfma>>; using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAttributeMfma>>; using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAttributeMfma>>; -using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, 2>>; -using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, 2>>; using WarpGemmMfma_f32_16x16x32_fp8_fp8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAttributeMfma>>; using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAttributeMfma>>; -using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl, 2>>; -using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl, 2>>; template using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl< - WarpGemmAtrributeMfma, + WarpGemmAttributeMfma, AttrNumAccess>>; template using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl< - WarpGemmAtrributeMfma, + WarpGemmAttributeMfma, AttrNumAccess>>; template using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl< - WarpGemmAtrributeMfma, + WarpGemmAttributeMfma, AttrNumAccess>>; template using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl< - WarpGemmAtrributeMfma, + WarpGemmAttributeMfma, AttrNumAccess>>; template using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl< - WarpGemmAtrributeMfma, + WarpGemmAttributeMfma, AttrNumAccess>>; template using WarpGemmMfma_f32_32x32x64_fp8_bf8 = WarpGemmImpl< - WarpGemmAtrributeMfma, + WarpGemmAttributeMfma, AttrNumAccess>>; template using WarpGemmMfma_f32_32x32x64_bf8_fp8 = WarpGemmImpl< - WarpGemmAtrributeMfma, + WarpGemmAttributeMfma, AttrNumAccess>>; template using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl< - WarpGemmAtrributeMfma, + WarpGemmAttributeMfma, AttrNumAccess>>; using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = - WarpGemmImpl>>; using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed = - WarpGemmImpl>>; using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = - WarpGemmImpl>>; using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = - WarpGemmImpl>>; template using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution = - WarpGemmImpl, 2, swizzle_factor>>; // int8 using WarpGemmMfma_i32_32x32x16_i8_i8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAttributeMfma>>; using WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed = - WarpGemmImpl>>; using WarpGemmMfma_i32_16x16x32_i8_i8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAttributeMfma>>; using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed = - WarpGemmImpl>>; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 97fab489ab..36a9955912 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -19,7 +19,7 @@ enum class WGAttrNumAccessEnum template -struct WarpGemmAtrributeMfma +struct WarpGemmAttributeMfma { using Impl = remove_cvref_t; static constexpr auto AttrNumAccess = AttrNumAccess_; @@ -103,7 +103,7 @@ struct WarpGemmAtrributeMfma template -struct WarpGemmAtrributeMfmaIterateK +struct WarpGemmAttributeMfmaIterateK { static_assert(kKIter > 0, "wrong!"); @@ -367,7 +367,7 @@ struct WarpGemmAtrributeMfmaIterateK template -struct WarpGemmAtrributeMfmaTransposedCDistribution +struct WarpGemmAttributeMfmaTransposedCDistribution { using Impl = remove_cvref_t; static constexpr auto AttrNumAccess = AttrNumAccess_; @@ -450,7 +450,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution }; template -struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB +struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB { using Impl = remove_cvref_t; @@ -546,7 +546,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB template -struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution +struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution { using Impl = remove_cvref_t; static constexpr auto AttrNumAccess = AttrNumAccess_; @@ -574,13 +574,13 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() { - return WarpGemmAtrributeMfmaIterateK:: + return WarpGemmAttributeMfmaIterateK:: get_bwarp_dstr_encoding(); } CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() { - return WarpGemmAtrributeMfmaIterateK:: + return WarpGemmAttributeMfmaIterateK:: get_awarp_dstr_encoding(); } @@ -696,7 +696,7 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution }; template -struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB +struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB { using Impl = remove_cvref_t; @@ -840,7 +840,7 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB }; template -struct WarpGemmAtrributeMfmaIterateK_SwizzleA +struct WarpGemmAttributeMfmaIterateK_SwizzleA { using Impl = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp new file mode 100644 index 0000000000..0f021c62f2 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/device_prop.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp" + +namespace ck_tile { + +// TODO: currently only support 16 bit input, which means only support tr16_b128; will use ADataType +// to determine the layout in the future +template +struct AWarpDstrEncodingTrait +{ + using type = tile_distribution_encoding< + sequence, + tuple, + sequence>, + tuple, + tuple, + typename Impl::kABYs2RHsMajor, + typename Impl::kABYs2RHsMinor>; +}; + +template +struct BWarpDstrEncodingTrait +{ + using type = tile_distribution_encoding< + sequence, + tuple, + sequence>, + tuple, + tuple, + typename Impl::kABYs2RHsMajor, + typename Impl::kABYs2RHsMinor>; +}; + +template +struct CWarpDstrEncodingTrait +{ + using type = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple, + tuple, + typename Impl::kCYs2RHsMajor, + typename Impl::kCYs2RHsMinor>; +}; + +template +struct WarpGemmAttributeWmma +{ + using Impl = remove_cvref_t; + + using ADataType = typename Impl::ADataType; + using BDataType = typename Impl::BDataType; + using CDataType = typename Impl::CDataType; + + using AVecType = typename Impl::AVecType; + using BVecType = typename Impl::BVecType; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kM; + static constexpr index_t kN = Impl::kN; + static constexpr index_t kK = Impl::kK; + static constexpr index_t kKPerThread = Impl::kABK0PerLane * Impl::kABK1PerLane; + + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } + + // 16 bit input, kAMLane = 16, kABK0PerLane = 4, kABKLane = 2, kABK1PerLane = 2 + // 8 bit input, kAMLane = 16, kABK0PerLane = 2, kABKLane = 2, kABK1PerLane = 4 + using AWarpDstrEncoding = typename AWarpDstrEncodingTrait::type; + using BWarpDstrEncoding = typename BWarpDstrEncodingTrait::type; + + // kCM0PerLane = 4, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16 for 16 bit input + // kCM0PerLane = 2, kCMLane = 2, kCM1PerLane = 4, kCNLane = 16 for 8 bit input + using CWarpDstrEncoding = typename CWarpDstrEncodingTrait::type; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + if constexpr(kTransC) + { + Impl{}(c_vec, b_vec, a_vec, bool_constant{}); + } + else + { + Impl{}(c_vec, a_vec, b_vec, bool_constant{}); + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + if constexpr(kTransC) + { + return Impl{}(b_vec, a_vec); + } + else + { + return Impl{}(a_vec, b_vec); + } + } +}; + +template +CK_TILE_HOST bool check_wmma_supported() +{ + if(is_gfx12_supported()) + { + return has_wmma_traits_v; + } + else if(is_gfx11_supported()) + { + return has_wmma_traits_v; + } + else + { + return false; + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp new file mode 100644 index 0000000000..13727d41b1 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/config.hpp" + +namespace ck_tile { + +// Base traits for WMMA operations +template +struct WmmaTraits; + +// Generic WMMA implementation using traits +template +struct WarpGemmAttributeWmmaImpl +{ + using ADataType = typename Traits::ADataType; + using BDataType = typename Traits::BDataType; + using CDataType = typename Traits::CDataType; + + using AVecType = typename Traits::AVecType; + using BVecType = typename Traits::BVecType; + using CVecType = typename Traits::CVecType; + + // Forward all static constants and type aliases + static constexpr index_t kM = Traits::kM; + static constexpr index_t kN = Traits::kN; + static constexpr index_t kK = Traits::kK; + + static constexpr index_t kRepeat = Traits::kRepeat; + static constexpr index_t kAMLane = Traits::kAMLane; + static constexpr index_t kBNLane = Traits::kBNLane; + static constexpr index_t kABK0PerLane = Traits::kABK0PerLane; + static constexpr index_t kABKLane = Traits::kABKLane; + static constexpr index_t kABK1PerLane = Traits::kABK1PerLane; + + static constexpr index_t kCMLane = Traits::kCMLane; + static constexpr index_t kCNLane = Traits::kCNLane; + static constexpr index_t kCM0PerLane = Traits::kCM0PerLane; + static constexpr index_t kCM1PerLane = Traits::kCM1PerLane; + + using kABPs2RHssMajor = typename Traits::kABPs2RHssMajor; + using kABPs2RHssMinor = typename Traits::kABPs2RHssMinor; + using kABYs2RHsMajor = typename Traits::kABYs2RHsMajor; + using kABYs2RHsMinor = typename Traits::kABYs2RHsMinor; + + using kCPs2RHssMajor = typename Traits::kCPs2RHssMajor; + using kCPs2RHssMinor = typename Traits::kCPs2RHssMinor; + using kCYs2RHsMajor = typename Traits::kCYs2RHsMajor; + using kCYs2RHsMinor = typename Traits::kCYs2RHsMinor; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + c_vec = Traits::template wmma_intrinsic(a_vec, b_vec, c_vec); + } + + // c_vec = a_vec * b_vec + template + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + return bit_cast( + Traits::template wmma_intrinsic(a_vec, b_vec, CVecType{0.f})); + } +}; + +using DeviceIp = remove_cvref_t; +using WarpGemmAttributeWmmaImpl_f32_16x16x16_f16_f16 = + WarpGemmAttributeWmmaImpl>; + +using WarpGemmAttributeWmmaImpl_f32_16x16x16_bf16_bf16 = + WarpGemmAttributeWmmaImpl>; + +using WarpGemmAttributeWmmaImpl_i32_16x16x16_i8_i8 = + WarpGemmAttributeWmmaImpl>; + +using WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_f8 = + WarpGemmAttributeWmmaImpl>; + +using WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_bf8 = + WarpGemmAttributeWmmaImpl>; + +using WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_bf8 = + WarpGemmAttributeWmmaImpl>; + +using WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_f8 = + WarpGemmAttributeWmmaImpl>; + +template +struct has_wmma_traits +{ + template + static auto + test(int) -> decltype(std::declval< + typename WmmaTraits:: + ADataType>(), + std::true_type{}); + + template + static std::false_type test(...); + + static constexpr bool value = decltype(test(0))::value; +}; + +template +constexpr bool has_wmma_traits_v = + has_wmma_traits::value; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp new file mode 100644 index 0000000000..7e834d9add --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "warp_gemm_attribute_wmma_impl_base_traits.hpp" +namespace ck_tile { +// fp16 specialization - GFX11 +template <> +struct WmmaTraits + : WmmaTraitsBase +{ + template + CK_TILE_DEVICE static CVecType + wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) + { +#ifdef __gfx11__ + return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_vec, b_vec, c_vec); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = c_vec; + return CVecType{0.f}; +#endif + } +}; + +// bf16 specialization - GFX11 +template <> +struct WmmaTraits + : WmmaTraitsBase +{ + template + CK_TILE_DEVICE static CVecType + wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) + { +#ifdef __gfx11__ + return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_vec, b_vec, c_vec); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = c_vec; + return CVecType{0.f}; +#endif + } +}; + +// fp16 specialization - GFX12 +template <> +struct WmmaTraits + : WmmaTraitsBase +{ + template + CK_TILE_DEVICE static CVecType + wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) + { +#ifdef __gfx12__ + return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_vec, b_vec, c_vec); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = c_vec; + return CVecType{0.f}; +#endif + } +}; + +// bf16 specialization - GFX12 +template <> +struct WmmaTraits + : WmmaTraitsBase +{ + template + CK_TILE_DEVICE static CVecType + wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) + { +#ifdef __gfx12__ + return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_vec, b_vec, c_vec); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = c_vec; + return CVecType{0.f}; +#endif + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp new file mode 100644 index 0000000000..81ff5af2fe --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "warp_gemm_attribute_wmma_impl_base_traits.hpp" +namespace ck_tile { +// int8 specialization - GFX11 +template <> +struct WmmaTraits + : WmmaTraitsBase +{ + template + CK_TILE_DEVICE static CVecType + wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) + { +#ifdef __gfx11__ + return __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, // neg_a + bit_cast(a_vec), + true, // neg_b + bit_cast(b_vec), + bit_cast(c_vec), + clamp); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = c_vec; + return CVecType{0}; +#endif + } +}; + +// int8 specialization - GFX12 +template <> +struct WmmaTraits + : WmmaTraitsBase +{ + template + CK_TILE_DEVICE static CVecType + wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) + { +#ifdef __gfx12__ + return __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, // neg_a + bit_cast(a_vec), + true, // neg_b + bit_cast(b_vec), + bit_cast(c_vec), + clamp); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = c_vec; + return CVecType{0}; +#endif + } +}; + +// fp8/bf8 specialization - GFX12 +template <> +struct WmmaTraits + : WmmaTraitsBase +{ + template + CK_TILE_DEVICE static CVecType + wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) + { +#ifdef __gfx12__ + return __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12( + bit_cast(a_vec), bit_cast(b_vec), bit_cast(c_vec)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = c_vec; + return CVecType{0}; +#endif + } +}; + +template <> +struct WmmaTraits + : WmmaTraitsBase +{ + template + CK_TILE_DEVICE static CVecType + wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) + { +#ifdef __gfx12__ + return __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12( + bit_cast(a_vec), bit_cast(b_vec), bit_cast(c_vec)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = c_vec; + return CVecType{0}; +#endif + } +}; + +template <> +struct WmmaTraits + : WmmaTraitsBase +{ + template + CK_TILE_DEVICE static CVecType + wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) + { +#ifdef __gfx12__ + return __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12( + bit_cast(a_vec), bit_cast(b_vec), bit_cast(c_vec)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = c_vec; + return CVecType{0}; +#endif + } +}; + +template <> +struct WmmaTraits + : WmmaTraitsBase +{ + template + CK_TILE_DEVICE static CVecType + wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) + { +#ifdef __gfx12__ + return __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12( + bit_cast(a_vec), bit_cast(b_vec), bit_cast(c_vec)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = c_vec; + return CVecType{0}; +#endif + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp new file mode 100644 index 0000000000..7ea5507d09 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +namespace ck_tile { +template +struct WmmaTraitsBase; + +// GFX11 specialization +template +struct WmmaTraitsBase +{ + using ADataType = ADType; + using BDataType = BDType; + using CDataType = CDType; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 16; + + static constexpr index_t kRepeat = 2; + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABK0PerLane = 1; + static constexpr index_t kABKLane = 1; + static constexpr index_t kABK1PerLane = 16; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 8; + static constexpr index_t kCM1PerLane = 1; + + using kABPs2RHssMajor = sequence<0, 2, 1>; + using kABPs2RHssMinor = sequence<0, 1, 0>; + using kABYs2RHsMajor = sequence<2, 2>; + using kABYs2RHsMinor = sequence<0, 2>; + + using kCPs2RHssMajor = sequence<1, 2>; + using kCPs2RHssMinor = sequence<1, 0>; + using kCYs2RHsMajor = sequence<1, 1>; + using kCYs2RHsMinor = sequence<0, 2>; +}; + +// GFX12 specialization +template +struct WmmaTraitsBase +{ + using ADataType = ADType; + using BDataType = BDType; + using CDataType = CDType; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 16; + + static constexpr index_t kRepeat = 1; + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABK0PerLane = 2; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABK1PerLane = 4; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 8; + + using kABPs2RHssMajor = sequence<2, 1>; + using kABPs2RHssMinor = sequence<1, 0>; + using kABYs2RHsMajor = sequence<2, 2>; + using kABYs2RHsMinor = sequence<0, 2>; + + using kCPs2RHssMajor = sequence<1, 2>; + using kCPs2RHssMinor = sequence<1, 0>; + using kCYs2RHsMajor = sequence<1, 1>; + using kCYs2RHsMinor = sequence<0, 2>; +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 8c6f39e511..d50b208946 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp" namespace ck_tile { @@ -19,115 +20,133 @@ template -struct WarpGemmMfmaDispatcher; +struct WarpGemmDispatcher; // clang-format off // fp16 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>; }; -template<> struct WarpGemmMfmaDispatcher { +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; -template<> struct WarpGemmMfmaDispatcher { +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; }; -template<> struct WarpGemmMfmaDispatcher { +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; -template<> struct WarpGemmMfmaDispatcher { +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M4N64K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M64N4K16; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M4N64K16; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M64N4K16; }; +// WMMA cases +#if defined(__gfx11__) || defined(__gfx12__) +template struct WarpGemmDispatcher { using Type = WarpGemmWmma_f32_16x16x16_f16_f16;}; +#else +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; +#endif -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; }; // fp16 2:4 structural sparsity // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmSmfmacF16F16F32M32N32K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmSmfmacF16F16F32M16N16K32; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmSmfmacF16F16F32M32N32K16; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmSmfmacF16F16F32M16N16K32; }; // bf16 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<>; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>; }; -template<> struct WarpGemmMfmaDispatcher { +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; -template<> struct WarpGemmMfmaDispatcher { +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; }; -template<> struct WarpGemmMfmaDispatcher { +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; -template<> struct WarpGemmMfmaDispatcher { +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; }; +// WMMA cases +#if defined(__gfx11__) || defined(__gfx12__) +template struct WarpGemmDispatcher { using Type = WarpGemmWmma_f32_16x16x16_bf16_bf16; }; +#else +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; +#endif -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; }; // fp8 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8<>; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<>; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<>; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<>; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; }; -template<> struct WarpGemmMfmaDispatcher { +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8; }; -template<> struct WarpGemmMfmaDispatcher { +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; }; -template<> struct WarpGemmMfmaDispatcher { +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8; }; -template<> struct WarpGemmMfmaDispatcher { +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; }; +//WMMA cases +template struct WarpGemmDispatcher { using Type =WarpGemmWmma_f32_16x16x16_f8_f8; }; +template struct WarpGemmDispatcher { using Type =WarpGemmWmma_f32_16x16x16_bf8_bf8; }; +template struct WarpGemmDispatcher { using Type =WarpGemmWmma_f32_16x16x16_f8_bf8; }; +template struct WarpGemmDispatcher { using Type =WarpGemmWmma_f32_16x16x16_bf8_f8; }; + // int8 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; }; +// WMMA cases +template struct WarpGemmDispatcher { using Type = WarpGemmWmma_i32_16x16x16_i8_i8;}; // clang-format on } // namespace impl @@ -142,15 +161,15 @@ template -using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher::Type; +using WarpGemmDispatcher = typename impl::WarpGemmDispatcher::Type; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp new file mode 100644 index 0000000000..cf477f7928 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp" + +namespace ck_tile { + +template +using WarpGemmWmma_f32_16x16x16_f16_f16 = + WarpGemmImpl>; + +template +using WarpGemmWmma_f32_16x16x16_bf16_bf16 = + WarpGemmImpl>; + +template +using WarpGemmWmma_i32_16x16x16_i8_i8 = + WarpGemmImpl>; + +template +using WarpGemmWmma_f32_16x16x16_f8_f8 = + WarpGemmImpl>; + +template +using WarpGemmWmma_f32_16x16x16_bf8_bf8 = + WarpGemmImpl>; + +template +using WarpGemmWmma_f32_16x16x16_f8_bf8 = + WarpGemmImpl>; + +template +using WarpGemmWmma_f32_16x16x16_bf8_f8 = + WarpGemmImpl>; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp index f2d78d7ab5..1fb92ad14d 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp @@ -44,13 +44,13 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC constexpr index_t VecLoadSize = GetVectorSizeAQ(); constexpr bool Preshuffle = Problem::Traits::Preshuffle; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using WarpGemm = WarpGemmMfmaDispatcher; + using WarpGemm = WarpGemmDispatcher; static_assert(std::is_same_v); if constexpr(Preshuffle) @@ -92,13 +92,13 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC static_assert(Problem::kQuantGroupSize % WarpTile::at(I2) == 0, "KPerWarpGemm must be a multiple of kQuantGroupSize!"); - using WarpGemm = WarpGemmMfmaDispatcher; + using WarpGemm = WarpGemmDispatcher; static_assert(std::is_same_v || std::is_same_v); static_assert(std::is_same_v); diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 6cbdc1a24e..a982e30a4c 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -30,6 +30,14 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_basic_bf8 test_gemm_pipeline_basic_bf8.cpp) target_compile_options(test_ck_tile_gemm_pipeline_basic_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") + # On Radeon devices, build the WMMA version instead + add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp) + add_gtest_executable(test_ck_tile_gemm_pipeline_compv3_wmma test_gemm_pipeline_compv3_wmma.cpp) + add_gtest_executable(test_ck_tile_gemm_pipeline_compv4_wmma test_gemm_pipeline_compv4_wmma.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_mem_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv3_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv4_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) else() message(DEBUG "Skipping ck_tile_gemm tests for current target") endif() @@ -46,4 +54,7 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MAT target_compile_options(test_ck_tile_gemm_pipeline_basic_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_basic_bf16 test_gemm_pipeline_basic_bf16.cpp) target_compile_options(test_ck_tile_gemm_pipeline_basic_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") + add_gtest_executable(test_ck_tile_gemm_pipeline_persistent_wmma test_gemm_pipeline_persistent_wmma.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_persistent_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp index 8944e6865d..370f4c16a8 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp @@ -3,7 +3,8 @@ #include "gtest/gtest.h" template -class TestCkTileGemmPipelineCompV3 : public TestCkTileGemmPipeline +class TestCkTileGemmPipelineCompV3 + : public TestCkTileGemmPipeline> { }; diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv3_wmma.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv3_wmma.cpp new file mode 100644 index 0000000000..6bd98d0bc7 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv3_wmma.cpp @@ -0,0 +1,17 @@ +#include "test_gemm_pipeline_kernel_types.hpp" +#include "test_gemm_pipeline_wmma_base.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileGemmPipelineCompV3Wmma + : public TestCkTileGemmPipelineWmmaBase> +{ +}; + +#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV3Wmma + +TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV3Wmma, KernelTypesCompV3Wmma); + +#include "test_gemm_pipeline_ut_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp index 22e77fac41..6d5a5b93d6 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp @@ -3,7 +3,8 @@ #include "gtest/gtest.h" template -class TestCkTileGemmPipelineCompV4 : public TestCkTileGemmPipeline +class TestCkTileGemmPipelineCompV4 + : public TestCkTileGemmPipeline> { }; diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv4_wmma.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv4_wmma.cpp new file mode 100644 index 0000000000..f73901e761 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv4_wmma.cpp @@ -0,0 +1,17 @@ +#include "test_gemm_pipeline_kernel_types.hpp" +#include "test_gemm_pipeline_wmma_base.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileGemmPipelineCompV4Wmma + : public TestCkTileGemmPipelineWmmaBase> +{ +}; + +#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV4Wmma + +TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV4Wmma, KernelTypesCompV4Wmma); + +#include "test_gemm_pipeline_ut_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index ae8899ba71..a55cd100c1 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -9,13 +9,16 @@ #include "ck_tile/host.hpp" #include "test_gemm_pipeline_util.hpp" -using I8 = ck_tile::int8_t; -using I32 = ck_tile::int32_t; +using INT8 = ck_tile::int8_t; +using INT32 = ck_tile::int32_t; using F16 = ck_tile::half_t; using F32 = float; using F8 = ck_tile::fp8_t; +using BF16 = ck_tile::bf16_t; +using BF8 = ck_tile::bf8_t; + using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Intrawave = ck_tile::integral_constant; +using I32 = ck_tile::number<32>; +using I64 = ck_tile::number<64>; +using I256 = ck_tile::number<256>; + // clang-format off using KernelTypesMem = ::testing::Types< - std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, - std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>, - std::tuple< Row, Row, Row, F8, F8, F32, F16, Interwave, Mem>, - std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, Mem>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>, - std::tuple< Row, Col, Row, F8, F8, F32, F16, Interwave, Mem>, - std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, Mem>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>, - std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, Mem>, - std::tuple< Col, Row, Row, F8, F8, F32, F16, Interwave, Mem>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>, - std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, Mem>, - std::tuple< Col, Col, Row, F8, F8, F32, F16, Interwave, Mem> + // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, M_TileSize, K_TileSize, Scheduler, PipelineType + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>, + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>, + std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>, + std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem> +>; + +using KernelTypesMemWmma = ::testing::Types< + std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>, + std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>, + std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I64, I64, I32, I16, I16, I16, Interwave, Mem>, + std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I64, I64, I32, I16, I16, I16, Interwave, Mem>, + std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Interwave, Mem>, + std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, Mem>, + std::tuple< Row, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>, + std::tuple< Row, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>, + std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>, + std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem> >; using KernelTypesCompV3 = ::testing::Types< - std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>, - std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>, - std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>, - std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>, - std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>, - std::tuple< Row, Row, Row, I8, I8, I32, I32, Intrawave, CompV3>, - std::tuple< Row, Col, Row, I8, I8, I32, I32, Intrawave, CompV3>, - std::tuple< Col, Row, Row, I8, I8, I32, I32, Intrawave, CompV3>, - std::tuple< Col, Col, Row, I8, I8, I32, I32, Intrawave, CompV3> - + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3> +>; + +using KernelTypesCompV3Wmma = ::testing::Types< + std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3> >; using KernelTypesCompV4 = ::testing::Types< - std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4> + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4> >; +using KernelTypesCompV4Wmma = ::testing::Types< + std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4> +>; + + using KernelTypesPersistent = ::testing::Types< - std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, Persistent>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, NonPersistent> + // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, M_TileSize, K_TileSize, Scheduler, PipelineType, Persistent + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3, Persistent>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3, NonPersistent> +>; + +using KernelTypesPersistentWmma = ::testing::Types< + std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3, Persistent>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3, NonPersistent> >; // clang-format on diff --git a/test/ck_tile/gemm/test_gemm_pipeline_mem.cpp b/test/ck_tile/gemm/test_gemm_pipeline_mem.cpp index a7f4e68386..51fbebc915 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_mem.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_mem.cpp @@ -3,7 +3,7 @@ #include "gtest/gtest.h" template -class TestCkTileGemmPipelineMem : public TestCkTileGemmPipeline +class TestCkTileGemmPipelineMem : public TestCkTileGemmPipeline> { }; diff --git a/test/ck_tile/gemm/test_gemm_pipeline_mem_wmma.cpp b/test/ck_tile/gemm/test_gemm_pipeline_mem_wmma.cpp new file mode 100644 index 0000000000..5af5e09b28 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_mem_wmma.cpp @@ -0,0 +1,17 @@ +#include "test_gemm_pipeline_kernel_types.hpp" +#include "test_gemm_pipeline_wmma_base.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileGemmPipelineMemWmma + : public TestCkTileGemmPipelineWmmaBase> +{ +}; + +#define TEST_SUITE_NAME TestCkTileGemmPipelineMemWmma + +TYPED_TEST_SUITE(TestCkTileGemmPipelineMemWmma, KernelTypesMemWmma); + +#include "test_gemm_pipeline_ut_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm/test_gemm_pipeline_persistent.cpp b/test/ck_tile/gemm/test_gemm_pipeline_persistent.cpp index 1dea1ab48c..54410acf70 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_persistent.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_persistent.cpp @@ -3,7 +3,8 @@ #include "gtest/gtest.h" template -class TestCkTileGemmPipelinePersistent : public TestCkTileGemmPipeline +class TestCkTileGemmPipelinePersistent + : public TestCkTileGemmPipeline> { }; diff --git a/test/ck_tile/gemm/test_gemm_pipeline_persistent_wmma.cpp b/test/ck_tile/gemm/test_gemm_pipeline_persistent_wmma.cpp new file mode 100644 index 0000000000..45ab586aa9 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_persistent_wmma.cpp @@ -0,0 +1,17 @@ +#include "test_gemm_pipeline_kernel_types.hpp" +#include "test_gemm_pipeline_wmma_base.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileGemmPipelinePersistentWmma + : public TestCkTileGemmPipelineWmmaBase> +{ +}; + +#define TEST_SUITE_NAME TestCkTileGemmPipelinePersistentWmma + +TYPED_TEST_SUITE(TEST_SUITE_NAME, KernelTypesPersistentWmma); + +#include "test_gemm_pipeline_ut_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 70aa161881..26ff847841 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -69,7 +69,7 @@ struct GemmPipelineTypeSelector static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV4"; } }; -template +template class TestCkTileGemmPipeline : public ::testing::Test { protected: @@ -80,32 +80,30 @@ class TestCkTileGemmPipeline : public ::testing::Test using BDataType = std::tuple_element_t<4, Tuple>; using AccDataType = std::tuple_element_t<5, Tuple>; using CDataType = std::tuple_element_t<6, Tuple>; - static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value; - static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value; + static constexpr auto Scheduler = std::tuple_element_t<13, Tuple>::value; + static constexpr auto PipelineType = std::tuple_element_t<14, Tuple>::value; + + static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>{}; + static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>{}; + static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>{}; + + static constexpr ck_tile::index_t M_Warp_Tile = std::tuple_element_t<10, Tuple>{}; + static constexpr ck_tile::index_t N_Warp_Tile = std::tuple_element_t<11, Tuple>{}; + static constexpr ck_tile::index_t K_Warp_Tile = std::tuple_element_t<12, Tuple>{}; using DsLayout = ck_tile::tuple<>; using DsDataType = ck_tile::tuple<>; static constexpr bool Persistent = - ck_tile::tuple_element_or_default_t::value; - // TODO: expose tile size through test t-param ? + ck_tile::tuple_element_or_default_t::value; template void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { - // TODO: This should be parameterized in tests - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = (PipelineType == GemmPipelineType::CompV4) ? 32 : 64; - constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; - constexpr bool kPadM = PadM; constexpr bool kPadN = PadN; constexpr bool kPadK = PadK; @@ -247,11 +245,48 @@ class TestCkTileGemmPipeline : public ::testing::Test BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } + template + bool check_data_type() + { + return static_cast(this) + ->template check_data_type_impl(); + } + + template + bool check_data_type_impl() + { + return true; + } + public: std::vector k_batches_; void SetUp() override { + if(!check_data_type()) + { + GTEST_SKIP() << "Unsupported data type combination for gemm pipeline test."; + } if constexpr(PipelineType == GemmPipelineType::CompV4) { // Only do k_batch = 1 when pipeline is CompV4 diff --git a/test/ck_tile/gemm/test_gemm_pipeline_wmma_base.hpp b/test/ck_tile/gemm/test_gemm_pipeline_wmma_base.hpp new file mode 100644 index 0000000000..8d8d245b6a --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_wmma_base.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include "test_gemm_pipeline_util.hpp" + +template +class TestCkTileGemmPipelineWmmaBase : public TestCkTileGemmPipeline +{ + public: + template + bool check_data_type_impl() + { + return ck_tile::check_wmma_supported(); + } +}; diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc old mode 100755 new mode 100644