mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
[rocm-libraries] ROCm/rocm-libraries#5082 (commit 9313659)
ck_tile: add gtest unit tests for MX flatmm (gfx950)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
## Summary
- Add correctness unit tests for the MX-format flatmm kernel
(`example/ck_tile/18_flatmm/mxgemm`) under `test/ck_tile/flatmm/`
- Tests cover all five dtype combinations: FP4×FP4, FP8×FP8, FP6×FP6,
FP8×FP4, FP4×FP8
- Tests cover all four kernel dispatch paths (the `has_hot_loop` ×
`tail_num` product):
- `has_hot_loop=false, tail=ODD` (K=256, num_loop=1)
- `has_hot_loop=false, tail=EVEN` (K=512, num_loop=2)
- `has_hot_loop=true, tail=ODD` (K=768, num_loop=3)
- `has_hot_loop=true, tail=EVEN` (K=1024, num_loop=4)
- Remove unsupported `-split_k` CLI option from
`tile_example_mx_flatmm`; the pre-shuffled B layout is incompatible with
K-splitting and the option silently produced wrong results
## Changes
**New files (`test/ck_tile/flatmm/`):**
- `CMakeLists.txt` — builds 40 kernel instances as a shared OBJECT
library, links into 5 per-dtype test executables; forwards
`-DCK_TILE_USE_OCP_FP8` when `CK_USE_OCP_FP8` is ON
- `test_mx_flatmm_base.hpp` — base test fixture with
`run_test_with_validation(M, N, K, kbatch=1)`
- `test_mx_flatmm_fixtures.hpp` — concrete `TestMXFlatmm` typed test
class and type aliases
- `test_mx_flatmm_fp{4fp4,8fp8,6fp6,8fp4,4fp8}.cpp` — per-dtype
`TYPED_TEST_SUITE` files
**Modified files:**
- `example/ck_tile/18_flatmm/mxgemm/mx_flatmm_arch_traits.hpp` — moved
`preShuffleWeight` here (was in `mx_flatmm.cpp`) so it is includeable by
both the example and the tests
- `example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp` / `run_mx_flatmm.inc`
— removed `-split_k` CLI arg, hardcoded `k_batch=1`, fixed `k_split`
formula, updated call sites after `preShuffleWeight` move
- `test/ck_tile/CMakeLists.txt` — added `add_subdirectory(flatmm)`
This commit is contained in:
committed by
assistant-librarian[bot]
parent
2169367735
commit
1a4aa7fd89
@@ -52,7 +52,12 @@ bool compare_results(std::string instanceName,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
|
||||
std::abs(static_cast<float>(*std::max_element(c_m_n_host_result.mData.begin(),
|
||||
c_m_n_host_result.mData.end(),
|
||||
[](CDataType a, CDataType b) {
|
||||
return std::abs(static_cast<float>(a)) <
|
||||
std::abs(static_cast<float>(b));
|
||||
})));
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
|
||||
Reference in New Issue
Block a user