Add NumAccess as a template parameter to WarpGemmAttributeMfma::get_warp_dstr_encoding

This commit is contained in:
Sami Aario
2025-11-28 09:20:19 +00:00
parent 933e09f6c3
commit f2fcc4a461
5 changed files with 121 additions and 36 deletions

View File

@@ -139,6 +139,7 @@ struct BlockUniversalGemmAsBsCr
using I0 = number<0>;
using I1 = number<1>;
template <bool convert = false>
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
constexpr index_t KPerThread = Traits::KPerThread;
@@ -158,12 +159,18 @@ struct BlockUniversalGemmAsBsCr
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
using Attr = typename WarpGemm::WarpGemmAttribute;
constexpr auto NumAccessA =
convert ? Attr::AttrNumAccessV * sizeof(ADataType) / sizeof(ComputeDataType)
: Attr::AttrNumAccessV;
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
a_block_outer_dstr_encoding,
WarpGemm::WarpGemmAttribute::template get_awarp_dstr_encoding<NumAccessA>());
return a_block_dstr_encode;
}
template <bool convert = false>
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
constexpr index_t KPerThread = Traits::KPerThread;
@@ -183,8 +190,13 @@ struct BlockUniversalGemmAsBsCr
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
using Attr = typename WarpGemm::WarpGemmAttribute;
constexpr auto NumAccessB =
convert ? Attr::AttrNumAccessV * sizeof(BDataType) / sizeof(ComputeDataType)
: Attr::AttrNumAccessV;
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
b_block_outer_dstr_encoding,
WarpGemm::WarpGemmAttribute::template get_bwarp_dstr_encoding<NumAccessB>());
return b_block_dstr_encode;
}

View File

@@ -440,10 +440,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
constexpr bool is_load_tr = is_a_load_tr_v || is_b_load_tr_v;
constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(
BlockGemm::template MakeABlockDistributionEncode<is_load_tr>());
constexpr auto b_lds_load_tile_distr = make_static_tile_distribution(
BlockGemm::template MakeBBlockDistributionEncode<is_load_tr>());
// A DRAM tile window for load
// A LDS tile window for store

View File

@@ -44,12 +44,12 @@ struct WarpGemmAttributeMfma
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
template <index_t kMNLane>
template <index_t kMNLane, index_t NumAccess>
static constexpr auto get_warp_dstr_encoding()
{
static_assert(kKPerThread % AttrNumAccessV == 0,
static_assert(NumAccess != 0 && kKPerThread % NumAccess == 0,
"kKPerThread must be divisible by NumAccess");
if constexpr(AttrNumAccessV == 1)
if constexpr(NumAccess == 1)
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
@@ -61,14 +61,30 @@ struct WarpGemmAttributeMfma
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>,
sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
sequence<NumAccess, Impl::kABKLane, Impl::kABKPerLane / NumAccess>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
template <index_t NumAccess = AttrNumAccessV>
static constexpr auto get_awarp_dstr_encoding()
{
return get_warp_dstr_encoding<Impl::kAMLane, NumAccess>();
}
template <index_t NumAccess = AttrNumAccessV>
static constexpr auto get_bwarp_dstr_encoding()
{
return get_warp_dstr_encoding<Impl::kBNLane, NumAccess>();
}
template <index_t NumAccess = AttrNumAccessV>
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding<NumAccess>());
template <index_t NumAccess = AttrNumAccessV>
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding<NumAccess>());
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
@@ -151,14 +167,14 @@ struct WarpGemmAttributeMfmaIterateK
static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
"Multi-block on both M & N directions is not supported");
template <index_t kMNLane, index_t kMNBlock, index_t kNMBlock>
template <index_t kMNLane, index_t kMNBlock, index_t kNMBlock, index_t NumAccess = AttrNumAccessV>
CK_TILE_DEVICE static constexpr auto get_warp_dstr_encoding()
{
if constexpr(kMNBlock == 1 && kNMBlock == 1)
{
static_assert(kKPerThread % AttrNumAccessV == 0,
"kKPerThread must be divisible by NumAccess");
if constexpr(AttrNumAccessV == 1)
if constexpr(NumAccess == 1)
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
@@ -170,9 +186,9 @@ struct WarpGemmAttributeMfmaIterateK
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>,
sequence<AttrNumAccessV,
sequence<NumAccess,
Impl::kABKLane,
Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
Impl::kABKPerLane * kKIter / NumAccess>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
@@ -180,7 +196,7 @@ struct WarpGemmAttributeMfmaIterateK
}
else if constexpr(kMNBlock == 1 && 1 < kNMBlock)
{
static_assert(AttrNumAccessV == 1,
static_assert(NumAccess == 1,
"Multiple access is not supported when using multi-block");
// each M/N blocks share the same data
return tile_distribution_encoding<
@@ -193,7 +209,7 @@ struct WarpGemmAttributeMfmaIterateK
}
else if constexpr(1 < kMNBlock && kNMBlock == 1)
{
static_assert(AttrNumAccessV == 1,
static_assert(NumAccess == 1,
"Multiple access is not supported when using multi-block");
// single block to multi-block thread mapping
return tile_distribution_encoding<
@@ -207,6 +223,18 @@ struct WarpGemmAttributeMfmaIterateK
}
}
template <index_t NumAccess = AttrNumAccessV>
CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
{
return get_warp_dstr_encoding<Impl::kAMLane, Impl::kAMBlock, Impl::kBNBlock, NumAccess>();
}
template <index_t NumAccess = AttrNumAccessV>
CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
{
return get_warp_dstr_encoding<Impl::kBNLane, Impl::kBNBlock, Impl::kAMBlock, NumAccess>();
}
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
@@ -245,10 +273,12 @@ struct WarpGemmAttributeMfmaIterateK
}
}
using AWarpDstrEncoding =
decltype(get_warp_dstr_encoding<Impl::kAMLane, Impl::kAMBlock, Impl::kBNBlock>());
using BWarpDstrEncoding =
decltype(get_warp_dstr_encoding<Impl::kBNLane, Impl::kBNBlock, Impl::kAMBlock>());
template <index_t NumAccess = AttrNumAccessV>
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding<NumAccess>());
template <index_t NumAccess = AttrNumAccessV>
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding<NumAccess>());
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
// c_vec += a_vec * b_vec
@@ -327,10 +357,23 @@ struct WarpGemmAttributeMfmaTransposedCDistribution
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding =
typename WarpGemmAttributeMfma<Impl, AttrNumAccess>::BWarpDstrEncoding;
using BWarpDstrEncoding =
typename WarpGemmAttributeMfma<Impl, AttrNumAccess>::AWarpDstrEncoding;
template <index_t NumAccess = AttrNumAccessV>
CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
{
return WarpGemmAttributeMfma<Impl, AttrNumAccess>::template get_bwarp_dstr_encoding<NumAccess>();
}
template <index_t NumAccess = AttrNumAccessV>
CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
{
return WarpGemmAttributeMfma<Impl, AttrNumAccess>::template get_awarp_dstr_encoding<NumAccess>();
}
template <index_t NumAccess = AttrNumAccessV>
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding<NumAccess>());
template <index_t NumAccess = AttrNumAccessV>
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding<NumAccess>());
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
@@ -384,6 +427,7 @@ struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
template <index_t NumAccess = 1>
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
@@ -392,6 +436,7 @@ struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB
sequence<2>,
sequence<1>>;
#if 0
template <index_t NumAccess = 1>
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
@@ -414,6 +459,7 @@ struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB
sequence<0, 2>>;
#else
// TODO: more test not only 32x32
template <index_t NumAccess = 1>
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
@@ -459,8 +505,9 @@ template <typename WarpGemmAttributeMfmaImpl_,
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
static constexpr auto AttrNumAccess = AttrNumAccess_;
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
static constexpr auto AttrNumAccess = AttrNumAccess_;
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
// swap A and B
using ADataType = typename Impl::BDataType;
@@ -521,10 +568,12 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution
}
}
using AWarpDstrEncoding =
typename WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::BWarpDstrEncoding;
using BWarpDstrEncoding =
typename WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::AWarpDstrEncoding;
template <index_t NumAccess = AttrNumAccessV>
using AWarpDstrEncoding = typename WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
template BWarpDstrEncoding<NumAccess>;
template <index_t NumAccess = AttrNumAccessV>
using BWarpDstrEncoding = typename WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
template AWarpDstrEncoding<NumAccess>;
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
// c_vec += a_vec * b_vec
@@ -603,6 +652,7 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
template <index_t NumAccess = 1>
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
@@ -611,6 +661,7 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB
sequence<2>,
sequence<1>>;
#if 0
template <index_t NumAccess = 1>
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
@@ -633,6 +684,7 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB
sequence<0, 2>>;
#else
// TODO: more test not only 32x32
template <index_t NumAccess = 1>
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
@@ -729,6 +781,7 @@ struct WarpGemmAttributeMfmaIterateK_SwizzleA
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
template <index_t NumAccess = 1>
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
@@ -741,6 +794,7 @@ struct WarpGemmAttributeMfmaIterateK_SwizzleA
sequence<2>,
sequence<1>>;
template <index_t NumAccess = 1>
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,

