wip - make a start on defining warp level classes which are templated on GfxId

This commit is contained in:
Philip Maybank
2025-07-29 11:54:09 -04:00
parent 7c2f6ad40d
commit 1c098fd37a
8 changed files with 647 additions and 23 deletions

View File

@@ -234,4 +234,44 @@ CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity()
#endif
}
// ============================================================================
// Architecture-specific parameter definitions
// We'll define all parameters for all supported architectures here.
// ============================================================================
// Parameters for gfx120x (using a namespace for organization or just global constexpr)
namespace Gfx120x {
constexpr ck_tile::index_t WarpTile = 32;
constexpr ck_tile::index_t VecLen = 8;
}
// Parameters for gfx90x (example values, adjust as needed)
namespace Gfx90x {
constexpr ck_tile::index_t WarpTile = 64;
constexpr ck_tile::index_t VecLen = 4;
}
// Generic Parameters - should never be used in this example
// templated run function should only be instantiated for Gfx120x and Gfx90x
namespace Generic {
constexpr ck_tile::index_t WarpTile = -1;
}
// Helper to get VecLen based on GfxId
template <int GfxId>
struct GfxConfig
{
static constexpr int get_vec_len()
{
if constexpr (GfxId == 1200)
{
return Gfx120x::VecLen;
}
else if constexpr (GfxId == 900)
{
return Gfx90x::VecLen;
}
}
};
} // namespace ck_tile

View File

@@ -1,20 +0,0 @@
// ============================================================================
// Architecture-specific parameter definitions
// We'll define all parameters for all supported architectures here.
// ============================================================================
// Parameters for gfx120x (using a namespace for organization or just global constexpr)
namespace Gfx120x {
constexpr ck_tile::index_t WarpTile = 32;
}
// Parameters for gfx90x (example values, adjust as needed)
namespace Gfx90x {
constexpr ck_tile::index_t WarpTile = 64;
}
// Generic Parameters - should never be used in this example
// templated run function should only be instantiated for Gfx120x and Gfx90x
namespace Generic {
constexpr ck_tile::index_t WarpTile = -1;
}

View File

@@ -51,8 +51,12 @@
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#if 0
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
#endif
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_generic.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_generic_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"

View File

@@ -5,16 +5,22 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#if 0
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp"
#endif
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_generic.hpp"
namespace ck_tile {
// fp16
#if 0
using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
@@ -23,7 +29,15 @@ using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
using WarpGemmWmmaF16F16F32M16N16K16 = WarpGemmImpl<
WarpGemmAttributeWmma<WarpGemmAttributeWmmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#endif
using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
WarpGemmAttributeGeneric<WarpGemmAttributeGenericImplF16F16F32M16N16K16<900, WGAttrCtlEnum::Default_>>>;
using WarpGemmWmmaF16F16F32M16N16K16 = WarpGemmImpl<
WarpGemmAttributeGeneric<WarpGemmAttributeGenericImplF16F16F32M16N16K16<1200, WGAttrCtlEnum::Default_>>>;
#if 0
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<
@@ -343,4 +357,6 @@ using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
#endif
} // namespace ck_tile

View File

