diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 794f7f21f2..35f5170179 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -44,8 +44,11 @@ #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index bd7a0566a2..e6350a8827 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -49,10 +49,16 @@ using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution = WarpGemmImpl>>; +#if defined(__gfx950__) +using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = + WarpGemmImpl>>; +#else using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = WarpGemmImpl, 2>>; +#endif #if defined(__gfx950__) using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = @@ -65,10 +71,16 @@ using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = 2>>; #endif +#if defined(__gfx950__) +using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = + WarpGemmImpl>>; +#else using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl, 2>>; +#endif using WarpGemmMfmaF16F16F32M4N64K16 = WarpGemmImpl, @@ -123,10 +135,16 @@ using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution = WarpGemmImpl>>; +#if defined(__gfx950__) +using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = + WarpGemmImpl>>; +#else using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = WarpGemmImpl, 2>>; +#endif #if defined(__gfx950__) using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = @@ -139,10 +157,16 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = 2>>; #endif +#if defined(__gfx950__) +using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = + WarpGemmImpl>>; +#else using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl, 2>>; +#endif using WarpGemmMfmaBf16Bf16F32M4N64K16 = WarpGemmImpl, diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index e7d4c37966..93ccdb5f57 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -356,7 +356,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution } }; -template +template struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB { using Impl = remove_cvref_t; @@ -373,6 +373,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB static constexpr index_t kN = Impl::kM; static constexpr index_t kK = Impl::kK; static constexpr index_t kKPerThread = Impl::kABKPerLane; + static constexpr index_t SFactor = SFactor_; // group how many CM1 together CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } @@ -386,7 +387,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB tuple>, sequence<2>, sequence<1>>; - +#if 0 using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple>, sequence<2, 2>, sequence<0, 2>>; +#else + // TODO: more test not only 32x32 + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>; +#endif template // c_vec += a_vec * b_vec CK_TILE_DEVICE void operator()(CVecType& c_vec, diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index f937899ffd..08f813a1e3 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -748,6 +748,235 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 } }; +// gfx950 +template +struct WarpGemmAttributeMfmaImplF16F16F32M32N32K16 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = fp16_t; + using BDataType = fp16_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 16; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_f16", Ctrl) + else + { +#if defined(__gfx950__) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, c_vec, 0, 0, 0); +#elif defined(__gfx90a__) || defined(__gfx94__) + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); +#elif defined(__gfx908__) + static_for<0, 4, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + return __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0); +#elif defined(__gfx90a__) || defined(__gfx94__) + CVecType c_vec{0.f}; + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; +#elif defined(__gfx908__) + CVecType c_vec{0.f}; + static_for<0, 4, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + +template +struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = bf16_t; + using BDataType = bf16_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 16; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_bf16", Ctrl) + else + { +#if defined(__gfx950__) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, c_vec, 0, 0, 0); +#elif defined(__gfx90a__) || defined(__gfx94__) + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); +#elif defined(__gfx908__) + static_for<0, 4, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + return __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0); +#elif defined(__gfx90a__) || defined(__gfx94__) + CVecType c_vec{0.f}; + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; +#elif defined(__gfx908__) + CVecType c_vec{0.f}; + static_for<0, 4, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + // FP8 template struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp index adf548aaca..84cdf17d66 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp @@ -1,3 +1,8 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp"