[CK_TILE] Fix flatmm on gfx11 and gfx12 (#2790)

1. Correct shuffle_b and MakeBFlatDramTileDistribution according to WMMA warp layout
2. Add FlatmmConfig16_Wmma for gfx11 and gfx12
This commit is contained in:
linqunAMD
2025-09-10 08:28:00 +08:00
committed by GitHub
parent 82890192dd
commit df4ee556d6
14 changed files with 224 additions and 67 deletions

View File

@@ -86,6 +86,14 @@ struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
};
template <typename DataType>
struct FlatmmConfig16_Wmma : public FlatmmConfig16<DataType>
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename ADataType>
struct GemmBasicTypeConfig;
@@ -183,8 +191,10 @@ auto create_args(int argc, char* argv[])
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
#if !defined(CK_TILE_USE_WMMA)
.insert(
"warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)")
#endif
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "flatmm_basic.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);