diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index f5434e7016..e31c96caaa 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -15,7 +15,10 @@ add_custom_command( ) set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") -add_example_executable(${EXAMPLE_FMHA_FWD} fmha_fwd.cpp) +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding tile_example ${EXAMPLE_NAME}") +add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 49e286c156..65ce774531 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -6,7 +6,7 @@ This folder contains example for fmha(fused multi-head attention) using ck_tile ``` # in the root of ck_tile mkdir build && cd build -sh ../script/cmake-ck_tile-dev.sh ../ # you can replace this to gfx90a, gfx942... +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... make tile_example_fmha_fwd -j ``` This will result in an executable `build/bin/tile_example_fmha_fwd` diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 6b1c11fa27..2767ee05b2 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -49,6 +49,7 @@ #include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/ignore.hpp" #include "ck_tile/core/utility/magic_div.hpp" #include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/to_sequence.hpp" diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 7b7dc880fe..a614cd3a70 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -20,6 +20,8 @@ #define CK_TILE_DEVICE_EXTERN #endif +// minimal arch + #ifndef CK_TILE_USE_CUSTOM_DATA_TYPE #define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code #endif diff --git a/include/ck_tile/core/utility/ignore.hpp b/include/ck_tile/core/utility/ignore.hpp new file mode 100644 index 0000000000..eead914954 --- /dev/null +++ b/include/ck_tile/core/utility/ignore.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +// https://en.cppreference.com/w/cpp/utility/tuple/ignore + +namespace ck_tile { + +namespace detail { +struct ignore_t +{ + template + constexpr void operator=(T&&) const noexcept + { + } +}; +} // namespace detail + +inline constexpr detail::ignore_t ignore; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index e618d66a75..cb250516f4 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -36,14 +36,28 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif } // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) return bit_cast( __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif } }; @@ -75,14 +89,28 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif } // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) return bit_cast( __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif } }; @@ -115,14 +143,52 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { +#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); +#elif defined(__gfx908__) + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif } // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { +#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) return bit_cast( __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); +#elif defined(__gfx908__) + CVecType c_vec{0.f}; + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif } }; @@ -154,14 +220,52 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { +#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); +#elif defined(__gfx908__) + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif } // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { +#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) return bit_cast( __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#elif defined(__gfx908__) + CVecType c_vec{0.f}; + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif } }; @@ -208,7 +312,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base else if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); -#else +#elif defined(__gfx908__) || defined(__gfx90a__) static_for<0, 8, 1>{}([&](auto k) { float a_f32 = type_convert(reinterpret_cast&>(a_vec) @@ -219,12 +323,17 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0); }); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; #endif } // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); @@ -237,6 +346,24 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base else if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); +#elif defined(__gfx908__) || defined(__gfx90a__) + CVecType c_vec{0.f}; + static_for<0, 8, 1>{}([&](auto k) { + float a_f32 = + type_convert(reinterpret_cast&>(a_vec) + .template get_as()[number{}]); + float b_f32 = + type_convert(reinterpret_cast&>(b_vec) + .template get_as()[number{}]); + + c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0); + }); + return c_vec; +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif } };