mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
[CK] Add rocm_ck foundation types: DataType, Layout, Args, Ops (#7114) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Add the vocabulary types that all rocm_ck schema headers build on - 9 new headers under `include/rocm_ck/`, 6 unit test files - Pure C++20, host-only — no CK Tile dependencies **Headers:** | Header | Purpose | |--------|---------| | `index_t.hpp` | `index_t`, `long_index_t` (matches ck_tile) | | `gpu_target.hpp` | `GpuTarget` enum (ISA targets) | | `datatype.hpp` | `DataType` enum (17 variants) | | `layout.hpp` | `Layout` enum (Row, Col, Auto) + stride helpers | | `fixed_string.hpp` | `FixedString<N>` — structural string for NTTPs | | `args.hpp` | Generic kernel argument buffer (ABI) | | `ops.hpp` | Operator structs (`GemmOp`, `AddOp`, ...) + `Op` variant | | `physical_tensor.hpp` | `PhysicalTensor` — maps names to Args slots | | `resolved_tensor.hpp` | `ResolvedTensor` — output of `Signature::resolve()` | **Stack**: This is PR 1 of 3 porting the rocm_ck constexpr schema from experimental to production, #7143. 1. **This PR** — Foundation types (vocabulary) 2. Schema engine — `Signature`, `resolve()`, `ArchProperties` 3. Spec factories — `GemmSpec`, `ElementwiseSpec`, `makeSpec()` ## Test plan - [ ] `ninja build-smoke-rocm-ck` builds all tests - [ ] `ctest -L ROCM_CK_SMOKE --output-on-failure` — 6 unit tests pass (86 test cases) - [ ] Default CK build (`CK_ENABLE_ROCM_CK=OFF`) unaffected 🤖 Generated with [Claude Code](https://claude.com/claude-code)
217 lines
6.0 KiB
C++
217 lines
6.0 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#include <rocm_ck/args.hpp>
|
|
|
|
#include <gmock/gmock.h>
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <array>
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <type_traits>
|
|
|
|
using ::rocm_ck::Args;
|
|
using ::rocm_ck::kMaxRank;
|
|
using ::rocm_ck::kMaxScalars;
|
|
using ::rocm_ck::kMaxTensors;
|
|
using ::rocm_ck::makeShape;
|
|
using ::rocm_ck::makeStrides;
|
|
using ::rocm_ck::ScalarValue;
|
|
using ::rocm_ck::TensorArg;
|
|
using ::testing::ElementsAre;
|
|
|
|
namespace {
|
|
|
|
// ============================================================================
|
|
// TensorArg ABI
|
|
// ============================================================================
|
|
|
|
TEST(TensorArg, IsTriviallyCopyable) { EXPECT_TRUE(std::is_trivially_copyable_v<TensorArg>); }
|
|
|
|
TEST(TensorArg, HasStandardLayout) { EXPECT_TRUE(std::is_standard_layout_v<TensorArg>); }
|
|
|
|
TEST(TensorArg, Occupies80Bytes)
|
|
{
|
|
// ptr(8) + lengths(6*4=24) + strides(6*8=48) = 80
|
|
EXPECT_EQ(sizeof(TensorArg), 80);
|
|
}
|
|
|
|
TEST(TensorArg, AlignsTo8Bytes) { EXPECT_EQ(alignof(TensorArg), 8); }
|
|
|
|
TEST(TensorArg, PlacesFieldsAtExpectedOffsets)
|
|
{
|
|
EXPECT_EQ(offsetof(TensorArg, ptr), 0);
|
|
EXPECT_EQ(offsetof(TensorArg, lengths), 8);
|
|
EXPECT_EQ(offsetof(TensorArg, strides), 32);
|
|
}
|
|
|
|
// ============================================================================
|
|
// ScalarValue ABI
|
|
// ============================================================================
|
|
|
|
TEST(ScalarValue, IsTriviallyCopyable) { EXPECT_TRUE(std::is_trivially_copyable_v<ScalarValue>); }
|
|
|
|
TEST(ScalarValue, Occupies8Bytes)
|
|
{
|
|
// Union of float(4), int32(4), uint32(4), double(8) -> 8 bytes
|
|
EXPECT_EQ(sizeof(ScalarValue), 8);
|
|
}
|
|
|
|
// ============================================================================
|
|
// Args ABI
|
|
// ============================================================================
|
|
|
|
TEST(Args, IsTriviallyCopyable) { EXPECT_TRUE(std::is_trivially_copyable_v<Args>); }
|
|
|
|
TEST(Args, HasStandardLayout) { EXPECT_TRUE(std::is_standard_layout_v<Args>); }
|
|
|
|
TEST(Args, Occupies1552Bytes)
|
|
{
|
|
// 16 tensors * 80 + 16 scalars * 8 + batch_count(4) + pad(4)
|
|
// + 16 batch_strides * 8 + workspace_ptr(8) = 1280 + 128 + 8 + 128 + 8 = 1552
|
|
EXPECT_EQ(sizeof(Args), 1552);
|
|
}
|
|
|
|
TEST(Args, AlignsTo8Bytes) { EXPECT_EQ(alignof(Args), 8); }
|
|
|
|
TEST(Args, FitsWithin4KBKernargBudget)
|
|
{
|
|
// HSA minimum kernarg size is 4096 bytes
|
|
EXPECT_LE(sizeof(Args), 4096);
|
|
}
|
|
|
|
// ============================================================================
|
|
// Capacity constants
|
|
// ============================================================================
|
|
|
|
TEST(Args, DefinesExpectedCapacityLimits)
|
|
{
|
|
EXPECT_EQ(kMaxRank, 6);
|
|
EXPECT_EQ(kMaxTensors, 16);
|
|
EXPECT_EQ(kMaxScalars, 16);
|
|
}
|
|
|
|
// ============================================================================
|
|
// ScalarValue union access
|
|
// ============================================================================
|
|
|
|
TEST(ScalarValue, StoresAndRetrievesFloat)
|
|
{
|
|
ScalarValue sv{};
|
|
sv.f32 = 3.14f;
|
|
EXPECT_FLOAT_EQ(sv.f32, 3.14f);
|
|
}
|
|
|
|
TEST(ScalarValue, StoresAndRetrievesInt32)
|
|
{
|
|
ScalarValue sv{};
|
|
sv.i32 = -42;
|
|
EXPECT_EQ(sv.i32, -42);
|
|
}
|
|
|
|
TEST(ScalarValue, StoresAndRetrievesDouble)
|
|
{
|
|
ScalarValue sv{};
|
|
sv.f64 = 2.718281828;
|
|
EXPECT_DOUBLE_EQ(sv.f64, 2.718281828);
|
|
}
|
|
|
|
TEST(ScalarValue, StoresAndRetrievesUInt32)
|
|
{
|
|
ScalarValue sv{};
|
|
sv.u32 = 0xDEADBEEF;
|
|
EXPECT_EQ(sv.u32, 0xDEADBEEF);
|
|
}
|
|
|
|
// ============================================================================
|
|
// Args field coverage — batch_strides and workspace_ptr
|
|
// ============================================================================
|
|
|
|
TEST(Args, BatchStridesFieldExists)
|
|
{
|
|
Args args{};
|
|
args.batch_strides[0] = 12345;
|
|
args.batch_strides[kMaxTensors - 1] = -99;
|
|
EXPECT_EQ(args.batch_strides[0], 12345);
|
|
EXPECT_EQ(args.batch_strides[kMaxTensors - 1], -99);
|
|
}
|
|
|
|
TEST(Args, WorkspacePtrFieldExists)
|
|
{
|
|
Args args{};
|
|
int dummy = 42;
|
|
args.workspace_ptr = &dummy;
|
|
EXPECT_EQ(args.workspace_ptr, &dummy);
|
|
}
|
|
|
|
TEST(Args, BatchCountFieldExists)
|
|
{
|
|
Args args{};
|
|
args.batch_count = 8;
|
|
EXPECT_EQ(args.batch_count, 8);
|
|
}
|
|
|
|
// ============================================================================
|
|
// Boundary access tests
|
|
// ============================================================================
|
|
|
|
TEST(Args, BoundaryAccessToTensors)
|
|
{
|
|
Args args{};
|
|
// Access last tensor slot (kMaxTensors - 1 = 15)
|
|
args.tensors[kMaxTensors - 1].ptr = nullptr;
|
|
EXPECT_EQ(args.tensors[kMaxTensors - 1].ptr, nullptr);
|
|
}
|
|
|
|
TEST(Args, BoundaryAccessToScalars)
|
|
{
|
|
Args args{};
|
|
// Access last scalar slot (kMaxScalars - 1 = 15)
|
|
args.scalars[kMaxScalars - 1].f32 = 1.0f;
|
|
EXPECT_FLOAT_EQ(args.scalars[kMaxScalars - 1].f32, 1.0f);
|
|
}
|
|
|
|
TEST(TensorArg, BoundaryAccessToLengthsAndStrides)
|
|
{
|
|
TensorArg ta{};
|
|
// Access last rank dimension (kMaxRank - 1 = 5)
|
|
ta.lengths[kMaxRank - 1] = 42;
|
|
ta.strides[kMaxRank - 1] = 99;
|
|
EXPECT_EQ(ta.lengths[kMaxRank - 1], 42);
|
|
EXPECT_EQ(ta.strides[kMaxRank - 1], 99);
|
|
}
|
|
|
|
// ============================================================================
|
|
// makeShape
|
|
// ============================================================================
|
|
|
|
TEST(MakeShape, ZeroFillsUnusedDimensions)
|
|
{
|
|
EXPECT_THAT(makeShape(128, 64), ElementsAre(128, 64, 0, 0, 0, 0));
|
|
}
|
|
|
|
TEST(MakeShape, FillsAllSixDimensions)
|
|
{
|
|
EXPECT_THAT(makeShape(2, 3, 4, 5, 6, 7), ElementsAre(2, 3, 4, 5, 6, 7));
|
|
}
|
|
|
|
TEST(MakeShape, SingleDimension) { EXPECT_THAT(makeShape(1024), ElementsAre(1024, 0, 0, 0, 0, 0)); }
|
|
|
|
// ============================================================================
|
|
// makeStrides
|
|
// ============================================================================
|
|
|
|
TEST(MakeStrides, ZeroFillsUnusedDimensions)
|
|
{
|
|
EXPECT_THAT(makeStrides(256, 1), ElementsAre(256, 1, 0, 0, 0, 0));
|
|
}
|
|
|
|
TEST(MakeStrides, HandlesLargeInt64Values)
|
|
{
|
|
constexpr int64_t large = 1LL << 40;
|
|
EXPECT_THAT(makeStrides(large, 1), ElementsAre(large, 1, 0, 0, 0, 0));
|
|
}
|
|
|
|
} // namespace
|