[CK_TILE] Fix flatmm on gfx11 and gfx12 (#2790)

1. Correct shuffle_b and MakeBFlatDramTileDistribution according to WMMA warp layout
2. Add FlatmmConfig16_Wmma for gfx11 and gfx12

[ROCm/composable_kernel commit: df4ee556d6]
This commit is contained in:
linqunAMD
2025-09-10 08:28:00 +08:00
committed by GitHub
parent f7ffd111ee
commit 7022f756e6
14 changed files with 224 additions and 67 deletions

View File

@@ -304,6 +304,14 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};
template <typename PrecType>
struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill<PrecType>
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmTypeConfig;

View File

@@ -91,7 +91,11 @@ int main(int argc, char* argv[])
try
{
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigPreshufflePrefill_Wmma>(arg_parser);
#else
return !run_gemm_example<GemmConfigPreshufflePrefill>(arg_parser);
#endif
}
catch(const std::runtime_error& e)
{

View File

@@ -176,16 +176,43 @@ template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
if(ck_tile::is_gfx12_supported())
{
// TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase<gfx12>
constexpr int divisor = 2;
constexpr int kABK0PerLane = 2;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
kABK0PerLane,
GemmConfig::K_Warp_Tile / divisor / kABK0PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}
else
{
int divisor = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
}
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
}
template <typename GemmConfig, typename T>

View File

@@ -190,6 +190,30 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr bool kPadK = true;
};
template <typename PrecType>
struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 32 / sizeof(PrecType);
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool kPadK = true;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
@@ -266,16 +290,43 @@ template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
if(ck_tile::is_gfx12_supported())
{
// TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase<gfx12>
constexpr int divisor = 2;
constexpr int kABK0PerLane = 2;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
kABK0PerLane,
GemmConfig::K_Warp_Tile / divisor / kABK0PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}
else
{
int divisor = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
}
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
}
template <typename GemmConfig,

View File

@@ -230,5 +230,9 @@ int run_grouped_gemm_example(int argc, char* argv[])
}
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_gemm_example<GemmConfigPreshuffleDecode_Wmma>(argc, argv);
#else
return !run_grouped_gemm_example<GemmConfigPreshuffleDecode>(argc, argv);
#endif
}

View File

@@ -268,6 +268,9 @@ int main(int argc, char* argv[])
try
{
#if defined(CK_TILE_USE_WMMA)
return !run_flatmm_example<FlatmmConfig16_Wmma>(argc, argv);
#else
int warp_tile = arg_parser.get_int("warp_tile");
if(warp_tile == 0)
{
@@ -285,6 +288,7 @@ int main(int argc, char* argv[])
{
return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
}
#endif
}
catch(const std::runtime_error& e)
{

View File

@@ -86,6 +86,14 @@ struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
};
template <typename DataType>
struct FlatmmConfig16_Wmma : public FlatmmConfig16<DataType>
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename ADataType>
struct GemmBasicTypeConfig;
@@ -183,8 +191,10 @@ auto create_args(int argc, char* argv[])
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
#if !defined(CK_TILE_USE_WMMA)
.insert(
"warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)")
#endif
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "flatmm_basic.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);

View File

@@ -43,15 +43,40 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int divisor = ck_tile::is_wave32() ? (FlatmmConfig::N_Warp_Tile == 32 ? 1 : 2)
: (FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4);
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
k_ / FlatmmConfig::K_Warp_Tile,
divisor,
FlatmmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
if(ck_tile::is_gfx12_supported())
{
// TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase<gfx12>
constexpr int divisor = 2;
constexpr int kABK0PerLane = 2;
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
k_ / FlatmmConfig::K_Warp_Tile,
divisor,
kABK0PerLane,
FlatmmConfig::K_Warp_Tile / divisor / kABK0PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}
else
{
int divisor = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
}
else
{
assert(is_wave32() == false);
divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
k_ / FlatmmConfig::K_Warp_Tile,
divisor,
FlatmmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>