[rocm-libraries] ROCm/rocm-libraries#4821 (commit 9456e0f)

[CK TILE] Refactor MX FLATMM example

Refactor the MX FLATMM example to support more pipelines
across different architectures. This work facilitates the NPI team
roadmap.
This commit is contained in:
Andriy Roshchenko
2026-02-27 23:21:39 +00:00
committed by assistant-librarian[bot]
parent 711374fcab
commit b661eab573
7 changed files with 199 additions and 273 deletions

View File

@@ -20,7 +20,7 @@ static constexpr inline auto is_row_major(Layout layout_)
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename FlatmmConfig,
template <typename MXFlatmmArchTraits,
typename ADataType,
typename BDataType,
typename DsDatatype,
@@ -49,6 +49,8 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
int n_warmup,
int n_repeat)
{
using FlatmmConfig = typename MXFlatmmArchTraits::Config;
ck_tile::ScaleFlatmmHostArgs<ScaleA, ScaleB> args = {a_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
{},
@@ -99,7 +101,7 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
constexpr auto has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_num_v = tail_num_.value;
auto invoke_splitk_path = [&](auto split_k_) {
return mx_flatmm_calc<FlatmmConfig,
return mx_flatmm_calc<MXFlatmmArchTraits,
ADataType,
BDataType,
DsDatatype,
@@ -157,22 +159,22 @@ auto create_args(int argc, char* argv[])
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert(
"mx_prec", "fp4xfp4", "data type for activation and weight, support: fp4xfp4, fp8xfp8")
.insert("mx_prec",
"fp4xfp4",
"data type for activation and weight, support: fp4xfp4, fp6xfp6, fp8xfp8, fp8xfp4 "
"and fp4xfp8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:constant(1)")
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
.insert("warp_tile",
"0",
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
.insert("warp_tile", "0", "0: 16x16x128 on gfx950.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <ck_tile::index_t N_Warp_Tile, typename dtype>
template <ck_tile::index_t NLane, typename dtype>
auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
{
auto src_lengths = src.get_lengths();
@@ -181,8 +183,8 @@ auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
constexpr int packed_size = ck_tile::numeric_traits<dtype>::PackedSize;
int KPack =
std::is_same_v<dtype, ck_tile::pk_fp6x16_t> ? 32 : 16 * packed_size; // fp4/fp6:32 or fp8:16
int NLane = N_Warp_Tile;
int KLane = 64 / NLane;
int KLane = ck_tile::get_warp_size() / NLane;
int K0 = K / (KLane * KPack);
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({N * K}, {1}));
@@ -211,68 +213,10 @@ auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
return shuffled;
}
template <class FlatmmConfig, bool KLast, typename dtype>
auto preShuffleScale(ck_tile::HostTensor<dtype>& src)
{
auto src_lengths = src.get_lengths();
const auto MN = KLast ? src_lengths[0] : src_lengths[1];
const auto K = KLast ? src_lengths[1] : src_lengths[0];
size_t MNXdlPack = 2;
size_t KXdlPack = 2;
size_t XdlMNThread = FlatmmConfig::N_Warp_Tile; // 16
size_t XdlKThread = 64 / XdlMNThread;
const auto MN_Paded = ck_tile::integer_least_multiple(MN, XdlMNThread * MNXdlPack);
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({MN_Paded * K}, {1}));
size_t K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(size_t n = 0; n < MN_Paded; ++n)
{
for(size_t k = 0; k < K; ++k)
{
auto n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
auto tempn = n % (XdlMNThread * MNXdlPack);
auto n1 = tempn % XdlMNThread; // i XdlMNThread
auto n2 = tempn / XdlMNThread; // i MNXdlPack
auto k0 = k / (XdlKThread * KXdlPack); // i KRepeat
auto tempk = k % (XdlKThread * KXdlPack);
auto k1 = tempk % XdlKThread; // i XdlKThread
auto k2 = tempk / XdlKThread; // i KXdlPack
auto outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
if constexpr(KLast)
shuffled(outputIndex) = n < MN ? src(n, k) : dtype{};
else
shuffled(outputIndex) = n < MN ? src(k, n) : dtype{};
}
}
return shuffled;
}
#include "run_mx_flatmm.inc"
int run_mx_flatmm_example(int argc, char* argv[])
int run_mx_flatmm_example(const ck_tile::ArgParser& arg_parser)
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
@@ -281,6 +225,8 @@ int run_mx_flatmm_example(int argc, char* argv[])
std::string b_layout = arg_parser.get_str("b_layout");
int persistent_opt = arg_parser.get_int("persistent");
std::cout << "Using default warptile of 16x16x128." << std::endl;
if(a_layout == "R" && b_layout == "C")
{
if(mx_prec == "fp4" || mx_prec == "fp4xfp4")
@@ -289,8 +235,8 @@ int run_mx_flatmm_example(int argc, char* argv[])
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
ck_tile::fp16_t,
MXfp4_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
MXFlatmm_GFX950_FP4FP4_Traits,
false>(arg_parser, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only non-persistent kernels are supported currently!");
}
@@ -300,8 +246,8 @@ int run_mx_flatmm_example(int argc, char* argv[])
return run_mx_flatmm_with_layouts<ck_tile::pk_fp6x16_t,
ck_tile::pk_fp6x16_t,
ck_tile::fp16_t,
MXfp6_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
MXFlatmm_GFX950_FP6FP6_Traits,
false>(arg_parser, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
@@ -311,8 +257,8 @@ int run_mx_flatmm_example(int argc, char* argv[])
return run_mx_flatmm_with_layouts<ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::fp16_t,
MXfp8_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
MXFlatmm_GFX950_FP8FP8_Traits,
false>(arg_parser, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
@@ -322,8 +268,8 @@ int run_mx_flatmm_example(int argc, char* argv[])
return run_mx_flatmm_with_layouts<ck_tile::fp8_t,
ck_tile::pk_fp4_t,
ck_tile::fp16_t,
MXf8f4_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
MXFlatmm_GFX950_FP8FP4_Traits,
false>(arg_parser, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
@@ -333,8 +279,8 @@ int run_mx_flatmm_example(int argc, char* argv[])
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
ck_tile::fp8_t,
ck_tile::fp16_t,
MXf4f8_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
MXFlatmm_GFX950_FP4FP8_Traits,
false>(arg_parser, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
@@ -359,7 +305,7 @@ int main(int argc, char* argv[])
int warp_tile = arg_parser.get_int("warp_tile");
if(warp_tile == 0)
{
return run_mx_flatmm_example(argc, argv);
return run_mx_flatmm_example(arg_parser);
}
else if(warp_tile == 1)
{