[CK_TILE] Refine warp_gemm_attribute_mfma (#3272)

This commit is contained in:
Yi DING
2025-11-26 10:57:15 +08:00
committed by GitHub
parent c7dce2ac29
commit 8fa90025d0

View File

@@ -47,8 +47,9 @@ struct WarpGemmAttributeMfma
template <index_t kMNLane>
static constexpr auto get_warp_dstr_encoding()
{
static_assert(kKPerThread % AttrNumAccessV == 0,
"kKPerThread must be divisible by NumAccess");
if constexpr(AttrNumAccessV == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
@@ -56,11 +57,7 @@ struct WarpGemmAttributeMfma
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else
{
static_assert(kKPerThread % AttrNumAccessV == 0,
"kKPerThread must be divisible by NumAccess");
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>,
@@ -69,7 +66,6 @@ struct WarpGemmAttributeMfma
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
@@ -155,28 +151,25 @@ struct WarpGemmAttributeMfmaIterateK
static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
"Multi-block on both M & N directions is not supported");
CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
template <index_t kMNLane, index_t kMNBlock, index_t kNMBlock>
CK_TILE_DEVICE static constexpr auto get_warp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
if constexpr(kMNBlock == 1 && kNMBlock == 1)
{
static_assert(kKPerThread % AttrNumAccessV == 0,
"kKPerThread must be divisible by NumAccess");
if constexpr(AttrNumAccessV == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else
{
static_assert(kKPerThread % AttrNumAccessV == 0,
"kKPerThread must be divisible by NumAccess");
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane>,
tuple<sequence<kMNLane>,
sequence<AttrNumAccessV,
Impl::kABKLane,
Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
@@ -184,30 +177,28 @@ struct WarpGemmAttributeMfmaIterateK
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
else if constexpr(kMNBlock == 1 && 1 < kNMBlock)
{
static_assert(AttrNumAccessV == 1,
"Multiple access is not supported when using multi-block");
// each M blocks share the same data
// each M/N blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kBNBlock>,
tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
sequence<kNMBlock>,
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
else if constexpr(1 < kMNBlock && kNMBlock == 1)
{
static_assert(AttrNumAccessV == 1,
"Multiple access is not supported when using multi-block");
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMBlock, Impl::kAMLane>,
tuple<sequence<kMNBlock, kMNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
@@ -216,68 +207,6 @@ struct WarpGemmAttributeMfmaIterateK
}
}
CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
if constexpr(AttrNumAccessV == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else
{
static_assert(kKPerThread % AttrNumAccessV == 0,
"kKPerThread must be divisible by NumAccess");
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>,
sequence<AttrNumAccessV,
Impl::kABKLane,
Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
static_assert(AttrNumAccessV == 1,
"Multiple access is not supported when using multi-block");
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNBlock, Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
static_assert(AttrNumAccessV == 1,
"Multiple access is not supported when using multi-block");
// each N blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kAMBlock>,
tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
}
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
@@ -316,10 +245,10 @@ struct WarpGemmAttributeMfmaIterateK
}
}
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding());
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding());
using AWarpDstrEncoding =
decltype(get_warp_dstr_encoding<Impl::kAMLane, Impl::kAMBlock, Impl::kBNBlock>());
using BWarpDstrEncoding =
decltype(get_warp_dstr_encoding<Impl::kBNLane, Impl::kBNBlock, Impl::kAMBlock>());
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
// c_vec += a_vec * b_vec
@@ -329,17 +258,7 @@ struct WarpGemmAttributeMfmaIterateK
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
});
static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
}
template <index_t iKIter, bool post_nop_ = false>
@@ -354,14 +273,12 @@ struct WarpGemmAttributeMfmaIterateK
static_assert(iKIter < kKIter);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec
@@ -377,13 +294,7 @@ struct WarpGemmAttributeMfmaIterateK
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
// c += a * b
static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]);
});
static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
return c_vec;
}
@@ -416,35 +327,10 @@ struct WarpGemmAttributeMfmaTransposedCDistribution
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
template <index_t kMNLane>
static constexpr auto get_warp_dstr_encoding()
{
if constexpr(AttrNumAccessV == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else
{
static_assert(kKPerThread % AttrNumAccessV == 0,
"kKPerThread must be divisible by NumAccess");
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>,
sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
using AWarpDstrEncoding =
typename WarpGemmAttributeMfma<Impl, AttrNumAccess>::BWarpDstrEncoding;
using BWarpDstrEncoding =
typename WarpGemmAttributeMfma<Impl, AttrNumAccess>::AWarpDstrEncoding;
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
@@ -597,18 +483,6 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution
static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
"Multi-block on both M & N directions is not supported");
CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
{
return WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
get_bwarp_dstr_encoding();
}
CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
{
return WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
get_awarp_dstr_encoding();
}
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
@@ -647,30 +521,20 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution
}
}
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding());
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding());
using AWarpDstrEncoding =
typename WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::BWarpDstrEncoding;
using BWarpDstrEncoding =
typename WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::AWarpDstrEncoding;
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
template <bool post_nop_ = false>
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
// swap A and B, value and type
static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
});
static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
}
template <index_t iKIter, bool post_nop_ = false>
@@ -686,14 +550,12 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution
static_assert(iKIter < kKIter);
// swap A and B, value and type
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec
@@ -708,13 +570,7 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter]);
});
static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
return c_vec;
}
@@ -805,17 +661,8 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
// swap A and B, value and type
static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
});
static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
}
template <index_t iKIter, bool post_nop_ = false>
@@ -830,14 +677,12 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB
static_assert(iKIter < kKIter);
// swap A and B, value and type
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec
@@ -852,13 +697,7 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter]);
});
static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
return c_vec;
}
@@ -926,17 +765,7 @@ struct WarpGemmAttributeMfmaIterateK_SwizzleA
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
});
static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
}
template <index_t iKIter, bool post_nop_ = false>
@@ -951,14 +780,12 @@ struct WarpGemmAttributeMfmaIterateK_SwizzleA
static_assert(iKIter < kKIter);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec
@@ -972,13 +799,7 @@ struct WarpGemmAttributeMfmaIterateK_SwizzleA
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]);
});
static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
return c_vec;
}