[CK_TILE] Fix flatmm on gfx11 and gfx12 (#2790)

1. Correct shuffle_b and MakeBFlatDramTileDistribution according to WMMA warp layout
2. Add FlatmmConfig16_Wmma for gfx11 and gfx12

[ROCm/composable_kernel commit: df4ee556d6]
This commit is contained in:
linqunAMD
2025-09-10 08:28:00 +08:00
committed by GitHub
parent b42e7a85b2
commit 9f852da7c8
14 changed files with 224 additions and 67 deletions

View File

@@ -304,6 +304,14 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};
template <typename PrecType>
struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill<PrecType>
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmTypeConfig;

View File

@@ -91,7 +91,11 @@ int main(int argc, char* argv[])
try
{
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigPreshufflePrefill_Wmma>(arg_parser);
#else
return !run_gemm_example<GemmConfigPreshufflePrefill>(arg_parser);
#endif
}
catch(const std::runtime_error& e)
{

View File

@@ -176,16 +176,43 @@ template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
if(ck_tile::is_gfx12_supported())
{
// TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase<gfx12>
constexpr int divisor = 2;
constexpr int kABK0PerLane = 2;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
kABK0PerLane,
GemmConfig::K_Warp_Tile / divisor / kABK0PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}
else
{
int divisor = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
}
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
}
template <typename GemmConfig, typename T>

View File

@@ -190,6 +190,30 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr bool kPadK = true;
};
template <typename PrecType>
struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 32 / sizeof(PrecType);
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool kPadK = true;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
@@ -266,16 +290,43 @@ template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
if(ck_tile::is_gfx12_supported())
{
// TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase<gfx12>
constexpr int divisor = 2;
constexpr int kABK0PerLane = 2;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
kABK0PerLane,
GemmConfig::K_Warp_Tile / divisor / kABK0PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}
else
{
int divisor = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
}
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
}
template <typename GemmConfig,

View File

@@ -230,5 +230,9 @@ int run_grouped_gemm_example(int argc, char* argv[])
}
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_gemm_example<GemmConfigPreshuffleDecode_Wmma>(argc, argv);
#else
return !run_grouped_gemm_example<GemmConfigPreshuffleDecode>(argc, argv);
#endif
}

View File

@@ -268,6 +268,9 @@ int main(int argc, char* argv[])
try
{
#if defined(CK_TILE_USE_WMMA)
return !run_flatmm_example<FlatmmConfig16_Wmma>(argc, argv);
#else
int warp_tile = arg_parser.get_int("warp_tile");
if(warp_tile == 0)
{
@@ -285,6 +288,7 @@ int main(int argc, char* argv[])
{
return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
}
#endif
}
catch(const std::runtime_error& e)
{

View File

@@ -86,6 +86,14 @@ struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
};
template <typename DataType>
struct FlatmmConfig16_Wmma : public FlatmmConfig16<DataType>
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename ADataType>
struct GemmBasicTypeConfig;
@@ -183,8 +191,10 @@ auto create_args(int argc, char* argv[])
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
#if !defined(CK_TILE_USE_WMMA)
.insert(
"warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)")
#endif
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "flatmm_basic.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);

View File

@@ -43,15 +43,40 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int divisor = ck_tile::is_wave32() ? (FlatmmConfig::N_Warp_Tile == 32 ? 1 : 2)
: (FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4);
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
k_ / FlatmmConfig::K_Warp_Tile,
divisor,
FlatmmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
if(ck_tile::is_gfx12_supported())
{
// TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase<gfx12>
constexpr int divisor = 2;
constexpr int kABK0PerLane = 2;
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
k_ / FlatmmConfig::K_Warp_Tile,
divisor,
kABK0PerLane,
FlatmmConfig::K_Warp_Tile / divisor / kABK0PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}
else
{
int divisor = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
}
else
{
assert(is_wave32() == false);
divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
k_ / FlatmmConfig::K_Warp_Tile,
divisor,
FlatmmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>

View File

@@ -127,7 +127,10 @@ struct FlatmmKernel
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize()
{
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
}
CK_TILE_HOST static constexpr KernelArgs
MakeKernelArgs(const FlatmmHostArgs<NumDTensor>& hostArgs)

View File

@@ -185,11 +185,11 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
}
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&

View File

@@ -237,8 +237,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
{
using TileShape = typename Problem::BlockGemmShape;
using TileShape = typename Problem::BlockGemmShape;
#if defined(__gfx11__)
constexpr index_t scale = 4;
#else
constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
#endif
if constexpr(TileShape::WarpTile::at(I1) == 32)
{
return TileShape::WarpTile::at(I2) * scale / 2;
@@ -342,7 +346,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
@@ -350,8 +354,13 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
#if defined(__gfx11__)
constexpr index_t KRepeatInWave = 2;
#else
constexpr index_t KRepeatInWave = 1;
#endif
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
@@ -362,16 +371,15 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t NRepeat = 1;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat>, // ?
sequence<WaveRepeat, KRepeatInWave>, // ?
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
// <repeat, vec_load>
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});

View File

@@ -89,14 +89,19 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
{
using TileShape = typename Problem::BlockGemmShape;
#if defined(__gfx11__)
constexpr index_t scale = 4;
#else
constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
#endif
if constexpr(TileShape::WarpTile::at(I1) == 32)
{
return TileShape::WarpTile::at(I2) / 2;
return TileShape::WarpTile::at(I2) * scale / 2;
}
else
{
static_assert(TileShape::WarpTile::at(I1) == 16);
return TileShape::WarpTile::at(I2) / 4;
return TileShape::WarpTile::at(I2) * scale / 4;
}
}
@@ -192,7 +197,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
@@ -200,8 +205,13 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
#if defined(__gfx11__)
constexpr index_t KRepeatInWave = 2;
#else
constexpr index_t KRepeatInWave = 1;
#endif
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
@@ -212,16 +222,15 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
constexpr index_t NRepeat = 1;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat>, // ?
sequence<WaveRepeat, KRepeatInWave>, // ?
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
// <repeat, vec_load>
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});

View File

@@ -189,11 +189,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
}
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&

View File

@@ -146,10 +146,14 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
static constexpr index_t mfma_per_wg = 1;
#endif
static constexpr index_t dsread_per_wg =
WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize;
static_assert((WG::kM * WG::kK * sizeof(ADataType) / WaveSize) % Problem::VectorLoadSize == 0);
static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp;
max(index_t(WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1);
#if defined(__HIP_DEVICE_COMPILE__)
static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) %
Problem::VectorLoadSize ==
0);
#endif
static constexpr index_t dsread_num_perK =
WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize / Problem::VectorLoadSize;
static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp);
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
static constexpr index_t Aload_num_perK = dswrite_num_perK;
@@ -499,12 +503,12 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename AElementFunction>
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>>,