View File

@@ -67,6 +67,10 @@ template <typename WarpGemmAttributeWmmaImpl_, bool kTransC = false>
struct WarpGemmAttributeWmma
{
using Impl = remove_cvref_t<WarpGemmAttributeWmmaImpl_>;
// AttrNumAccessV is required for compatibility with the block GEMM, and is currently ignored
// within WarpGemmAttributeWmma
static constexpr auto AttrNumAccess = WGAttrNumAccessEnum::Single;
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
// When kTransC is true and A/B types differ, we need an impl with swapped types
using TransposedImpl =
@@ -99,8 +103,22 @@ struct WarpGemmAttributeWmma
// 16 bit input, kAMLane = 16, kABK0PerLane = 4, kABKLane = 2, kABK1PerLane = 2
// 8 bit input, kAMLane = 16, kABK0PerLane = 2, kABKLane = 2, kABK1PerLane = 4
using AWarpDstrEncoding = typename AWarpDstrEncodingTrait<Impl>::type;
using BWarpDstrEncoding = typename BWarpDstrEncodingTrait<Impl>::type;
template <index_t NumAccess = AttrNumAccessV>
static constexpr auto get_awarp_dstr_encoding()
{
return typename AWarpDstrEncodingTrait<Impl>::type{};
}
template <index_t NumAccess = AttrNumAccessV>
static constexpr auto get_bwarp_dstr_encoding()
{
return typename BWarpDstrEncodingTrait<Impl>::type{};
}
template <index_t NumAccess = 1>
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding<NumAccess>());
template <index_t NumAccess = 1>
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding<NumAccess>());
// kCM0PerLane = 1, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16
using CWarpDstrEncoding =

View File

@@ -25,8 +25,8 @@ struct WarpGemmImpl
using BDataType = typename WarpGemmAttribute::BDataType;
using CDataType = typename WarpGemmAttribute::CDataType;
using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding;
using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding;
using AWarpDstrEncoding = typename WarpGemmAttribute::template AWarpDstrEncoding<>;
using BWarpDstrEncoding = typename WarpGemmAttribute::template BWarpDstrEncoding<>;
using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding;
using AWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(AWarpDstrEncoding{}))>;