[rocm-libraries] ROCm/rocm-libraries#7114 (commit ecef372)

[CK] Add rocm_ck foundation types: DataType, Layout, Args,
 Ops (#7114)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

- Add the vocabulary types that all rocm_ck schema headers build on
- 9 new headers under `include/rocm_ck/`, 6 unit test files
- Pure C++20, host-only — no CK Tile dependencies

**Headers:**

| Header | Purpose |
|--------|---------|
| `index_t.hpp` | `index_t`, `long_index_t` (matches ck_tile) |
| `gpu_target.hpp` | `GpuTarget` enum (ISA targets) |
| `datatype.hpp` | `DataType` enum (17 variants) |
| `layout.hpp` | `Layout` enum (Row, Col, Auto) + stride helpers |
| `fixed_string.hpp` | `FixedString<N>` — structural string for NTTPs |
| `args.hpp` | Generic kernel argument buffer (ABI) |
| `ops.hpp` | Operator structs (`GemmOp`, `AddOp`, ...) + `Op` variant |
| `physical_tensor.hpp` | `PhysicalTensor` — maps names to Args slots |
| `resolved_tensor.hpp` | `ResolvedTensor` — output of
`Signature::resolve()` |

**Stack**: This is PR 1 of 3 porting the rocm_ck constexpr schema from
experimental to production, #7143.
1. **This PR** — Foundation types (vocabulary)
2. Schema engine — `Signature`, `resolve()`, `ArchProperties`
3. Spec factories — `GemmSpec`, `ElementwiseSpec`, `makeSpec()`

## Test plan

- [ ] `ninja build-smoke-rocm-ck` builds all tests
- [ ] `ctest -L ROCM_CK_SMOKE --output-on-failure` — 6 unit tests pass
(86 test cases)
- [ ] Default CK build (`CK_ENABLE_ROCM_CK=OFF`) unaffected

🤖 Generated with [Claude Code](https://claude.com/claude-code)
This commit is contained in:
John Shumway
2026-05-15 19:22:44 +00:00
committed by assistant-librarian[bot]
parent 187ef8ac94
commit 3e110e1718
16 changed files with 1072 additions and 24 deletions

View File

@@ -3,11 +3,10 @@
A C++20 constexpr API for configuring and distributing
[CK Tile](../include/ck_tile/) GPU kernels across multiple architectures.
> **Status**: Early development. The current code establishes the directory
> structure, build integration, and CI pipeline. A single unit test verifies
> that the build and test infrastructure works end-to-end in Jenkins.
> The schema types, device bridge, and kernel tests described below are
> under active development.
> **Status**: Early development. Foundation types are in place (DataType,
> Layout, Args, operators, FixedString, PhysicalTensor, ResolvedTensor).
> The schema engine (Signature, resolve(), Algorithm) and device bridge
> are under active development.
## Why rocm_ck exists

View File

@@ -0,0 +1,89 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Role: abi — shared between host and device. Trivially copyable, no CK deps.
//
// Args is a hardware buffer for passing data between CPU and GPU during a
// kernel call. It carries raw pointers, shapes, strides, and scalar values —
// nothing more. All semantic meaning (which tensor is "A", which scalar is
// "alpha", input vs output) lives in the Signature, not here.
//
// This is deliberately one type for all operations. Per-operation structs
// (GemmArgs, FmhaArgs, ...) would make the dispatcher a closed set — adding
// an operation means adding a type, updating launch code, and changing the
// kpack format. A generic buffer keeps the dispatcher open.
//
// Capacity limits (kMaxRank=6, kMaxTensors=16, kMaxScalars=16) are sized to
// the most demanding current operation (FMHA backward: ~12 tensors, ~12
// scalars, rank-6 for grouped 3D conv). If a future operation exceeds these,
// bump the constants — the layout is not versioned, and the 4KB HSA kernarg
// budget has room. Don't over-provision speculatively.
//
// Key constraints:
// - Trivially copyable, standard layout — required for HSA kernarg passing.
// - Fixed-capacity arrays, no heap — sizeof fits the 4KB kernarg budget.
// - const void* for all tensor pointers — the entry kernel casts to the
// concrete type. Input vs output semantics live in the Signature.
// - No runtime type tags on scalars — the Signature declares types at
// compile time. The entry kernel reads the correct union member.
// - Slot ordering is the invariant: tensors[i] maps to Signature::tensors[i].
#pragma once
#include <rocm_ck/index_t.hpp>
#include <array>
#include <cstdint>
namespace rocm_ck {
// When changing these, update the byte-size comments on TensorArg and Args fields.
constexpr int kMaxRank = 6; // grouped 3D conv = GNCDHW = rank 6
constexpr int kMaxTensors = 16; // FMHA backward uses ~12
constexpr int kMaxScalars = 16; // FMHA with masking+dropout needs ~12
struct TensorArg
{
const void* ptr; // 8 bytes (offset 0)
std::array<index_t, kMaxRank> lengths; // 24 bytes (offset 8) — int32
std::array<long_index_t, kMaxRank> strides; // 48 bytes (offset 32) — int64
};
// FP16/BF16/FP8 scalars use f32 — scalar precision >= tensor precision.
union ScalarValue
{
float f32;
int32_t i32;
uint32_t u32;
double f64;
int64_t i64;
uint64_t u64;
};
// Slot ordering matches Signature: tensors[i] <-> Signature::tensors[i].
struct Args
{
std::array<TensorArg, kMaxTensors> tensors; // 16 x 80 = 1280 bytes
std::array<ScalarValue, kMaxScalars> scalars; // 16 x 8 = 128 bytes
index_t batch_count = 0; // 4 bytes
std::array<long_index_t, kMaxTensors> batch_strides = {}; // 16 x 8 = 128 bytes
void* workspace_ptr = nullptr; // 8 bytes
};
constexpr std::array<index_t, kMaxRank> makeShape(
index_t d0, index_t d1 = 0, index_t d2 = 0, index_t d3 = 0, index_t d4 = 0, index_t d5 = 0)
{
return {d0, d1, d2, d3, d4, d5};
}
constexpr std::array<long_index_t, kMaxRank> makeStrides(long_index_t s0,
long_index_t s1 = 0,
long_index_t s2 = 0,
long_index_t s3 = 0,
long_index_t s4 = 0,
long_index_t s5 = 0)
{
return {s0, s1, s2, s3, s4, s5};
}
} // namespace rocm_ck

View File

@@ -0,0 +1,101 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Role: types — DataType enum, constexpr queries. No runtime, no CK deps.
#pragma once
#include "rocm_ck/platform.hpp"
#include <cstdint>
namespace rocm_ck {
// FP8 = e4m3, BF8 = e5m2 (CK convention).
enum class DataType : uint8_t
{
// Floating point — standard widths
FP64,
FP32,
FP16,
BF16,
// FP8 variants — see note below
FP8_FNUZ,
BF8_FNUZ,
FP8_OCP,
BF8_OCP,
// Integer types — signed and unsigned at each width
I4,
I8,
I16,
I32,
I64,
U8,
U16,
U32,
U64
};
// FP8 variants — FNUZ and OCP are different number formats, not just HW hints.
// FNUZ: gfx942 native (higher bias, no Inf, max 240)
// OCP: gfx950 native (OCP standard, has Inf, max 448)
// Non-native formats run in software (slower) and produce different numerical
// results. Choose based on target GPU and model training format.
// We keep FNUZ and OCP explicit rather than a generic FP8 — the numerical
// differences matter for compatibility and schema-driven test coverage.
// TODO - We may introduce a generic FP8/BF8 that resolves to the hardware-native type.
// See: https://rocm.docs.amd.com/projects/HIP/en/latest/reference/fp8_numbers.html
// Bits (not bytes) so sub-byte types (I4) are clean integers.
constexpr int dataTypeBits(DataType dt)
{
switch(dt)
{
case DataType::FP64: return 64;
case DataType::FP32: return 32;
case DataType::FP16: return 16;
case DataType::BF16: return 16;
case DataType::FP8_FNUZ: return 8;
case DataType::BF8_FNUZ: return 8;
case DataType::FP8_OCP: return 8;
case DataType::BF8_OCP: return 8;
case DataType::I4: return 4;
case DataType::I8: return 8;
case DataType::I16: return 16;
case DataType::I32: return 32;
case DataType::I64: return 64;
case DataType::U8: return 8;
case DataType::U16: return 16;
case DataType::U32: return 32;
case DataType::U64: return 64;
}
ROCM_CK_UNREACHABLE();
}
constexpr const char* dataTypeName(DataType dt)
{
switch(dt)
{
case DataType::FP64: return "FP64";
case DataType::FP32: return "FP32";
case DataType::FP16: return "FP16";
case DataType::BF16: return "BF16";
case DataType::FP8_FNUZ: return "FP8_FNUZ";
case DataType::BF8_FNUZ: return "BF8_FNUZ";
case DataType::FP8_OCP: return "FP8_OCP";
case DataType::BF8_OCP: return "BF8_OCP";
case DataType::I4: return "I4";
case DataType::I8: return "I8";
case DataType::I16: return "I16";
case DataType::I32: return "I32";
case DataType::I64: return "I64";
case DataType::U8: return "U8";
case DataType::U16: return "U16";
case DataType::U32: return "U32";
case DataType::U64: return "U64";
}
ROCM_CK_UNREACHABLE();
}
} // namespace rocm_ck

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Role: types — FixedString. No runtime, no CK deps.
//
// A compile-time string for use in template parameters (NTTPs).
//
// C++20 requires template parameters to be "structural types" — loosely, types
// that are trivially comparable and don't contain pointers or references.
// std::string and std::string_view fail this requirement (internal pointer).
//
// FixedString stores the string inline in a char array, making it structural:
//
// template <PhysicalTensor PT> // PhysicalTensor contains FixedString<16>
// void dispatch() { ... }
//
// When to use FixedString vs std::string_view:
// - FixedString: the type must be structural (template parameters).
// - string_view: consteval-only types that never become template parameters
// (e.g., ResolvedTensor — see resolved_tensor.hpp).
//
// The capacity is a template parameter so each use site documents its limit:
// FixedString<16> name("bias"); // tensor names: 15 chars max
#pragma once
#include <cstddef>
#include <string_view>
namespace rocm_ck {
template <std::size_t MaxLen>
struct FixedString
{
char data[MaxLen]{};
int len = 0;
constexpr FixedString() = default;
constexpr FixedString(std::string_view sv) : len(static_cast<int>(sv.size()))
{
if(sv.size() > MaxLen - 1)
throw "FixedString: input exceeds capacity";
for(int i = 0; i < len; ++i)
data[i] = sv[i];
}
constexpr bool operator==(std::string_view sv) const
{
if(len != static_cast<int>(sv.size()))
return false;
for(int i = 0; i < len; ++i)
if(data[i] != sv[i])
return false;
return true;
}
// Required: the string_view overload above suppresses the implicit == from <=>.
constexpr bool operator==(const FixedString&) const = default;
constexpr auto operator<=>(const FixedString&) const = default;
};
} // namespace rocm_ck

View File

@@ -0,0 +1,24 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Role: types — GpuTarget enum. No runtime, no CK deps.
#pragma once
#include <cstdint>
namespace rocm_ck {
// ISA target identifiers (matching -mcpu flags), not marketing names.
enum class GpuTarget : uint8_t
{
gfx90a, // CDNA 2
gfx942, // CDNA 3
gfx950, // CDNA 4
gfx1100, // RDNA 3
gfx1101, // RDNA 3
gfx1102, // RDNA 3
gfx1150, // RDNA 3.5
gfx1151, // RDNA 3.5
};
} // namespace rocm_ck

View File

@@ -0,0 +1,69 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Role: types — Layout enum, constexpr/consteval helpers. No runtime, no CK deps.
#pragma once
#include "rocm_ck/platform.hpp"
#include <array>
#include <cstddef>
#include <cstdint>
namespace rocm_ck {
// Auto is a resolve-time placeholder — Signature::resolve() replaces it with
// the concrete layout from the operator slot. It never reaches the kernel.
enum class Layout : uint8_t
{
Row,
Col,
Auto
};
constexpr const char* layoutName(Layout layout)
{
switch(layout)
{
case Layout::Row: return "Row";
case Layout::Col: return "Col";
case Layout::Auto: return "Auto";
}
ROCM_CK_UNREACHABLE();
}
constexpr bool isValidLayoutForRank(Layout layout, int rank)
{
switch(layout)
{
case Layout::Row: return rank == 2;
case Layout::Col: return rank == 2;
case Layout::Auto: return false;
}
ROCM_CK_UNREACHABLE();
}
template <typename T, std::size_t N>
constexpr T leadingDimStride(Layout layout, const std::array<T, N>& strides)
{
switch(layout)
{
case Layout::Row: return strides[0];
case Layout::Col: return strides[1];
case Layout::Auto: throw "leadingDimStride requires Row or Col layout";
}
ROCM_CK_UNREACHABLE();
}
constexpr std::array<int, 2> layoutStrides(Layout layout, int rows, int cols)
{
switch(layout)
{
case Layout::Row: return {cols, 1};
case Layout::Col: return {1, rows};
case Layout::Auto: throw "layoutStrides requires Row or Col layout";
}
ROCM_CK_UNREACHABLE();
}
} // namespace rocm_ck

View File

@@ -0,0 +1,139 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Role: meta — operator structs, Op variant. No runtime, no CK deps.
//
// Operators are the edges of a Signature's compute graph. Each operator
// names its tensor slots as string_view labels (e.g., "A", "bias", "query")
// that refer to tensors declared elsewhere in the Signature. The Signature
// owns the tensor definitions; operators just reference them by name.
//
// This separation means operators are reusable across different tensor
// configurations — a GemmOp doesn't care whether its "lhs" is FP16 or BF16,
// Row or Col. That's resolved later when the Signature is validated.
//
// The Op variant is the closed set of supported operator types. Adding a
// new operator means adding a struct here and a variant alternative.
// Fused operations (like FMHA) are single operators — not chains of
// elementwise + GEMM — because CK Tile implements them as monolithic kernels.
#pragma once
#include <rocm_ck/datatype.hpp>
#include <string_view>
#include <variant>
namespace rocm_ck {
// Matrix multiplication: out = lhs x rhs.
// acc_dtype is the accumulation type — defaults to FP32, the universal safe
// choice across all input types.
struct GemmOp
{
std::string_view lhs;
std::string_view rhs;
std::string_view out;
DataType acc_dtype = DataType::FP32;
};
// Element-wise addition: out = lhs + rhs.
struct AddOp
{
std::string_view lhs;
std::string_view rhs;
std::string_view out;
};
// Element-wise multiplication: out = lhs * rhs.
struct MulOp
{
std::string_view lhs;
std::string_view rhs;
std::string_view out;
};
// ReLU activation: out = max(0, in).
struct ReluOp
{
std::string_view in;
std::string_view out;
};
// Fast GELU approximation: out = in * sigmoid(1.702 * in).
struct FastGeluOp
{
std::string_view in;
std::string_view out;
};
// Exact GELU: out = 0.5 * in * (1 + erf(in / sqrt(2))).
struct GeluOp
{
std::string_view in;
std::string_view out;
};
// SiLU (Swish) activation: out = in * sigmoid(in).
struct SiluOp
{
std::string_view in;
std::string_view out;
};
// Sigmoid activation: out = 1 / (1 + exp(-in)).
struct SigmoidOp
{
std::string_view in;
std::string_view out;
};
// Softmax: out[i] = exp(in[i]) / sum(exp(in)), reduction along last dimension.
struct SoftmaxOp
{
std::string_view in;
std::string_view out;
};
// Scalar multiply: out = in * scale.
// 'scale' names a Scalar in the Signature, not a tensor.
struct ScaleOp
{
std::string_view in;
std::string_view out;
std::string_view scale;
};
// Fused multi-head attention backward pass.
// Implemented as a single CK Tile kernel, not a chain of ops.
// Feature flags (mask, dropout, bias, deterministic) belong in the Algorithm.
struct FmhaBwdOp
{
std::string_view q; // query
std::string_view k; // key
std::string_view v; // value
std::string_view lse; // log-sum-exp from forward pass
std::string_view do_; // output gradient
std::string_view d; // dot(output_grad, output)
std::string_view dq; // query gradient
std::string_view dk; // key gradient
std::string_view dv; // value gradient
DataType acc_dtype = DataType::FP32;
};
// The closed set of supported operators. std::monostate marks empty slots.
using Op = std::variant<std::monostate,
GemmOp,
AddOp,
MulOp,
ReluOp,
FastGeluOp,
GeluOp,
SiluOp,
SigmoidOp,
SoftmaxOp,
ScaleOp,
FmhaBwdOp>;
} // namespace rocm_ck

View File

@@ -0,0 +1,29 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Role: types — PhysicalTensor. No runtime, no CK deps.
//
// A PhysicalTensor maps a named tensor from the Signature graph to a slot
// in the generic Args buffer. Not every tensor in a compute graph is physical —
// intermediate values (e.g., the S matrix in FMHA = Q*K^T) live only in
// registers and never appear in device memory. The physical tensor table
// describes exactly what the host needs to pack into Args.
#pragma once
#include <rocm_ck/datatype.hpp>
#include <rocm_ck/fixed_string.hpp>
#include <rocm_ck/layout.hpp>
namespace rocm_ck {
inline constexpr int kMaxPhysicalTensors = 8;
struct PhysicalTensor
{
FixedString<16> name;
DataType dtype = DataType::FP32;
Layout layout = Layout::Row;
int args_slot = 0;
};
} // namespace rocm_ck

View File

@@ -0,0 +1,13 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Compiler portability macros for LLVM/Clang/GCC and MSVC.
// C++23 will provide std::unreachable(): https://en.cppreference.com/w/cpp/utility/unreachable
#pragma once
#ifdef _MSC_VER
#define ROCM_CK_UNREACHABLE() __assume(false)
#else
#define ROCM_CK_UNREACHABLE() __builtin_unreachable()
#endif

View File

@@ -0,0 +1,59 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Role: types — ResolvedTensor, ResolvedQuantization. No runtime, no CK deps.
//
// ResolvedTensor is the intermediate result of consteval resolution. It exists
// only at compile time — produced by Signature::resolve() and consumed by
// makeSpec(), both consteval. It never appears in compiled code.
//
// In the user-facing Signature, tensors can have Layout::Auto (inherit from
// operator slot) and omit fields with sensible defaults. After resolution,
// every field is concrete. The base fields (name, dtype, rank, layout)
// describe a plain dense tensor — enough for most operands (GEMM inputs,
// outputs, bias vectors). Some tensors carry additional metadata beyond the
// dense description. Block-quantized tensors (e.g., INT4 weights) need a
// scale tensor and group size. We use optional sub-structs for these
// extensions, keeping the common case clean without bloating every instance.
//
// Why std::string_view instead of FixedString?
// ResolvedTensor is consteval-only — produced and consumed entirely at
// compile time. No library loading, no runtime lifetime concerns. The
// string_views point to string literals from user code (e.g.,
// GemmOp{.lhs = "A"}), which have static storage duration — no dangling.
// FixedString is required for PhysicalTensor because it IS used as a
// template parameter (NTTP), which requires structural types (no pointers).
// ResolvedTensor is never a template parameter.
//
// Plain aggregate — no methods, no validation. Resolution validates; this
// type just carries the result to makeSpec().
#pragma once
#include <rocm_ck/datatype.hpp>
#include <rocm_ck/layout.hpp>
#include <optional>
#include <string_view>
namespace rocm_ck {
// Present when a tensor carries block-quantized data (e.g., INT4 weights).
// The scale tensor is a separate entry in the Signature; this struct ties
// the quantized tensor to its scale.
struct ResolvedQuantization
{
std::string_view scale_name;
DataType scale_dtype;
int group_size; // elements per quantization group
};
struct ResolvedTensor
{
std::string_view name;
DataType dtype;
int rank = 2;
Layout layout = Layout::Row;
std::optional<ResolvedQuantization> quantize = std::nullopt;
};
} // namespace rocm_ck

View File

@@ -4,7 +4,7 @@
# rocm_ck tests
#
# Test tiers:
# ROCM_CK_SMOKE — Fast host-only tests (< 1s total). No GPU, no HIP.
# ROCM_CK_SMOKE — Fast host-only tests (< 1s total). No GPU required.
# ROCM_CK_KERNEL — GPU kernel tests. Require HIP and a GPU.
#
# Usage:
@@ -22,37 +22,33 @@ include(${CMAKE_CURRENT_SOURCE_DIR}/../../cmake/gtest.cmake)
# ---------------------------------------------------------------------------
function(add_rocm_ck_test test_name)
add_executable(${test_name} ${ARGN})
target_link_libraries(${test_name} PRIVATE rocm_ck GTest::gtest_main)
target_link_libraries(${test_name} PRIVATE rocm_ck GTest::gtest_main GTest::gmock)
target_compile_options(${test_name} PRIVATE
-Wno-global-constructors # GTest registration macros
-Wno-undef # GTest internal headers
-Wno-global-constructors # GTest registration macros
-Wno-undef # GTest internal headers
-Wno-zero-as-null-pointer-constant # C++20 <=> comparisons to 0
)
endfunction()
# ---------------------------------------------------------------------------
# Smoke tests (fast, host-only, no GPU)
# ---------------------------------------------------------------------------
set(ROCM_CK_SMOKE_TESTS
add_rocm_ck_test(rocm_ck_unit
unit/unit_args.cpp
unit/unit_datatype.cpp
unit/unit_fixed_string.cpp
unit/unit_index_t.cpp
unit/unit_layout.cpp
unit/unit_physical_tensor.cpp
)
set(ROCM_CK_SMOKE_TARGETS)
foreach(test_source ${ROCM_CK_SMOKE_TESTS})
get_filename_component(test_name ${test_source} NAME_WLE)
set(target_name "rocm_ck_${test_name}")
add_rocm_ck_test(${target_name} ${test_source})
add_test(NAME ${target_name} COMMAND ${target_name})
set_tests_properties(${target_name} PROPERTIES LABELS "ROCM_CK_SMOKE")
list(APPEND ROCM_CK_SMOKE_TARGETS ${target_name})
endforeach()
# rocm_ck_unit_index_t verifies rocm_ck index types match ck_tile
target_include_directories(rocm_ck_unit_index_t PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
target_include_directories(rocm_ck_unit PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
add_test(NAME rocm_ck_unit COMMAND rocm_ck_unit)
set_tests_properties(rocm_ck_unit PROPERTIES LABELS "ROCM_CK_SMOKE")
# ---------------------------------------------------------------------------
# Convenience targets
# ---------------------------------------------------------------------------
add_custom_target(build-smoke-rocm-ck DEPENDS ${ROCM_CK_SMOKE_TARGETS})
add_custom_target(build-smoke-rocm-ck DEPENDS rocm_ck_unit)
add_custom_target(smoke-rocm-ck
COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -L "ROCM_CK_SMOKE"

View File

@@ -0,0 +1,216 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <rocm_ck/args.hpp>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <array>
#include <cstddef>
#include <cstdint>
#include <type_traits>
using ::rocm_ck::Args;
using ::rocm_ck::kMaxRank;
using ::rocm_ck::kMaxScalars;
using ::rocm_ck::kMaxTensors;
using ::rocm_ck::makeShape;
using ::rocm_ck::makeStrides;
using ::rocm_ck::ScalarValue;
using ::rocm_ck::TensorArg;
using ::testing::ElementsAre;
namespace {
// ============================================================================
// TensorArg ABI
// ============================================================================
TEST(TensorArg, IsTriviallyCopyable) { EXPECT_TRUE(std::is_trivially_copyable_v<TensorArg>); }
TEST(TensorArg, HasStandardLayout) { EXPECT_TRUE(std::is_standard_layout_v<TensorArg>); }
TEST(TensorArg, Occupies80Bytes)
{
// ptr(8) + lengths(6*4=24) + strides(6*8=48) = 80
EXPECT_EQ(sizeof(TensorArg), 80);
}
TEST(TensorArg, AlignsTo8Bytes) { EXPECT_EQ(alignof(TensorArg), 8); }
TEST(TensorArg, PlacesFieldsAtExpectedOffsets)
{
EXPECT_EQ(offsetof(TensorArg, ptr), 0);
EXPECT_EQ(offsetof(TensorArg, lengths), 8);
EXPECT_EQ(offsetof(TensorArg, strides), 32);
}
// ============================================================================
// ScalarValue ABI
// ============================================================================
TEST(ScalarValue, IsTriviallyCopyable) { EXPECT_TRUE(std::is_trivially_copyable_v<ScalarValue>); }
TEST(ScalarValue, Occupies8Bytes)
{
// Union of float(4), int32(4), uint32(4), double(8) -> 8 bytes
EXPECT_EQ(sizeof(ScalarValue), 8);
}
// ============================================================================
// Args ABI
// ============================================================================
TEST(Args, IsTriviallyCopyable) { EXPECT_TRUE(std::is_trivially_copyable_v<Args>); }
TEST(Args, HasStandardLayout) { EXPECT_TRUE(std::is_standard_layout_v<Args>); }
TEST(Args, Occupies1552Bytes)
{
// 16 tensors * 80 + 16 scalars * 8 + batch_count(4) + pad(4)
// + 16 batch_strides * 8 + workspace_ptr(8) = 1280 + 128 + 8 + 128 + 8 = 1552
EXPECT_EQ(sizeof(Args), 1552);
}
TEST(Args, AlignsTo8Bytes) { EXPECT_EQ(alignof(Args), 8); }
TEST(Args, FitsWithin4KBKernargBudget)
{
// HSA minimum kernarg size is 4096 bytes
EXPECT_LE(sizeof(Args), 4096);
}
// ============================================================================
// Capacity constants
// ============================================================================
TEST(Args, DefinesExpectedCapacityLimits)
{
EXPECT_EQ(kMaxRank, 6);
EXPECT_EQ(kMaxTensors, 16);
EXPECT_EQ(kMaxScalars, 16);
}
// ============================================================================
// ScalarValue union access
// ============================================================================
TEST(ScalarValue, StoresAndRetrievesFloat)
{
ScalarValue sv{};
sv.f32 = 3.14f;
EXPECT_FLOAT_EQ(sv.f32, 3.14f);
}
TEST(ScalarValue, StoresAndRetrievesInt32)
{
ScalarValue sv{};
sv.i32 = -42;
EXPECT_EQ(sv.i32, -42);
}
TEST(ScalarValue, StoresAndRetrievesDouble)
{
ScalarValue sv{};
sv.f64 = 2.718281828;
EXPECT_DOUBLE_EQ(sv.f64, 2.718281828);
}
TEST(ScalarValue, StoresAndRetrievesUInt32)
{
ScalarValue sv{};
sv.u32 = 0xDEADBEEF;
EXPECT_EQ(sv.u32, 0xDEADBEEF);
}
// ============================================================================
// Args field coverage — batch_strides and workspace_ptr
// ============================================================================
TEST(Args, BatchStridesFieldExists)
{
Args args{};
args.batch_strides[0] = 12345;
args.batch_strides[kMaxTensors - 1] = -99;
EXPECT_EQ(args.batch_strides[0], 12345);
EXPECT_EQ(args.batch_strides[kMaxTensors - 1], -99);
}
TEST(Args, WorkspacePtrFieldExists)
{
Args args{};
int dummy = 42;
args.workspace_ptr = &dummy;
EXPECT_EQ(args.workspace_ptr, &dummy);
}
TEST(Args, BatchCountFieldExists)
{
Args args{};
args.batch_count = 8;
EXPECT_EQ(args.batch_count, 8);
}
// ============================================================================
// Boundary access tests
// ============================================================================
TEST(Args, BoundaryAccessToTensors)
{
Args args{};
// Access last tensor slot (kMaxTensors - 1 = 15)
args.tensors[kMaxTensors - 1].ptr = nullptr;
EXPECT_EQ(args.tensors[kMaxTensors - 1].ptr, nullptr);
}
TEST(Args, BoundaryAccessToScalars)
{
Args args{};
// Access last scalar slot (kMaxScalars - 1 = 15)
args.scalars[kMaxScalars - 1].f32 = 1.0f;
EXPECT_FLOAT_EQ(args.scalars[kMaxScalars - 1].f32, 1.0f);
}
TEST(TensorArg, BoundaryAccessToLengthsAndStrides)
{
TensorArg ta{};
// Access last rank dimension (kMaxRank - 1 = 5)
ta.lengths[kMaxRank - 1] = 42;
ta.strides[kMaxRank - 1] = 99;
EXPECT_EQ(ta.lengths[kMaxRank - 1], 42);
EXPECT_EQ(ta.strides[kMaxRank - 1], 99);
}
// ============================================================================
// makeShape
// ============================================================================
TEST(MakeShape, ZeroFillsUnusedDimensions)
{
EXPECT_THAT(makeShape(128, 64), ElementsAre(128, 64, 0, 0, 0, 0));
}
TEST(MakeShape, FillsAllSixDimensions)
{
EXPECT_THAT(makeShape(2, 3, 4, 5, 6, 7), ElementsAre(2, 3, 4, 5, 6, 7));
}
TEST(MakeShape, SingleDimension) { EXPECT_THAT(makeShape(1024), ElementsAre(1024, 0, 0, 0, 0, 0)); }
// ============================================================================
// makeStrides
// ============================================================================
TEST(MakeStrides, ZeroFillsUnusedDimensions)
{
EXPECT_THAT(makeStrides(256, 1), ElementsAre(256, 1, 0, 0, 0, 0));
}
TEST(MakeStrides, HandlesLargeInt64Values)
{
constexpr int64_t large = 1LL << 40;
EXPECT_THAT(makeStrides(large, 1), ElementsAre(large, 1, 0, 0, 0, 0));
}
} // namespace

View File

@@ -0,0 +1,79 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <rocm_ck/datatype.hpp>
#include <gtest/gtest.h>
#include <string>
using ::rocm_ck::DataType;
using ::rocm_ck::dataTypeBits;
using ::rocm_ck::dataTypeName;
using ::testing::TestParamInfo;
using ::testing::TestWithParam;
using ::testing::Values;
namespace {
// ============================================================================
// Parameterized: one row per DataType variant
// ============================================================================
struct DataTypeEntry
{
DataType dt;
int bits;
const char* name;
};
class DataTypeTest : public TestWithParam<DataTypeEntry>
{
};
TEST_P(DataTypeTest, ReportsCorrectBits)
{
EXPECT_EQ(dataTypeBits(GetParam().dt), GetParam().bits);
}
TEST_P(DataTypeTest, MapsToExpectedName)
{
EXPECT_STREQ(dataTypeName(GetParam().dt), GetParam().name);
}
INSTANTIATE_TEST_SUITE_P(
AllTypes,
DataTypeTest,
Values(DataTypeEntry{.dt = DataType::FP64, .bits = 64, .name = "FP64"},
DataTypeEntry{.dt = DataType::FP32, .bits = 32, .name = "FP32"},
DataTypeEntry{.dt = DataType::FP16, .bits = 16, .name = "FP16"},
DataTypeEntry{.dt = DataType::BF16, .bits = 16, .name = "BF16"},
DataTypeEntry{.dt = DataType::FP8_FNUZ, .bits = 8, .name = "FP8_FNUZ"},
DataTypeEntry{.dt = DataType::BF8_FNUZ, .bits = 8, .name = "BF8_FNUZ"},
DataTypeEntry{.dt = DataType::FP8_OCP, .bits = 8, .name = "FP8_OCP"},
DataTypeEntry{.dt = DataType::BF8_OCP, .bits = 8, .name = "BF8_OCP"},
DataTypeEntry{.dt = DataType::I4, .bits = 4, .name = "I4"},
DataTypeEntry{.dt = DataType::I8, .bits = 8, .name = "I8"},
DataTypeEntry{.dt = DataType::I16, .bits = 16, .name = "I16"},
DataTypeEntry{.dt = DataType::I32, .bits = 32, .name = "I32"},
DataTypeEntry{.dt = DataType::I64, .bits = 64, .name = "I64"},
DataTypeEntry{.dt = DataType::U8, .bits = 8, .name = "U8"},
DataTypeEntry{.dt = DataType::U16, .bits = 16, .name = "U16"},
DataTypeEntry{.dt = DataType::U32, .bits = 32, .name = "U32"},
DataTypeEntry{.dt = DataType::U64, .bits = 64, .name = "U64"}),
[](const TestParamInfo<DataTypeEntry>& p) { return std::string(p.param.name); });
// ============================================================================
// constexpr validation
// ============================================================================
TEST(DataType, EvaluatesBitsAndNameAtCompileTime)
{
constexpr int fp32_bits = dataTypeBits(DataType::FP32);
EXPECT_EQ(fp32_bits, 32);
constexpr const char* fp32_name = dataTypeName(DataType::FP32);
EXPECT_STREQ(fp32_name, "FP32");
}
} // namespace

View File

@@ -0,0 +1,50 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <rocm_ck/fixed_string.hpp>
#include <gtest/gtest.h>
using ::rocm_ck::FixedString;
namespace {
TEST(FixedString, MatchesSingleCharacter)
{
EXPECT_TRUE(FixedString<16>("A") == "A");
EXPECT_FALSE(FixedString<16>("A") == "B");
}
TEST(FixedString, MatchesExactStringOnly)
{
EXPECT_TRUE(FixedString<16>("bias") == "bias");
EXPECT_FALSE(FixedString<16>("bias") == "bia");
EXPECT_FALSE(FixedString<16>("bias") == "biases");
}
TEST(FixedString, AcceptsMaxCapacityMinusOne)
{
EXPECT_TRUE(FixedString<16>("123456789012345") == "123456789012345");
}
TEST(FixedString, SupportsEmptyString)
{
EXPECT_EQ(FixedString<16>("").len, 0);
EXPECT_TRUE(FixedString<16>("") == "");
EXPECT_FALSE(FixedString<16>("") == "A");
}
TEST(FixedString, EqualStringsCompareEqual)
{
EXPECT_EQ(FixedString<16>("A"), FixedString<16>("A"));
EXPECT_NE(FixedString<16>("A"), FixedString<16>("B"));
}
TEST(FixedString, OrderingIsLexicographic)
{
EXPECT_LT(FixedString<16>("A"), FixedString<16>("B"));
EXPECT_LT(FixedString<16>("B"), FixedString<16>("Z"));
EXPECT_GT(FixedString<16>("Z"), FixedString<16>("A"));
}
} // namespace

View File

@@ -0,0 +1,88 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <rocm_ck/layout.hpp>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <array>
using ::rocm_ck::isValidLayoutForRank;
using ::rocm_ck::Layout;
using ::rocm_ck::layoutName;
using ::rocm_ck::layoutStrides;
using ::rocm_ck::leadingDimStride;
using ::testing::ElementsAre;
namespace {
// ============================================================================
// layoutName
// ============================================================================
TEST(Layout, MapsEnumValuesToExpectedStrings)
{
EXPECT_STREQ(layoutName(Layout::Row), "Row");
EXPECT_STREQ(layoutName(Layout::Col), "Col");
EXPECT_STREQ(layoutName(Layout::Auto), "Auto");
}
// ============================================================================
// isValidLayoutForRank
// ============================================================================
TEST(Layout, AllowsRowAndColOnlyForRank2)
{
EXPECT_FALSE(isValidLayoutForRank(Layout::Row, 1));
EXPECT_TRUE(isValidLayoutForRank(Layout::Row, 2));
EXPECT_FALSE(isValidLayoutForRank(Layout::Col, 1));
EXPECT_TRUE(isValidLayoutForRank(Layout::Col, 2));
}
TEST(Layout, RejectsAutoForAllRanks)
{
EXPECT_FALSE(isValidLayoutForRank(Layout::Auto, 0));
EXPECT_FALSE(isValidLayoutForRank(Layout::Auto, 1));
EXPECT_FALSE(isValidLayoutForRank(Layout::Auto, 2));
}
TEST(Layout, RejectsRowAndColForRankGreaterThan2)
{
EXPECT_FALSE(isValidLayoutForRank(Layout::Row, 3));
EXPECT_FALSE(isValidLayoutForRank(Layout::Row, 4));
EXPECT_FALSE(isValidLayoutForRank(Layout::Row, 6));
EXPECT_FALSE(isValidLayoutForRank(Layout::Col, 3));
EXPECT_FALSE(isValidLayoutForRank(Layout::Col, 4));
EXPECT_FALSE(isValidLayoutForRank(Layout::Col, 6));
}
// ============================================================================
// leadingDimStride
// ============================================================================
TEST(Layout, LeadingDimStrideReturnsFirstForRow)
{
EXPECT_EQ(leadingDimStride(Layout::Row, std::array{128, 1}), 128);
}
TEST(Layout, LeadingDimStrideReturnsSecondForCol)
{
EXPECT_EQ(leadingDimStride(Layout::Col, std::array{1, 64}), 64);
}
// ============================================================================
// layoutStrides
// ============================================================================
TEST(Layout, LayoutStridesRowMajor)
{
EXPECT_THAT(layoutStrides(Layout::Row, 32, 64), ElementsAre(64, 1));
}
TEST(Layout, LayoutStridesColMajor)
{
EXPECT_THAT(layoutStrides(Layout::Col, 32, 64), ElementsAre(1, 32));
}
} // namespace

View File

@@ -0,0 +1,35 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <rocm_ck/physical_tensor.hpp>
#include <gtest/gtest.h>
using ::rocm_ck::DataType;
using ::rocm_ck::FixedString;
using ::rocm_ck::kMaxPhysicalTensors;
using ::rocm_ck::Layout;
using ::rocm_ck::PhysicalTensor;
namespace {
TEST(PhysicalTensor, InitializesWithFP32RowAndSlotZero)
{
constexpr PhysicalTensor pt{};
EXPECT_EQ(pt.dtype, DataType::FP32);
EXPECT_EQ(pt.layout, Layout::Row);
EXPECT_EQ(pt.args_slot, 0);
}
TEST(PhysicalTensor, StoresAllFieldsFromConstruction)
{
constexpr PhysicalTensor pt{FixedString<16>("bias"), DataType::FP16, Layout::Col, 3};
EXPECT_TRUE(pt.name == "bias");
EXPECT_EQ(pt.dtype, DataType::FP16);
EXPECT_EQ(pt.layout, Layout::Col);
EXPECT_EQ(pt.args_slot, 3);
}
TEST(PhysicalTensor, LimitsCapacityTo8) { EXPECT_EQ(kMaxPhysicalTensors, 8); }
} // namespace