[CK_TILE] support gfx950 matrix core in 01_fmha fwd (#2110)

* gfx950 01_fmha fwd

* fix comment

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
carlushuang
2025-04-24 03:40:18 +08:00
committed by GitHub
parent 854159fd00
commit 5487289fc4
5 changed files with 286 additions and 2 deletions

View File

@@ -49,10 +49,16 @@ using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
#else
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
#endif
#if defined(__gfx950__)
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
@@ -65,10 +71,16 @@ using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
2>>;
#endif
#if defined(__gfx950__)
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
#else
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
#endif
using WarpGemmMfmaF16F16F32M4N64K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M4N64K4<WGAttrCtlEnum::Default_>,
@@ -123,10 +135,16 @@ using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
#else
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
#endif
#if defined(__gfx950__)
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
@@ -139,10 +157,16 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
2>>;
#endif
#if defined(__gfx950__)
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
#else
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
#endif
using WarpGemmMfmaBf16Bf16F32M4N64K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4<WGAttrCtlEnum::Default_>,

View File

@@ -356,7 +356,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
}
};
template <typename WarpGemmAttributeMfmaImpl_>
template <typename WarpGemmAttributeMfmaImpl_, index_t SFactor_ = 2>
struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
@@ -373,6 +373,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
@@ -386,7 +387,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
#if 0
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
@@ -407,7 +408,29 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>;
#else
// TODO: more test not only 32x32
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
Impl::kCMLane,
SFactor,
Impl::kCM1PerLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<2, 1, 1, 1, 1>>,
tuple<sequence<0, 0, 2, 1, 3>>,
sequence<2>,
sequence<1>>;
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCNLane>,
sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>;
#endif
template <bool post_nop_ = false>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void operator()(CVecType& c_vec,

View File

@@ -748,6 +748,235 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
}
};
// gfx950
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M32N32K16
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
using BDataType = fp16_t;
using CDataType = float;
using AVecType = ext_vector_t<fp16_t, 8>;
using BVecType = ext_vector_t<fp16_t, 8>;
using CVecType = ext_vector_t<float, 16>;
static constexpr index_t kM = 32;
static constexpr index_t kN = 32;
static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 32;
static constexpr index_t kCM0PerLane = 4;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_f16", Ctrl)
else
{
#if defined(__gfx950__)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, c_vec, 0, 0, 0);
#elif defined(__gfx90a__) || defined(__gfx94__)
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(
reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
.template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
.template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
c_vec,
0,
0,
0);
});
#elif defined(__gfx908__)
static_for<0, 4, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
c_vec,
0,
0,
0);
});
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx950__)
return __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0);
#elif defined(__gfx90a__) || defined(__gfx94__)
CVecType c_vec{0.f};
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(
reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
.template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
.template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
c_vec,
0,
0,
0);
});
return c_vec;
#elif defined(__gfx908__)
CVecType c_vec{0.f};
static_for<0, 4, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
c_vec,
0,
0,
0);
});
return c_vec;
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
return CVecType{0.f};
#endif
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = bf16_t;
using BDataType = bf16_t;
using CDataType = float;
using AVecType = ext_vector_t<bf16_t, 8>;
using BVecType = ext_vector_t<bf16_t, 8>;
using CVecType = ext_vector_t<float, 16>;
static constexpr index_t kM = 32;
static constexpr index_t kN = 32;
static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 32;
static constexpr index_t kCM0PerLane = 4;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_bf16", Ctrl)
else
{
#if defined(__gfx950__)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, c_vec, 0, 0, 0);
#elif defined(__gfx90a__) || defined(__gfx94__)
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
c_vec,
0,
0,
0);
});
#elif defined(__gfx908__)
static_for<0, 4, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
c_vec,
0,
0,
0);
});
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx950__)
return __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0);
#elif defined(__gfx90a__) || defined(__gfx94__)
CVecType c_vec{0.f};
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
c_vec,
0,
0,
0);
});
return c_vec;
#elif defined(__gfx908__)
CVecType c_vec{0.f};
static_for<0, 4, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
c_vec,
0,
0,
0);
});
return c_vec;
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
return CVecType{0.f};
#endif
}
};
// FP8
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base

View File

@@ -1,3 +1,8 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp"