mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Add NumAccess as a template parameter to WarpGemmAttributeMfma::get_warp_dstr_encoding
This commit is contained in:
@@ -139,6 +139,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
|
||||
template <bool convert = false>
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
@@ -158,12 +159,18 @@ struct BlockUniversalGemmAsBsCr
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
using Attr = typename WarpGemm::WarpGemmAttribute;
|
||||
constexpr auto NumAccessA =
|
||||
convert ? Attr::AttrNumAccessV * sizeof(ADataType) / sizeof(ComputeDataType)
|
||||
: Attr::AttrNumAccessV;
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
a_block_outer_dstr_encoding,
|
||||
WarpGemm::WarpGemmAttribute::template get_awarp_dstr_encoding<NumAccessA>());
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
|
||||
template <bool convert = false>
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
@@ -183,8 +190,13 @@ struct BlockUniversalGemmAsBsCr
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
using Attr = typename WarpGemm::WarpGemmAttribute;
|
||||
constexpr auto NumAccessB =
|
||||
convert ? Attr::AttrNumAccessV * sizeof(BDataType) / sizeof(ComputeDataType)
|
||||
: Attr::AttrNumAccessV;
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
b_block_outer_dstr_encoding,
|
||||
WarpGemm::WarpGemmAttribute::template get_bwarp_dstr_encoding<NumAccessB>());
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
|
||||
@@ -440,10 +440,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
|
||||
|
||||
// Tile distribution for load from lds
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
constexpr bool is_load_tr = is_a_load_tr_v || is_b_load_tr_v;
|
||||
constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(
|
||||
BlockGemm::template MakeABlockDistributionEncode<is_load_tr>());
|
||||
constexpr auto b_lds_load_tile_distr = make_static_tile_distribution(
|
||||
BlockGemm::template MakeBBlockDistributionEncode<is_load_tr>());
|
||||
|
||||
// A DRAM tile window for load
|
||||
// A LDS tile window for store
|
||||
|
||||
@@ -44,12 +44,12 @@ struct WarpGemmAttributeMfma
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
|
||||
|
||||
template <index_t kMNLane>
|
||||
template <index_t kMNLane, index_t NumAccess>
|
||||
static constexpr auto get_warp_dstr_encoding()
|
||||
{
|
||||
static_assert(kKPerThread % AttrNumAccessV == 0,
|
||||
static_assert(NumAccess != 0 && kKPerThread % NumAccess == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
if constexpr(AttrNumAccessV == 1)
|
||||
if constexpr(NumAccess == 1)
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
@@ -61,14 +61,30 @@ struct WarpGemmAttributeMfma
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>,
|
||||
sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
|
||||
sequence<NumAccess, Impl::kABKLane, Impl::kABKPerLane / NumAccess>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
|
||||
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
|
||||
|
||||
template <index_t NumAccess = AttrNumAccessV>
|
||||
static constexpr auto get_awarp_dstr_encoding()
|
||||
{
|
||||
return get_warp_dstr_encoding<Impl::kAMLane, NumAccess>();
|
||||
}
|
||||
|
||||
template <index_t NumAccess = AttrNumAccessV>
|
||||
static constexpr auto get_bwarp_dstr_encoding()
|
||||
{
|
||||
return get_warp_dstr_encoding<Impl::kBNLane, NumAccess>();
|
||||
}
|
||||
|
||||
template <index_t NumAccess = AttrNumAccessV>
|
||||
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding<NumAccess>());
|
||||
|
||||
template <index_t NumAccess = AttrNumAccessV>
|
||||
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding<NumAccess>());
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
@@ -151,14 +167,14 @@ struct WarpGemmAttributeMfmaIterateK
|
||||
static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
|
||||
"Multi-block on both M & N directions is not supported");
|
||||
|
||||
template <index_t kMNLane, index_t kMNBlock, index_t kNMBlock>
|
||||
template <index_t kMNLane, index_t kMNBlock, index_t kNMBlock, index_t NumAccess = AttrNumAccessV>
|
||||
CK_TILE_DEVICE static constexpr auto get_warp_dstr_encoding()
|
||||
{
|
||||
if constexpr(kMNBlock == 1 && kNMBlock == 1)
|
||||
{
|
||||
static_assert(kKPerThread % AttrNumAccessV == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
if constexpr(AttrNumAccessV == 1)
|
||||
if constexpr(NumAccess == 1)
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
@@ -170,9 +186,9 @@ struct WarpGemmAttributeMfmaIterateK
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>,
|
||||
sequence<AttrNumAccessV,
|
||||
sequence<NumAccess,
|
||||
Impl::kABKLane,
|
||||
Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
|
||||
Impl::kABKPerLane * kKIter / NumAccess>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
@@ -180,7 +196,7 @@ struct WarpGemmAttributeMfmaIterateK
|
||||
}
|
||||
else if constexpr(kMNBlock == 1 && 1 < kNMBlock)
|
||||
{
|
||||
static_assert(AttrNumAccessV == 1,
|
||||
static_assert(NumAccess == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
// each M/N blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
@@ -193,7 +209,7 @@ struct WarpGemmAttributeMfmaIterateK
|
||||
}
|
||||
else if constexpr(1 < kMNBlock && kNMBlock == 1)
|
||||
{
|
||||
static_assert(AttrNumAccessV == 1,
|
||||
static_assert(NumAccess == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
@@ -207,6 +223,18 @@ struct WarpGemmAttributeMfmaIterateK
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NumAccess = AttrNumAccessV>
|
||||
CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
|
||||
{
|
||||
return get_warp_dstr_encoding<Impl::kAMLane, Impl::kAMBlock, Impl::kBNBlock, NumAccess>();
|
||||
}
|
||||
|
||||
template <index_t NumAccess = AttrNumAccessV>
|
||||
CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
|
||||
{
|
||||
return get_warp_dstr_encoding<Impl::kBNLane, Impl::kBNBlock, Impl::kAMBlock, NumAccess>();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
@@ -245,10 +273,12 @@ struct WarpGemmAttributeMfmaIterateK
|
||||
}
|
||||
}
|
||||
|
||||
using AWarpDstrEncoding =
|
||||
decltype(get_warp_dstr_encoding<Impl::kAMLane, Impl::kAMBlock, Impl::kBNBlock>());
|
||||
using BWarpDstrEncoding =
|
||||
decltype(get_warp_dstr_encoding<Impl::kBNLane, Impl::kBNBlock, Impl::kAMBlock>());
|
||||
template <index_t NumAccess = AttrNumAccessV>
|
||||
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding<NumAccess>());
|
||||
|
||||
template <index_t NumAccess = AttrNumAccessV>
|
||||
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding<NumAccess>());
|
||||
|
||||
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
@@ -327,10 +357,23 @@ struct WarpGemmAttributeMfmaTransposedCDistribution
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
|
||||
|
||||
using AWarpDstrEncoding =
|
||||
typename WarpGemmAttributeMfma<Impl, AttrNumAccess>::BWarpDstrEncoding;
|
||||
using BWarpDstrEncoding =
|
||||
typename WarpGemmAttributeMfma<Impl, AttrNumAccess>::AWarpDstrEncoding;
|
||||
template <index_t NumAccess = AttrNumAccessV>
|
||||
CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
|
||||
{
|
||||
return WarpGemmAttributeMfma<Impl, AttrNumAccess>::template get_bwarp_dstr_encoding<NumAccess>();
|
||||
}
|
||||
|
||||
template <index_t NumAccess = AttrNumAccessV>
|
||||
CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
|
||||
{
|
||||
return WarpGemmAttributeMfma<Impl, AttrNumAccess>::template get_awarp_dstr_encoding<NumAccess>();
|
||||
}
|
||||
|
||||
template <index_t NumAccess = AttrNumAccessV>
|
||||
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding<NumAccess>());
|
||||
|
||||
template <index_t NumAccess = AttrNumAccessV>
|
||||
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding<NumAccess>());
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
@@ -384,6 +427,7 @@ struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
|
||||
|
||||
template <index_t NumAccess = 1>
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
@@ -392,6 +436,7 @@ struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
#if 0
|
||||
template <index_t NumAccess = 1>
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
|
||||
@@ -414,6 +459,7 @@ struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB
|
||||
sequence<0, 2>>;
|
||||
#else
|
||||
// TODO: more test not only 32x32
|
||||
template <index_t NumAccess = 1>
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
|
||||
@@ -459,8 +505,9 @@ template <typename WarpGemmAttributeMfmaImpl_,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
|
||||
|
||||
// swap A and B
|
||||
using ADataType = typename Impl::BDataType;
|
||||
@@ -521,10 +568,12 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution
|
||||
}
|
||||
}
|
||||
|
||||
using AWarpDstrEncoding =
|
||||
typename WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::BWarpDstrEncoding;
|
||||
using BWarpDstrEncoding =
|
||||
typename WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::AWarpDstrEncoding;
|
||||
template <index_t NumAccess = AttrNumAccessV>
|
||||
using AWarpDstrEncoding = typename WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
|
||||
template BWarpDstrEncoding<NumAccess>;
|
||||
template <index_t NumAccess = AttrNumAccessV>
|
||||
using BWarpDstrEncoding = typename WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
|
||||
template AWarpDstrEncoding<NumAccess>;
|
||||
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
@@ -603,6 +652,7 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
|
||||
|
||||
template <index_t NumAccess = 1>
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
@@ -611,6 +661,7 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
#if 0
|
||||
template <index_t NumAccess = 1>
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
|
||||
@@ -633,6 +684,7 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
sequence<0, 2>>;
|
||||
#else
|
||||
// TODO: more test not only 32x32
|
||||
template <index_t NumAccess = 1>
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
|
||||
@@ -729,6 +781,7 @@ struct WarpGemmAttributeMfmaIterateK_SwizzleA
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
|
||||
|
||||
template <index_t NumAccess = 1>
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
|
||||
@@ -741,6 +794,7 @@ struct WarpGemmAttributeMfmaIterateK_SwizzleA
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
template <index_t NumAccess = 1>
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
|
||||
@@ -67,6 +67,10 @@ template <typename WarpGemmAttributeWmmaImpl_, bool kTransC = false>
|
||||
struct WarpGemmAttributeWmma
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeWmmaImpl_>;
|
||||
// AttrNumAccessV is required for compatibility with the block GEMM, and is currently ignored
|
||||
// within WarpGemmAttributeWmma
|
||||
static constexpr auto AttrNumAccess = WGAttrNumAccessEnum::Single;
|
||||
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
|
||||
|
||||
// When kTransC is true and A/B types differ, we need an impl with swapped types
|
||||
using TransposedImpl =
|
||||
@@ -99,8 +103,22 @@ struct WarpGemmAttributeWmma
|
||||
|
||||
// 16 bit input, kAMLane = 16, kABK0PerLane = 4, kABKLane = 2, kABK1PerLane = 2
|
||||
// 8 bit input, kAMLane = 16, kABK0PerLane = 2, kABKLane = 2, kABK1PerLane = 4
|
||||
using AWarpDstrEncoding = typename AWarpDstrEncodingTrait<Impl>::type;
|
||||
using BWarpDstrEncoding = typename BWarpDstrEncodingTrait<Impl>::type;
|
||||
template <index_t NumAccess = AttrNumAccessV>
|
||||
static constexpr auto get_awarp_dstr_encoding()
|
||||
{
|
||||
return typename AWarpDstrEncodingTrait<Impl>::type{};
|
||||
}
|
||||
|
||||
template <index_t NumAccess = AttrNumAccessV>
|
||||
static constexpr auto get_bwarp_dstr_encoding()
|
||||
{
|
||||
return typename BWarpDstrEncodingTrait<Impl>::type{};
|
||||
}
|
||||
|
||||
template <index_t NumAccess = 1>
|
||||
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding<NumAccess>());
|
||||
template <index_t NumAccess = 1>
|
||||
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding<NumAccess>());
|
||||
|
||||
// kCM0PerLane = 1, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16
|
||||
using CWarpDstrEncoding =
|
||||
|
||||
@@ -25,8 +25,8 @@ struct WarpGemmImpl
|
||||
using BDataType = typename WarpGemmAttribute::BDataType;
|
||||
using CDataType = typename WarpGemmAttribute::CDataType;
|
||||
|
||||
using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding;
|
||||
using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding;
|
||||
using AWarpDstrEncoding = typename WarpGemmAttribute::template AWarpDstrEncoding<>;
|
||||
using BWarpDstrEncoding = typename WarpGemmAttribute::template BWarpDstrEncoding<>;
|
||||
using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding;
|
||||
|
||||
using AWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(AWarpDstrEncoding{}))>;
|
||||
|
||||
Reference in New Issue
Block a user