mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
wip - make a start on defining warp level classes which are templated on GfxId
This commit is contained in:
@@ -234,4 +234,44 @@ CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity()
|
||||
#endif
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Architecture-specific parameter definitions
|
||||
// We'll define all parameters for all supported architectures here.
|
||||
// ============================================================================
|
||||
|
||||
// Parameters for gfx120x (using a namespace for organization or just global constexpr)
|
||||
namespace Gfx120x {
|
||||
constexpr ck_tile::index_t WarpTile = 32;
|
||||
constexpr ck_tile::index_t VecLen = 8;
|
||||
}
|
||||
|
||||
// Parameters for gfx90x (example values, adjust as needed)
|
||||
namespace Gfx90x {
|
||||
constexpr ck_tile::index_t WarpTile = 64;
|
||||
constexpr ck_tile::index_t VecLen = 4;
|
||||
}
|
||||
|
||||
// Generic Parameters - should never be used in this example
|
||||
// templated run function should only be instantiated for Gfx120x and Gfx90x
|
||||
namespace Generic {
|
||||
constexpr ck_tile::index_t WarpTile = -1;
|
||||
}
|
||||
|
||||
// Helper to get VecLen based on GfxId
|
||||
template <int GfxId>
|
||||
struct GfxConfig
|
||||
{
|
||||
static constexpr int get_vec_len()
|
||||
{
|
||||
if constexpr (GfxId == 1200)
|
||||
{
|
||||
return Gfx120x::VecLen;
|
||||
}
|
||||
else if constexpr (GfxId == 900)
|
||||
{
|
||||
return Gfx90x::VecLen;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
// ============================================================================
|
||||
// Architecture-specific parameter definitions
|
||||
// We'll define all parameters for all supported architectures here.
|
||||
// ============================================================================
|
||||
|
||||
// Parameters for gfx120x (using a namespace for organization or just global constexpr)
|
||||
namespace Gfx120x {
|
||||
constexpr ck_tile::index_t WarpTile = 32;
|
||||
}
|
||||
|
||||
// Parameters for gfx90x (example values, adjust as needed)
|
||||
namespace Gfx90x {
|
||||
constexpr ck_tile::index_t WarpTile = 64;
|
||||
}
|
||||
|
||||
// Generic Parameters - should never be used in this example
|
||||
// templated run function should only be instantiated for Gfx120x and Gfx90x
|
||||
namespace Generic {
|
||||
constexpr ck_tile::index_t WarpTile = -1;
|
||||
}
|
||||
@@ -51,8 +51,12 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#if 0
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
|
||||
#endif
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_generic.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_generic_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
@@ -5,16 +5,22 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
|
||||
|
||||
#if 0
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp"
|
||||
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp"
|
||||
#endif
|
||||
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_generic.hpp"
|
||||
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// fp16
|
||||
|
||||
#if 0
|
||||
using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
@@ -23,7 +29,15 @@ using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
|
||||
|
||||
using WarpGemmWmmaF16F16F32M16N16K16 = WarpGemmImpl<
|
||||
WarpGemmAttributeWmma<WarpGemmAttributeWmmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
|
||||
WarpGemmAttributeGeneric<WarpGemmAttributeGenericImplF16F16F32M16N16K16<900, WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmWmmaF16F16F32M16N16K16 = WarpGemmImpl<
|
||||
WarpGemmAttributeGeneric<WarpGemmAttributeGenericImplF16F16F32M16N16K16<1200, WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if 0
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<
|
||||
@@ -343,4 +357,6 @@ using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
452
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_generic.hpp
Normal file
452
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_generic.hpp
Normal file
@@ -0,0 +1,452 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_generic_impl_F16F16F32M16N16K16.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Number of groups of consecutive elements to fill in a ABKLane
|
||||
enum class WGAttrNumAccessEnum
|
||||
{
|
||||
Single = 1,
|
||||
Double = 2,
|
||||
Quad = 4,
|
||||
Invalid = -1
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeGenericImpl_,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmAttributeGeneric
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeGenericImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = typename Impl::AVecType;
|
||||
using BVecType = typename Impl::BVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane;
|
||||
static constexpr index_t kCMLane = Impl::kCMLane;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeGenericImpl is not supported");
|
||||
|
||||
template <index_t kMNLane>
|
||||
static constexpr auto get_warp_dstr_encoding()
|
||||
{
|
||||
if constexpr(AttrNumAccessV == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKPerThread % AttrNumAccessV == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>,
|
||||
sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
|
||||
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
|
||||
sequence<Impl::kCNLane>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 2>>;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
return Impl{}(a_vec, b_vec);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeGenericImpl_,
|
||||
index_t kKIter,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmAttributeGenericIterateK
|
||||
{
|
||||
static_assert(kKIter > 0, "wrong!");
|
||||
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeGenericImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType =
|
||||
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
|
||||
using BVecType =
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
|
||||
|
||||
static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
|
||||
"Multi-block on both M & N directions is not supported");
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
if constexpr(AttrNumAccessV == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKPerThread % AttrNumAccessV == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<AttrNumAccessV,
|
||||
Impl::kABKLane,
|
||||
Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
static_assert(AttrNumAccessV == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
// each M blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<Impl::kBNBlock>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
static_assert(AttrNumAccessV == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMBlock, Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<1, 2, 1>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
if constexpr(AttrNumAccessV == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
static_assert(kKPerThread % AttrNumAccessV == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<AttrNumAccessV,
|
||||
Impl::kABKLane,
|
||||
Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
static_assert(AttrNumAccessV == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNBlock, Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<1, 2, 1>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
static_assert(AttrNumAccessV == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
// each N blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<Impl::kAMBlock>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
|
||||
sequence<Impl::kCNLane>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
|
||||
sequence<Impl::kBNBlock * Impl::kCNLane>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<
|
||||
sequence<Impl::kCM0PerLane, Impl::kAMBlock * Impl::kCMLane, Impl::kCM1PerLane>,
|
||||
sequence<Impl::kCNLane>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
|
||||
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding());
|
||||
|
||||
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding());
|
||||
|
||||
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
bool_constant<post_nop_>{});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t iKIter, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
number<iKIter>,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
static_assert(iKIter < kKIter);
|
||||
|
||||
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
bool_constant<post_nop_>{});
|
||||
//});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
// c = a * b
|
||||
auto c_vec = Impl{}(
|
||||
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
|
||||
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
|
||||
|
||||
// c += a * b
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeGenericImpl_,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmAttributeGenericTransposedCDistribution
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeGenericImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
|
||||
|
||||
using ADataType = typename Impl::BDataType;
|
||||
using BDataType = typename Impl::ADataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = typename Impl::BVecType;
|
||||
using BVecType = typename Impl::AVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeGenericImpl is not supported");
|
||||
|
||||
template <index_t kMNLane>
|
||||
static constexpr auto get_warp_dstr_encoding()
|
||||
{
|
||||
if constexpr(AttrNumAccessV == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKPerThread % AttrNumAccessV == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>,
|
||||
sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
|
||||
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCNLane>,
|
||||
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
// swap A and B
|
||||
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
// swap A and B
|
||||
return Impl{}(b_vec, a_vec);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,132 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: refactor warp-gemm
|
||||
// currently there is a discrepency for vav/vva if we need transpose C/D
|
||||
// e.g. if we want A:agpr, B:vgpr, we have to use vva in WGAttrEnum
|
||||
// because we swap the A/B pointer in _impl code (but not known this info here)
|
||||
enum class WGAttrCtlEnum
|
||||
{
|
||||
Default_ = 0,
|
||||
Raw_vvv = 1, // c-vgpr, a-vgpr, b-vgpr
|
||||
Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr
|
||||
Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr
|
||||
Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr
|
||||
Raw_avv = 5, // c-agpr, a-vgpr, b-vgpr
|
||||
// raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
|
||||
};
|
||||
|
||||
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
|
||||
if constexpr(post_nop_) \
|
||||
{ \
|
||||
asm volatile(mfma_ " %0, %1, %2, %3 ; yyy\n" \
|
||||
"s_nop 3" \
|
||||
: dmod_(c_vec) \
|
||||
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
|
||||
:); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
asm volatile(mfma_ " %0, %1, %2, %3\n" \
|
||||
: dmod_(c_vec) \
|
||||
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
|
||||
:); \
|
||||
}
|
||||
|
||||
#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \
|
||||
if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \
|
||||
{ \
|
||||
DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \
|
||||
} \
|
||||
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \
|
||||
{ \
|
||||
DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \
|
||||
} \
|
||||
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \
|
||||
{ \
|
||||
DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \
|
||||
} \
|
||||
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \
|
||||
{ \
|
||||
DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \
|
||||
} \
|
||||
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \
|
||||
{ \
|
||||
DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
|
||||
}
|
||||
|
||||
|
||||
template <int GfxId, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeGenericImplF16F16F32M16N16K16
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = fp16_t;
|
||||
using BDataType = fp16_t;
|
||||
using CDataType = float;
|
||||
|
||||
static constexpr int VecLen = GfxConfig<GfxId>::get_vec_len();
|
||||
using AVecType = ext_vector_t<fp16_t, VecLen>;
|
||||
using BVecType = ext_vector_t<fp16_t, VecLen>;
|
||||
using CVecType = ext_vector_t<float, VecLen>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kK = 16;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 4;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16f16", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx12__)
|
||||
c_vec = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_vec, b_vec, c_vec);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
|
||||
#elif defined(__gfx12__)
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_vec, b_vec, fp32x8_t{0.f}));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@@ -334,8 +334,8 @@ struct WarpGemmAttributeWmmaTransposedCDistribution
|
||||
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>;
|
||||
sequence<2, 4>,
|
||||
sequence<0, 4>>;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
Reference in New Issue
Block a user