[CK_TILE] Fix flatmm on gfx11 and gfx12 (#2790)

1. Correct shuffle_b and MakeBFlatDramTileDistribution according to WMMA warp layout
2. Add FlatmmConfig16_Wmma for gfx11 and gfx12
This commit is contained in:
linqunAMD
2025-09-10 08:28:00 +08:00
committed by GitHub
parent 82890192dd
commit df4ee556d6
14 changed files with 224 additions and 67 deletions

View File

@@ -304,6 +304,14 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};
template <typename PrecType>
struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill<PrecType>
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmTypeConfig;

View File

@@ -91,7 +91,11 @@ int main(int argc, char* argv[])
try
{
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigPreshufflePrefill_Wmma>(arg_parser);
#else
return !run_gemm_example<GemmConfigPreshufflePrefill>(arg_parser);
#endif
}
catch(const std::runtime_error& e)
{

View File

@@ -176,16 +176,43 @@ template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
if(ck_tile::is_gfx12_supported())
{
// TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase<gfx12>
constexpr int divisor = 2;
constexpr int kABK0PerLane = 2;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
kABK0PerLane,
GemmConfig::K_Warp_Tile / divisor / kABK0PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}
else
{
int divisor = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
}
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
}
template <typename GemmConfig, typename T>