mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#5082 (commit 9313659)
ck_tile: add gtest unit tests for MX flatmm (gfx950)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
## Summary
- Add correctness unit tests for the MX-format flatmm kernel
(`example/ck_tile/18_flatmm/mxgemm`) under `test/ck_tile/flatmm/`
- Tests cover all five dtype combinations: FP4×FP4, FP8×FP8, FP6×FP6,
FP8×FP4, FP4×FP8
- Tests cover all four kernel dispatch paths (the `has_hot_loop` ×
`tail_num` product):
- `has_hot_loop=false, tail=ODD` (K=256, num_loop=1)
- `has_hot_loop=false, tail=EVEN` (K=512, num_loop=2)
- `has_hot_loop=true, tail=ODD` (K=768, num_loop=3)
- `has_hot_loop=true, tail=EVEN` (K=1024, num_loop=4)
- Remove unsupported `-split_k` CLI option from
`tile_example_mx_flatmm`; the pre-shuffled B layout is incompatible with
K-splitting and the option silently produced wrong results
## Changes
**New files (`test/ck_tile/flatmm/`):**
- `CMakeLists.txt` — builds 40 kernel instances as a shared OBJECT
library, links into 5 per-dtype test executables; forwards
`-DCK_TILE_USE_OCP_FP8` when `CK_USE_OCP_FP8` is ON
- `test_mx_flatmm_base.hpp` — base test fixture with
`run_test_with_validation(M, N, K, kbatch=1)`
- `test_mx_flatmm_fixtures.hpp` — concrete `TestMXFlatmm` typed test
class and type aliases
- `test_mx_flatmm_fp{4fp4,8fp8,6fp6,8fp4,4fp8}.cpp` — per-dtype
`TYPED_TEST_SUITE` files
**Modified files:**
- `example/ck_tile/18_flatmm/mxgemm/mx_flatmm_arch_traits.hpp` — moved
`preShuffleWeight` here (was in `mx_flatmm.cpp`) so it is includeable by
both the example and the tests
- `example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp` / `run_mx_flatmm.inc`
— removed `-split_k` CLI arg, hardcoded `k_batch=1`, fixed `k_split`
formula, updated call sites after `preShuffleWeight` move
- `test/ck_tile/CMakeLists.txt` — added `add_subdirectory(flatmm)`
This commit is contained in:
committed by
assistant-librarian[bot]
parent
2169367735
commit
1a4aa7fd89
@@ -43,7 +43,6 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
ScaleA scale_a,
|
||||
ScaleB scale_b,
|
||||
int n_warmup,
|
||||
@@ -55,7 +54,7 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
b_shuffle_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
1,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@@ -90,8 +89,8 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
|
||||
using BaseFlatmmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t k_split = (K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t k_grain = FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t k_split = (K + k_grain - 1) / k_grain * k_grain;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(k_split);
|
||||
const bool has_hot_loop = BaseFlatmmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseFlatmmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
@@ -100,29 +99,24 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
[&](auto has_hot_loop_, auto tail_num_) {
|
||||
constexpr auto has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_num_v = tail_num_.value;
|
||||
auto invoke_splitk_path = [&](auto split_k_) {
|
||||
return mx_flatmm_calc<MXFlatmmArchTraits,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ScaleA,
|
||||
ScaleB,
|
||||
UsePersistentKernel,
|
||||
CDEElementWise,
|
||||
split_k_.value,
|
||||
has_hot_loop_v,
|
||||
tail_num_v>(
|
||||
args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
};
|
||||
return (args.k_batch == 1) ? invoke_splitk_path(std::false_type{})
|
||||
: invoke_splitk_path(std::true_type{});
|
||||
return mx_flatmm_calc<MXFlatmmArchTraits,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ScaleA,
|
||||
ScaleB,
|
||||
UsePersistentKernel,
|
||||
CDEElementWise,
|
||||
false,
|
||||
has_hot_loop_v,
|
||||
tail_num_v>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
},
|
||||
has_hot_loop,
|
||||
tail_num);
|
||||
@@ -166,7 +160,6 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:constant(1)")
|
||||
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
|
||||
.insert("warp_tile", "0", "0: 16x16x128 on gfx950.");
|
||||
@@ -174,45 +167,6 @@ auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NLane, typename dtype>
|
||||
auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
|
||||
{
|
||||
auto src_lengths = src.get_lengths();
|
||||
const int K = src_lengths[0];
|
||||
const int N = src_lengths[1];
|
||||
constexpr int packed_size = ck_tile::numeric_traits<dtype>::PackedSize;
|
||||
int KPack =
|
||||
std::is_same_v<dtype, ck_tile::pk_fp6x16_t> ? 32 : 16 * packed_size; // fp4/fp6:32 or fp8:16
|
||||
|
||||
int KLane = ck_tile::get_warp_size() / NLane;
|
||||
int K0 = K / (KLane * KPack);
|
||||
|
||||
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({N * K}, {1}));
|
||||
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; k += packed_size)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
int tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
shuffled(outputIndex) = src(k, n);
|
||||
}
|
||||
}
|
||||
return shuffled;
|
||||
}
|
||||
|
||||
#include "run_mx_flatmm.inc"
|
||||
|
||||
int run_mx_flatmm_example(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
@@ -70,6 +70,47 @@ struct MXFlatmmArchTraits
|
||||
|
||||
static constexpr int GetNLane() { return Config::N_Warp_Tile; }
|
||||
|
||||
template <typename dtype>
|
||||
static auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
|
||||
{
|
||||
constexpr ck_tile::index_t NLane = Config::N_Warp_Tile;
|
||||
auto src_lengths = src.get_lengths();
|
||||
const int K = src_lengths[0];
|
||||
const int N = src_lengths[1];
|
||||
constexpr int packed_size = ck_tile::numeric_traits<dtype>::PackedSize;
|
||||
int KPack = std::is_same_v<dtype, ck_tile::pk_fp6x16_t>
|
||||
? 32
|
||||
: 16 * packed_size; // fp4/fp6:32 or fp8:16
|
||||
|
||||
int KLane = ck_tile::get_warp_size() / NLane;
|
||||
int K0 = K / (KLane * KPack);
|
||||
|
||||
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({N * K}, {1}));
|
||||
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; k += packed_size)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
int tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
shuffled(outputIndex) = src(k, n);
|
||||
}
|
||||
}
|
||||
return shuffled;
|
||||
}
|
||||
|
||||
template <bool KLast, typename dtype>
|
||||
static auto preShuffleScale(ck_tile::HostTensor<dtype>& src)
|
||||
{
|
||||
|
||||
@@ -32,7 +32,6 @@ int run_mx_flatmm_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
ck_tile::index_t n_warmup = arg_parser.get_int("warmup");
|
||||
ck_tile::index_t n_repeat = arg_parser.get_int("repeat");
|
||||
@@ -106,7 +105,7 @@ int run_mx_flatmm_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
}
|
||||
}
|
||||
|
||||
const auto b_shuffled_host = preShuffleWeight<MXFlatmmArchTraits::GetNLane()>(b_origin_host);
|
||||
const auto b_shuffled_host = MXFlatmmArchTraits::preShuffleWeight(b_origin_host);
|
||||
const auto scale_a_shuffled = MXFlatmmArchTraits::template preShuffleScale<true>(scale_a);
|
||||
const auto scale_b_shuffled = MXFlatmmArchTraits::template preShuffleScale<false>(scale_b);
|
||||
|
||||
@@ -151,7 +150,6 @@ int run_mx_flatmm_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
scale_a_dev_ptr,
|
||||
scale_b_dev_ptr,
|
||||
n_warmup,
|
||||
|
||||
Reference in New Issue
Block a user