mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 05:19:20 +00:00
[CK] Add rocm_ck foundation types: DataType, Layout, Args, Ops (#7114) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Add the vocabulary types that all rocm_ck schema headers build on - 9 new headers under `include/rocm_ck/`, 6 unit test files - Pure C++20, host-only — no CK Tile dependencies **Headers:** | Header | Purpose | |--------|---------| | `index_t.hpp` | `index_t`, `long_index_t` (matches ck_tile) | | `gpu_target.hpp` | `GpuTarget` enum (ISA targets) | | `datatype.hpp` | `DataType` enum (17 variants) | | `layout.hpp` | `Layout` enum (Row, Col, Auto) + stride helpers | | `fixed_string.hpp` | `FixedString<N>` — structural string for NTTPs | | `args.hpp` | Generic kernel argument buffer (ABI) | | `ops.hpp` | Operator structs (`GemmOp`, `AddOp`, ...) + `Op` variant | | `physical_tensor.hpp` | `PhysicalTensor` — maps names to Args slots | | `resolved_tensor.hpp` | `ResolvedTensor` — output of `Signature::resolve()` | **Stack**: This is PR 1 of 3 porting the rocm_ck constexpr schema from experimental to production, #7143. 1. **This PR** — Foundation types (vocabulary) 2. Schema engine — `Signature`, `resolve()`, `ArchProperties` 3. Spec factories — `GemmSpec`, `ElementwiseSpec`, `makeSpec()` ## Test plan - [ ] `ninja build-smoke-rocm-ck` builds all tests - [ ] `ctest -L ROCM_CK_SMOKE --output-on-failure` — 6 unit tests pass (86 test cases) - [ ] Default CK build (`CK_ENABLE_ROCM_CK=OFF`) unaffected 🤖 Generated with [Claude Code](https://claude.com/claude-code)
89 lines
2.6 KiB
C++
89 lines
2.6 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#include <rocm_ck/layout.hpp>
|
|
|
|
#include <gmock/gmock.h>
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <array>
|
|
|
|
using ::rocm_ck::isValidLayoutForRank;
|
|
using ::rocm_ck::Layout;
|
|
using ::rocm_ck::layoutName;
|
|
using ::rocm_ck::layoutStrides;
|
|
using ::rocm_ck::leadingDimStride;
|
|
using ::testing::ElementsAre;
|
|
|
|
namespace {
|
|
|
|
// ============================================================================
|
|
// layoutName
|
|
// ============================================================================
|
|
|
|
TEST(Layout, MapsEnumValuesToExpectedStrings)
|
|
{
|
|
EXPECT_STREQ(layoutName(Layout::Row), "Row");
|
|
EXPECT_STREQ(layoutName(Layout::Col), "Col");
|
|
EXPECT_STREQ(layoutName(Layout::Auto), "Auto");
|
|
}
|
|
|
|
// ============================================================================
|
|
// isValidLayoutForRank
|
|
// ============================================================================
|
|
|
|
TEST(Layout, AllowsRowAndColOnlyForRank2)
|
|
{
|
|
EXPECT_FALSE(isValidLayoutForRank(Layout::Row, 1));
|
|
EXPECT_TRUE(isValidLayoutForRank(Layout::Row, 2));
|
|
EXPECT_FALSE(isValidLayoutForRank(Layout::Col, 1));
|
|
EXPECT_TRUE(isValidLayoutForRank(Layout::Col, 2));
|
|
}
|
|
|
|
TEST(Layout, RejectsAutoForAllRanks)
|
|
{
|
|
EXPECT_FALSE(isValidLayoutForRank(Layout::Auto, 0));
|
|
EXPECT_FALSE(isValidLayoutForRank(Layout::Auto, 1));
|
|
EXPECT_FALSE(isValidLayoutForRank(Layout::Auto, 2));
|
|
}
|
|
|
|
TEST(Layout, RejectsRowAndColForRankGreaterThan2)
|
|
{
|
|
EXPECT_FALSE(isValidLayoutForRank(Layout::Row, 3));
|
|
EXPECT_FALSE(isValidLayoutForRank(Layout::Row, 4));
|
|
EXPECT_FALSE(isValidLayoutForRank(Layout::Row, 6));
|
|
EXPECT_FALSE(isValidLayoutForRank(Layout::Col, 3));
|
|
EXPECT_FALSE(isValidLayoutForRank(Layout::Col, 4));
|
|
EXPECT_FALSE(isValidLayoutForRank(Layout::Col, 6));
|
|
}
|
|
|
|
// ============================================================================
|
|
// leadingDimStride
|
|
// ============================================================================
|
|
|
|
TEST(Layout, LeadingDimStrideReturnsFirstForRow)
|
|
{
|
|
EXPECT_EQ(leadingDimStride(Layout::Row, std::array{128, 1}), 128);
|
|
}
|
|
|
|
TEST(Layout, LeadingDimStrideReturnsSecondForCol)
|
|
{
|
|
EXPECT_EQ(leadingDimStride(Layout::Col, std::array{1, 64}), 64);
|
|
}
|
|
|
|
// ============================================================================
|
|
// layoutStrides
|
|
// ============================================================================
|
|
|
|
TEST(Layout, LayoutStridesRowMajor)
|
|
{
|
|
EXPECT_THAT(layoutStrides(Layout::Row, 32, 64), ElementsAre(64, 1));
|
|
}
|
|
|
|
TEST(Layout, LayoutStridesColMajor)
|
|
{
|
|
EXPECT_THAT(layoutStrides(Layout::Col, 32, 64), ElementsAre(1, 32));
|
|
}
|
|
|
|
} // namespace
|