@@ -0,0 +1,452 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_generic_impl_F16F16F32M16N16K16.hpp"
namespace ck_tile {
// Number of groups of consecutive elements to fill in a ABKLane
enum class WGAttrNumAccessEnum
{
Single = 1,
Double = 2,
Quad = 4,
Invalid = -1
};
template <typename WarpGemmAttributeGenericImpl_,
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
struct WarpGemmAttributeGeneric
{
using Impl = remove_cvref_t<WarpGemmAttributeGenericImpl_>;
static constexpr auto AttrNumAccess = AttrNumAccess_;
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
using ADataType = typename Impl::ADataType;
using BDataType = typename Impl::BDataType;
using CDataType = typename Impl::CDataType;
using AVecType = typename Impl::AVecType;
using BVecType = typename Impl::BVecType;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
static constexpr index_t kCMLane = Impl::kCMLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeGenericImpl is not supported");
template <index_t kMNLane>
static constexpr auto get_warp_dstr_encoding()
{
if constexpr(AttrNumAccessV == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else
{
static_assert(kKPerThread % AttrNumAccessV == 0,
"kKPerThread must be divisible by NumAccess");
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>,
sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
return Impl{}(a_vec, b_vec);
}
};
template <typename WarpGemmAttributeGenericImpl_,
index_t kKIter,
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
struct WarpGemmAttributeGenericIterateK
{
static_assert(kKIter > 0, "wrong!");
using Impl = remove_cvref_t<WarpGemmAttributeGenericImpl_>;
static constexpr auto AttrNumAccess = AttrNumAccess_;
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
using ADataType = typename Impl::ADataType;
using BDataType = typename Impl::BDataType;
using CDataType = typename Impl::CDataType;
using AVecType =
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
using BVecType =
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
"Multi-block on both M & N directions is not supported");
CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
if constexpr(AttrNumAccessV == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else
{
static_assert(kKPerThread % AttrNumAccessV == 0,
"kKPerThread must be divisible by NumAccess");
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane>,
sequence<AttrNumAccessV,
Impl::kABKLane,
Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
static_assert(AttrNumAccessV == 1,
"Multiple access is not supported when using multi-block");
// each M blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kBNBlock>,
tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
static_assert(AttrNumAccessV == 1,
"Multiple access is not supported when using multi-block");
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMBlock, Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
}
CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
if constexpr(AttrNumAccessV == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else
{
static_assert(kKPerThread % AttrNumAccessV == 0,
"kKPerThread must be divisible by NumAccess");
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>,
sequence<AttrNumAccessV,
Impl::kABKLane,
Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
static_assert(AttrNumAccessV == 1,
"Multiple access is not supported when using multi-block");
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNBlock, Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
static_assert(AttrNumAccessV == 1,
"Multiple access is not supported when using multi-block");
// each N blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kAMBlock>,
tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
}
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kBNBlock * Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<
sequence<Impl::kCM0PerLane, Impl::kAMBlock * Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>{};
}
}
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding());
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding());
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
});
}
template <index_t iKIter, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
constexpr auto I0 = number<0>{};
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
// c = a * b
auto c_vec = Impl{}(
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
// c += a * b
static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]);
});
return c_vec;
}
};
template <typename WarpGemmAttributeGenericImpl_,
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
struct WarpGemmAttributeGenericTransposedCDistribution
{
using Impl = remove_cvref_t<WarpGemmAttributeGenericImpl_>;
static constexpr auto AttrNumAccess = AttrNumAccess_;
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
using ADataType = typename Impl::BDataType;
using BDataType = typename Impl::ADataType;
using CDataType = typename Impl::CDataType;
using AVecType = typename Impl::BVecType;
using BVecType = typename Impl::AVecType;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeGenericImpl is not supported");
template <index_t kMNLane>
static constexpr auto get_warp_dstr_encoding()
{
if constexpr(AttrNumAccessV == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else
{
static_assert(kKPerThread % AttrNumAccessV == 0,
"kKPerThread must be divisible by NumAccess");
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>,
sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCNLane>,
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
// swap A and B
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
// swap A and B
return Impl{}(b_vec, a_vec);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,132 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
// TODO: refactor warp-gemm
// currently there is a discrepency for vav/vva if we need transpose C/D
// e.g. if we want A:agpr, B:vgpr, we have to use vva in WGAttrEnum
// because we swap the A/B pointer in _impl code (but not known this info here)
enum class WGAttrCtlEnum
{
Default_ = 0,
Raw_vvv = 1, // c-vgpr, a-vgpr, b-vgpr
Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr
Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr
Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr
Raw_avv = 5, // c-agpr, a-vgpr, b-vgpr
// raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
};
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
if constexpr(post_nop_) \
{ \
asm volatile(mfma_ " %0, %1, %2, %3 ; yyy\n" \
"s_nop 3" \
: dmod_(c_vec) \
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
:); \
} \
else \
{ \
asm volatile(mfma_ " %0, %1, %2, %3\n" \
: dmod_(c_vec) \
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
:); \
}
#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \
if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \
{ \
DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
}
template <int GfxId, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeGenericImplF16F16F32M16N16K16
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
using BDataType = fp16_t;
using CDataType = float;
static constexpr int VecLen = GfxConfig<GfxId>::get_vec_len();
using AVecType = ext_vector_t<fp16_t, VecLen>;
using BVecType = ext_vector_t<fp16_t, VecLen>;
using CVecType = ext_vector_t<float, VecLen>;
static constexpr index_t kM = 16;
static constexpr index_t kN = 16;
static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4;
static constexpr index_t kABKPerLane = 4;
static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16f16", Ctrl)
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
#elif defined(__gfx12__)
c_vec = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_vec, b_vec, c_vec);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx9__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#elif defined(__gfx12__)
return bit_cast<CVecType>(
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_vec, b_vec, fp32x8_t{0.f}));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
return CVecType{0.f};
#endif
}
};

View File

@@ -334,8 +334,8 @@ struct WarpGemmAttributeWmmaTransposedCDistribution
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>;
sequence<2, 4>,
sequence<0, 4>>;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>