Files
composable_kernel/rocm_ck/tests/unit/unit_layout.cpp
John Shumway 3e110e1718 [rocm-libraries] ROCm/rocm-libraries#7114 (commit ecef372)
[CK] Add rocm_ck foundation types: DataType, Layout, Args,
 Ops (#7114)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

- Add the vocabulary types that all rocm_ck schema headers build on
- 9 new headers under `include/rocm_ck/`, 6 unit test files
- Pure C++20, host-only — no CK Tile dependencies

**Headers:**

| Header | Purpose |
|--------|---------|
| `index_t.hpp` | `index_t`, `long_index_t` (matches ck_tile) |
| `gpu_target.hpp` | `GpuTarget` enum (ISA targets) |
| `datatype.hpp` | `DataType` enum (17 variants) |
| `layout.hpp` | `Layout` enum (Row, Col, Auto) + stride helpers |
| `fixed_string.hpp` | `FixedString<N>` — structural string for NTTPs |
| `args.hpp` | Generic kernel argument buffer (ABI) |
| `ops.hpp` | Operator structs (`GemmOp`, `AddOp`, ...) + `Op` variant |
| `physical_tensor.hpp` | `PhysicalTensor` — maps names to Args slots |
| `resolved_tensor.hpp` | `ResolvedTensor` — output of
`Signature::resolve()` |

**Stack**: This is PR 1 of 3 porting the rocm_ck constexpr schema from
experimental to production, #7143.
1. **This PR** — Foundation types (vocabulary)
2. Schema engine — `Signature`, `resolve()`, `ArchProperties`
3. Spec factories — `GemmSpec`, `ElementwiseSpec`, `makeSpec()`

## Test plan

- [ ] `ninja build-smoke-rocm-ck` builds all tests
- [ ] `ctest -L ROCM_CK_SMOKE --output-on-failure` — 6 unit tests pass
(86 test cases)
- [ ] Default CK build (`CK_ENABLE_ROCM_CK=OFF`) unaffected

🤖 Generated with [Claude Code](https://claude.com/claude-code)
2026-05-15 19:22:44 +00:00

89 lines
2.6 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <rocm_ck/layout.hpp>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <array>
using ::rocm_ck::isValidLayoutForRank;
using ::rocm_ck::Layout;
using ::rocm_ck::layoutName;
using ::rocm_ck::layoutStrides;
using ::rocm_ck::leadingDimStride;
using ::testing::ElementsAre;
namespace {
// ============================================================================
// layoutName
// ============================================================================
TEST(Layout, MapsEnumValuesToExpectedStrings)
{
EXPECT_STREQ(layoutName(Layout::Row), "Row");
EXPECT_STREQ(layoutName(Layout::Col), "Col");
EXPECT_STREQ(layoutName(Layout::Auto), "Auto");
}
// ============================================================================
// isValidLayoutForRank
// ============================================================================
TEST(Layout, AllowsRowAndColOnlyForRank2)
{
EXPECT_FALSE(isValidLayoutForRank(Layout::Row, 1));
EXPECT_TRUE(isValidLayoutForRank(Layout::Row, 2));
EXPECT_FALSE(isValidLayoutForRank(Layout::Col, 1));
EXPECT_TRUE(isValidLayoutForRank(Layout::Col, 2));
}
TEST(Layout, RejectsAutoForAllRanks)
{
EXPECT_FALSE(isValidLayoutForRank(Layout::Auto, 0));
EXPECT_FALSE(isValidLayoutForRank(Layout::Auto, 1));
EXPECT_FALSE(isValidLayoutForRank(Layout::Auto, 2));
}
TEST(Layout, RejectsRowAndColForRankGreaterThan2)
{
EXPECT_FALSE(isValidLayoutForRank(Layout::Row, 3));
EXPECT_FALSE(isValidLayoutForRank(Layout::Row, 4));
EXPECT_FALSE(isValidLayoutForRank(Layout::Row, 6));
EXPECT_FALSE(isValidLayoutForRank(Layout::Col, 3));
EXPECT_FALSE(isValidLayoutForRank(Layout::Col, 4));
EXPECT_FALSE(isValidLayoutForRank(Layout::Col, 6));
}
// ============================================================================
// leadingDimStride
// ============================================================================
TEST(Layout, LeadingDimStrideReturnsFirstForRow)
{
EXPECT_EQ(leadingDimStride(Layout::Row, std::array{128, 1}), 128);
}
TEST(Layout, LeadingDimStrideReturnsSecondForCol)
{
EXPECT_EQ(leadingDimStride(Layout::Col, std::array{1, 64}), 64);
}
// ============================================================================
// layoutStrides
// ============================================================================
TEST(Layout, LayoutStridesRowMajor)
{
EXPECT_THAT(layoutStrides(Layout::Row, 32, 64), ElementsAre(64, 1));
}
TEST(Layout, LayoutStridesColMajor)
{
EXPECT_THAT(layoutStrides(Layout::Col, 32, 64), ElementsAre(1, 32));
}
} // namespace