mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
[rocm-libraries] ROCm/rocm-libraries#7114 (commit ecef372)
[CK] Add rocm_ck foundation types: DataType, Layout, Args, Ops (#7114) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Add the vocabulary types that all rocm_ck schema headers build on - 9 new headers under `include/rocm_ck/`, 6 unit test files - Pure C++20, host-only — no CK Tile dependencies **Headers:** | Header | Purpose | |--------|---------| | `index_t.hpp` | `index_t`, `long_index_t` (matches ck_tile) | | `gpu_target.hpp` | `GpuTarget` enum (ISA targets) | | `datatype.hpp` | `DataType` enum (17 variants) | | `layout.hpp` | `Layout` enum (Row, Col, Auto) + stride helpers | | `fixed_string.hpp` | `FixedString<N>` — structural string for NTTPs | | `args.hpp` | Generic kernel argument buffer (ABI) | | `ops.hpp` | Operator structs (`GemmOp`, `AddOp`, ...) + `Op` variant | | `physical_tensor.hpp` | `PhysicalTensor` — maps names to Args slots | | `resolved_tensor.hpp` | `ResolvedTensor` — output of `Signature::resolve()` | **Stack**: This is PR 1 of 3 porting the rocm_ck constexpr schema from experimental to production, #7143. 1. **This PR** — Foundation types (vocabulary) 2. Schema engine — `Signature`, `resolve()`, `ArchProperties` 3. Spec factories — `GemmSpec`, `ElementwiseSpec`, `makeSpec()` ## Test plan - [ ] `ninja build-smoke-rocm-ck` builds all tests - [ ] `ctest -L ROCM_CK_SMOKE --output-on-failure` — 6 unit tests pass (86 test cases) - [ ] Default CK build (`CK_ENABLE_ROCM_CK=OFF`) unaffected 🤖 Generated with [Claude Code](https://claude.com/claude-code)
This commit is contained in:
committed by
assistant-librarian[bot]
parent
187ef8ac94
commit
3e110e1718
@@ -3,11 +3,10 @@
|
||||
A C++20 constexpr API for configuring and distributing
|
||||
[CK Tile](../include/ck_tile/) GPU kernels across multiple architectures.
|
||||
|
||||
> **Status**: Early development. The current code establishes the directory
|
||||
> structure, build integration, and CI pipeline. A single unit test verifies
|
||||
> that the build and test infrastructure works end-to-end in Jenkins.
|
||||
> The schema types, device bridge, and kernel tests described below are
|
||||
> under active development.
|
||||
> **Status**: Early development. Foundation types are in place (DataType,
|
||||
> Layout, Args, operators, FixedString, PhysicalTensor, ResolvedTensor).
|
||||
> The schema engine (Signature, resolve(), Algorithm) and device bridge
|
||||
> are under active development.
|
||||
|
||||
## Why rocm_ck exists
|
||||
|
||||
|
||||
89
rocm_ck/include/rocm_ck/args.hpp
Normal file
89
rocm_ck/include/rocm_ck/args.hpp
Normal file
@@ -0,0 +1,89 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Role: abi — shared between host and device. Trivially copyable, no CK deps.
|
||||
//
|
||||
// Args is a hardware buffer for passing data between CPU and GPU during a
|
||||
// kernel call. It carries raw pointers, shapes, strides, and scalar values —
|
||||
// nothing more. All semantic meaning (which tensor is "A", which scalar is
|
||||
// "alpha", input vs output) lives in the Signature, not here.
|
||||
//
|
||||
// This is deliberately one type for all operations. Per-operation structs
|
||||
// (GemmArgs, FmhaArgs, ...) would make the dispatcher a closed set — adding
|
||||
// an operation means adding a type, updating launch code, and changing the
|
||||
// kpack format. A generic buffer keeps the dispatcher open.
|
||||
//
|
||||
// Capacity limits (kMaxRank=6, kMaxTensors=16, kMaxScalars=16) are sized to
|
||||
// the most demanding current operation (FMHA backward: ~12 tensors, ~12
|
||||
// scalars, rank-6 for grouped 3D conv). If a future operation exceeds these,
|
||||
// bump the constants — the layout is not versioned, and the 4KB HSA kernarg
|
||||
// budget has room. Don't over-provision speculatively.
|
||||
//
|
||||
// Key constraints:
|
||||
// - Trivially copyable, standard layout — required for HSA kernarg passing.
|
||||
// - Fixed-capacity arrays, no heap — sizeof fits the 4KB kernarg budget.
|
||||
// - const void* for all tensor pointers — the entry kernel casts to the
|
||||
// concrete type. Input vs output semantics live in the Signature.
|
||||
// - No runtime type tags on scalars — the Signature declares types at
|
||||
// compile time. The entry kernel reads the correct union member.
|
||||
// - Slot ordering is the invariant: tensors[i] maps to Signature::tensors[i].
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <rocm_ck/index_t.hpp>
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
|
||||
namespace rocm_ck {
|
||||
|
||||
// When changing these, update the byte-size comments on TensorArg and Args fields.
|
||||
constexpr int kMaxRank = 6; // grouped 3D conv = GNCDHW = rank 6
|
||||
constexpr int kMaxTensors = 16; // FMHA backward uses ~12
|
||||
constexpr int kMaxScalars = 16; // FMHA with masking+dropout needs ~12
|
||||
|
||||
struct TensorArg
|
||||
{
|
||||
const void* ptr; // 8 bytes (offset 0)
|
||||
std::array<index_t, kMaxRank> lengths; // 24 bytes (offset 8) — int32
|
||||
std::array<long_index_t, kMaxRank> strides; // 48 bytes (offset 32) — int64
|
||||
};
|
||||
|
||||
// FP16/BF16/FP8 scalars use f32 — scalar precision >= tensor precision.
|
||||
union ScalarValue
|
||||
{
|
||||
float f32;
|
||||
int32_t i32;
|
||||
uint32_t u32;
|
||||
double f64;
|
||||
int64_t i64;
|
||||
uint64_t u64;
|
||||
};
|
||||
|
||||
// Slot ordering matches Signature: tensors[i] <-> Signature::tensors[i].
|
||||
struct Args
|
||||
{
|
||||
std::array<TensorArg, kMaxTensors> tensors; // 16 x 80 = 1280 bytes
|
||||
std::array<ScalarValue, kMaxScalars> scalars; // 16 x 8 = 128 bytes
|
||||
|
||||
index_t batch_count = 0; // 4 bytes
|
||||
std::array<long_index_t, kMaxTensors> batch_strides = {}; // 16 x 8 = 128 bytes
|
||||
void* workspace_ptr = nullptr; // 8 bytes
|
||||
};
|
||||
|
||||
constexpr std::array<index_t, kMaxRank> makeShape(
|
||||
index_t d0, index_t d1 = 0, index_t d2 = 0, index_t d3 = 0, index_t d4 = 0, index_t d5 = 0)
|
||||
{
|
||||
return {d0, d1, d2, d3, d4, d5};
|
||||
}
|
||||
|
||||
constexpr std::array<long_index_t, kMaxRank> makeStrides(long_index_t s0,
|
||||
long_index_t s1 = 0,
|
||||
long_index_t s2 = 0,
|
||||
long_index_t s3 = 0,
|
||||
long_index_t s4 = 0,
|
||||
long_index_t s5 = 0)
|
||||
{
|
||||
return {s0, s1, s2, s3, s4, s5};
|
||||
}
|
||||
|
||||
} // namespace rocm_ck
|
||||
101
rocm_ck/include/rocm_ck/datatype.hpp
Normal file
101
rocm_ck/include/rocm_ck/datatype.hpp
Normal file
@@ -0,0 +1,101 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Role: types — DataType enum, constexpr queries. No runtime, no CK deps.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "rocm_ck/platform.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace rocm_ck {
|
||||
|
||||
// FP8 = e4m3, BF8 = e5m2 (CK convention).
|
||||
enum class DataType : uint8_t
|
||||
{
|
||||
// Floating point — standard widths
|
||||
FP64,
|
||||
FP32,
|
||||
FP16,
|
||||
BF16,
|
||||
|
||||
// FP8 variants — see note below
|
||||
FP8_FNUZ,
|
||||
BF8_FNUZ,
|
||||
FP8_OCP,
|
||||
BF8_OCP,
|
||||
|
||||
// Integer types — signed and unsigned at each width
|
||||
I4,
|
||||
I8,
|
||||
I16,
|
||||
I32,
|
||||
I64,
|
||||
U8,
|
||||
U16,
|
||||
U32,
|
||||
U64
|
||||
};
|
||||
|
||||
// FP8 variants — FNUZ and OCP are different number formats, not just HW hints.
|
||||
// FNUZ: gfx942 native (higher bias, no Inf, max 240)
|
||||
// OCP: gfx950 native (OCP standard, has Inf, max 448)
|
||||
// Non-native formats run in software (slower) and produce different numerical
|
||||
// results. Choose based on target GPU and model training format.
|
||||
// We keep FNUZ and OCP explicit rather than a generic FP8 — the numerical
|
||||
// differences matter for compatibility and schema-driven test coverage.
|
||||
// TODO - We may introduce a generic FP8/BF8 that resolves to the hardware-native type.
|
||||
// See: https://rocm.docs.amd.com/projects/HIP/en/latest/reference/fp8_numbers.html
|
||||
|
||||
// Bits (not bytes) so sub-byte types (I4) are clean integers.
|
||||
constexpr int dataTypeBits(DataType dt)
|
||||
{
|
||||
switch(dt)
|
||||
{
|
||||
case DataType::FP64: return 64;
|
||||
case DataType::FP32: return 32;
|
||||
case DataType::FP16: return 16;
|
||||
case DataType::BF16: return 16;
|
||||
case DataType::FP8_FNUZ: return 8;
|
||||
case DataType::BF8_FNUZ: return 8;
|
||||
case DataType::FP8_OCP: return 8;
|
||||
case DataType::BF8_OCP: return 8;
|
||||
case DataType::I4: return 4;
|
||||
case DataType::I8: return 8;
|
||||
case DataType::I16: return 16;
|
||||
case DataType::I32: return 32;
|
||||
case DataType::I64: return 64;
|
||||
case DataType::U8: return 8;
|
||||
case DataType::U16: return 16;
|
||||
case DataType::U32: return 32;
|
||||
case DataType::U64: return 64;
|
||||
}
|
||||
ROCM_CK_UNREACHABLE();
|
||||
}
|
||||
|
||||
constexpr const char* dataTypeName(DataType dt)
|
||||
{
|
||||
switch(dt)
|
||||
{
|
||||
case DataType::FP64: return "FP64";
|
||||
case DataType::FP32: return "FP32";
|
||||
case DataType::FP16: return "FP16";
|
||||
case DataType::BF16: return "BF16";
|
||||
case DataType::FP8_FNUZ: return "FP8_FNUZ";
|
||||
case DataType::BF8_FNUZ: return "BF8_FNUZ";
|
||||
case DataType::FP8_OCP: return "FP8_OCP";
|
||||
case DataType::BF8_OCP: return "BF8_OCP";
|
||||
case DataType::I4: return "I4";
|
||||
case DataType::I8: return "I8";
|
||||
case DataType::I16: return "I16";
|
||||
case DataType::I32: return "I32";
|
||||
case DataType::I64: return "I64";
|
||||
case DataType::U8: return "U8";
|
||||
case DataType::U16: return "U16";
|
||||
case DataType::U32: return "U32";
|
||||
case DataType::U64: return "U64";
|
||||
}
|
||||
ROCM_CK_UNREACHABLE();
|
||||
}
|
||||
|
||||
} // namespace rocm_ck
|
||||
62
rocm_ck/include/rocm_ck/fixed_string.hpp
Normal file
62
rocm_ck/include/rocm_ck/fixed_string.hpp
Normal file
@@ -0,0 +1,62 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Role: types — FixedString. No runtime, no CK deps.
|
||||
//
|
||||
// A compile-time string for use in template parameters (NTTPs).
|
||||
//
|
||||
// C++20 requires template parameters to be "structural types" — loosely, types
|
||||
// that are trivially comparable and don't contain pointers or references.
|
||||
// std::string and std::string_view fail this requirement (internal pointer).
|
||||
//
|
||||
// FixedString stores the string inline in a char array, making it structural:
|
||||
//
|
||||
// template <PhysicalTensor PT> // PhysicalTensor contains FixedString<16>
|
||||
// void dispatch() { ... }
|
||||
//
|
||||
// When to use FixedString vs std::string_view:
|
||||
// - FixedString: the type must be structural (template parameters).
|
||||
// - string_view: consteval-only types that never become template parameters
|
||||
// (e.g., ResolvedTensor — see resolved_tensor.hpp).
|
||||
//
|
||||
// The capacity is a template parameter so each use site documents its limit:
|
||||
// FixedString<16> name("bias"); // tensor names: 15 chars max
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <string_view>
|
||||
|
||||
namespace rocm_ck {
|
||||
|
||||
template <std::size_t MaxLen>
|
||||
struct FixedString
|
||||
{
|
||||
char data[MaxLen]{};
|
||||
int len = 0;
|
||||
|
||||
constexpr FixedString() = default;
|
||||
|
||||
constexpr FixedString(std::string_view sv) : len(static_cast<int>(sv.size()))
|
||||
{
|
||||
if(sv.size() > MaxLen - 1)
|
||||
throw "FixedString: input exceeds capacity";
|
||||
for(int i = 0; i < len; ++i)
|
||||
data[i] = sv[i];
|
||||
}
|
||||
|
||||
constexpr bool operator==(std::string_view sv) const
|
||||
{
|
||||
if(len != static_cast<int>(sv.size()))
|
||||
return false;
|
||||
for(int i = 0; i < len; ++i)
|
||||
if(data[i] != sv[i])
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Required: the string_view overload above suppresses the implicit == from <=>.
|
||||
constexpr bool operator==(const FixedString&) const = default;
|
||||
constexpr auto operator<=>(const FixedString&) const = default;
|
||||
};
|
||||
|
||||
} // namespace rocm_ck
|
||||
24
rocm_ck/include/rocm_ck/gpu_target.hpp
Normal file
24
rocm_ck/include/rocm_ck/gpu_target.hpp
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Role: types — GpuTarget enum. No runtime, no CK deps.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace rocm_ck {
|
||||
|
||||
// ISA target identifiers (matching -mcpu flags), not marketing names.
|
||||
enum class GpuTarget : uint8_t
|
||||
{
|
||||
gfx90a, // CDNA 2
|
||||
gfx942, // CDNA 3
|
||||
gfx950, // CDNA 4
|
||||
gfx1100, // RDNA 3
|
||||
gfx1101, // RDNA 3
|
||||
gfx1102, // RDNA 3
|
||||
gfx1150, // RDNA 3.5
|
||||
gfx1151, // RDNA 3.5
|
||||
};
|
||||
|
||||
} // namespace rocm_ck
|
||||
69
rocm_ck/include/rocm_ck/layout.hpp
Normal file
69
rocm_ck/include/rocm_ck/layout.hpp
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Role: types — Layout enum, constexpr/consteval helpers. No runtime, no CK deps.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "rocm_ck/platform.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
namespace rocm_ck {
|
||||
|
||||
// Auto is a resolve-time placeholder — Signature::resolve() replaces it with
|
||||
// the concrete layout from the operator slot. It never reaches the kernel.
|
||||
enum class Layout : uint8_t
|
||||
{
|
||||
Row,
|
||||
Col,
|
||||
Auto
|
||||
};
|
||||
|
||||
constexpr const char* layoutName(Layout layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case Layout::Row: return "Row";
|
||||
case Layout::Col: return "Col";
|
||||
case Layout::Auto: return "Auto";
|
||||
}
|
||||
ROCM_CK_UNREACHABLE();
|
||||
}
|
||||
|
||||
constexpr bool isValidLayoutForRank(Layout layout, int rank)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case Layout::Row: return rank == 2;
|
||||
case Layout::Col: return rank == 2;
|
||||
case Layout::Auto: return false;
|
||||
}
|
||||
ROCM_CK_UNREACHABLE();
|
||||
}
|
||||
|
||||
template <typename T, std::size_t N>
|
||||
constexpr T leadingDimStride(Layout layout, const std::array<T, N>& strides)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case Layout::Row: return strides[0];
|
||||
case Layout::Col: return strides[1];
|
||||
case Layout::Auto: throw "leadingDimStride requires Row or Col layout";
|
||||
}
|
||||
ROCM_CK_UNREACHABLE();
|
||||
}
|
||||
|
||||
constexpr std::array<int, 2> layoutStrides(Layout layout, int rows, int cols)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case Layout::Row: return {cols, 1};
|
||||
case Layout::Col: return {1, rows};
|
||||
case Layout::Auto: throw "layoutStrides requires Row or Col layout";
|
||||
}
|
||||
ROCM_CK_UNREACHABLE();
|
||||
}
|
||||
|
||||
} // namespace rocm_ck
|
||||
139
rocm_ck/include/rocm_ck/ops.hpp
Normal file
139
rocm_ck/include/rocm_ck/ops.hpp
Normal file
@@ -0,0 +1,139 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Role: meta — operator structs, Op variant. No runtime, no CK deps.
|
||||
//
|
||||
// Operators are the edges of a Signature's compute graph. Each operator
|
||||
// names its tensor slots as string_view labels (e.g., "A", "bias", "query")
|
||||
// that refer to tensors declared elsewhere in the Signature. The Signature
|
||||
// owns the tensor definitions; operators just reference them by name.
|
||||
//
|
||||
// This separation means operators are reusable across different tensor
|
||||
// configurations — a GemmOp doesn't care whether its "lhs" is FP16 or BF16,
|
||||
// Row or Col. That's resolved later when the Signature is validated.
|
||||
//
|
||||
// The Op variant is the closed set of supported operator types. Adding a
|
||||
// new operator means adding a struct here and a variant alternative.
|
||||
// Fused operations (like FMHA) are single operators — not chains of
|
||||
// elementwise + GEMM — because CK Tile implements them as monolithic kernels.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <rocm_ck/datatype.hpp>
|
||||
|
||||
#include <string_view>
|
||||
#include <variant>
|
||||
|
||||
namespace rocm_ck {
|
||||
|
||||
// Matrix multiplication: out = lhs x rhs.
|
||||
// acc_dtype is the accumulation type — defaults to FP32, the universal safe
|
||||
// choice across all input types.
|
||||
struct GemmOp
|
||||
{
|
||||
std::string_view lhs;
|
||||
std::string_view rhs;
|
||||
std::string_view out;
|
||||
DataType acc_dtype = DataType::FP32;
|
||||
};
|
||||
|
||||
// Element-wise addition: out = lhs + rhs.
|
||||
struct AddOp
|
||||
{
|
||||
std::string_view lhs;
|
||||
std::string_view rhs;
|
||||
std::string_view out;
|
||||
};
|
||||
|
||||
// Element-wise multiplication: out = lhs * rhs.
|
||||
struct MulOp
|
||||
{
|
||||
std::string_view lhs;
|
||||
std::string_view rhs;
|
||||
std::string_view out;
|
||||
};
|
||||
|
||||
// ReLU activation: out = max(0, in).
|
||||
struct ReluOp
|
||||
{
|
||||
std::string_view in;
|
||||
std::string_view out;
|
||||
};
|
||||
|
||||
// Fast GELU approximation: out = in * sigmoid(1.702 * in).
|
||||
struct FastGeluOp
|
||||
{
|
||||
std::string_view in;
|
||||
std::string_view out;
|
||||
};
|
||||
|
||||
// Exact GELU: out = 0.5 * in * (1 + erf(in / sqrt(2))).
|
||||
struct GeluOp
|
||||
{
|
||||
std::string_view in;
|
||||
std::string_view out;
|
||||
};
|
||||
|
||||
// SiLU (Swish) activation: out = in * sigmoid(in).
|
||||
struct SiluOp
|
||||
{
|
||||
std::string_view in;
|
||||
std::string_view out;
|
||||
};
|
||||
|
||||
// Sigmoid activation: out = 1 / (1 + exp(-in)).
|
||||
struct SigmoidOp
|
||||
{
|
||||
std::string_view in;
|
||||
std::string_view out;
|
||||
};
|
||||
|
||||
// Softmax: out[i] = exp(in[i]) / sum(exp(in)), reduction along last dimension.
|
||||
struct SoftmaxOp
|
||||
{
|
||||
std::string_view in;
|
||||
std::string_view out;
|
||||
};
|
||||
|
||||
// Scalar multiply: out = in * scale.
|
||||
// 'scale' names a Scalar in the Signature, not a tensor.
|
||||
struct ScaleOp
|
||||
{
|
||||
std::string_view in;
|
||||
std::string_view out;
|
||||
std::string_view scale;
|
||||
};
|
||||
|
||||
// Fused multi-head attention backward pass.
|
||||
// Implemented as a single CK Tile kernel, not a chain of ops.
|
||||
// Feature flags (mask, dropout, bias, deterministic) belong in the Algorithm.
|
||||
struct FmhaBwdOp
|
||||
{
|
||||
std::string_view q; // query
|
||||
std::string_view k; // key
|
||||
std::string_view v; // value
|
||||
std::string_view lse; // log-sum-exp from forward pass
|
||||
std::string_view do_; // output gradient
|
||||
std::string_view d; // dot(output_grad, output)
|
||||
|
||||
std::string_view dq; // query gradient
|
||||
std::string_view dk; // key gradient
|
||||
std::string_view dv; // value gradient
|
||||
|
||||
DataType acc_dtype = DataType::FP32;
|
||||
};
|
||||
|
||||
// The closed set of supported operators. std::monostate marks empty slots.
|
||||
using Op = std::variant<std::monostate,
|
||||
GemmOp,
|
||||
AddOp,
|
||||
MulOp,
|
||||
ReluOp,
|
||||
FastGeluOp,
|
||||
GeluOp,
|
||||
SiluOp,
|
||||
SigmoidOp,
|
||||
SoftmaxOp,
|
||||
ScaleOp,
|
||||
FmhaBwdOp>;
|
||||
|
||||
} // namespace rocm_ck
|
||||
29
rocm_ck/include/rocm_ck/physical_tensor.hpp
Normal file
29
rocm_ck/include/rocm_ck/physical_tensor.hpp
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Role: types — PhysicalTensor. No runtime, no CK deps.
|
||||
//
|
||||
// A PhysicalTensor maps a named tensor from the Signature graph to a slot
|
||||
// in the generic Args buffer. Not every tensor in a compute graph is physical —
|
||||
// intermediate values (e.g., the S matrix in FMHA = Q*K^T) live only in
|
||||
// registers and never appear in device memory. The physical tensor table
|
||||
// describes exactly what the host needs to pack into Args.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <rocm_ck/datatype.hpp>
|
||||
#include <rocm_ck/fixed_string.hpp>
|
||||
#include <rocm_ck/layout.hpp>
|
||||
|
||||
namespace rocm_ck {
|
||||
|
||||
inline constexpr int kMaxPhysicalTensors = 8;
|
||||
|
||||
struct PhysicalTensor
|
||||
{
|
||||
FixedString<16> name;
|
||||
DataType dtype = DataType::FP32;
|
||||
Layout layout = Layout::Row;
|
||||
int args_slot = 0;
|
||||
};
|
||||
|
||||
} // namespace rocm_ck
|
||||
13
rocm_ck/include/rocm_ck/platform.hpp
Normal file
13
rocm_ck/include/rocm_ck/platform.hpp
Normal file
@@ -0,0 +1,13 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Compiler portability macros for LLVM/Clang/GCC and MSVC.
|
||||
// C++23 will provide std::unreachable(): https://en.cppreference.com/w/cpp/utility/unreachable
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define ROCM_CK_UNREACHABLE() __assume(false)
|
||||
#else
|
||||
#define ROCM_CK_UNREACHABLE() __builtin_unreachable()
|
||||
#endif
|
||||
59
rocm_ck/include/rocm_ck/resolved_tensor.hpp
Normal file
59
rocm_ck/include/rocm_ck/resolved_tensor.hpp
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Role: types — ResolvedTensor, ResolvedQuantization. No runtime, no CK deps.
|
||||
//
|
||||
// ResolvedTensor is the intermediate result of consteval resolution. It exists
|
||||
// only at compile time — produced by Signature::resolve() and consumed by
|
||||
// makeSpec(), both consteval. It never appears in compiled code.
|
||||
//
|
||||
// In the user-facing Signature, tensors can have Layout::Auto (inherit from
|
||||
// operator slot) and omit fields with sensible defaults. After resolution,
|
||||
// every field is concrete. The base fields (name, dtype, rank, layout)
|
||||
// describe a plain dense tensor — enough for most operands (GEMM inputs,
|
||||
// outputs, bias vectors). Some tensors carry additional metadata beyond the
|
||||
// dense description. Block-quantized tensors (e.g., INT4 weights) need a
|
||||
// scale tensor and group size. We use optional sub-structs for these
|
||||
// extensions, keeping the common case clean without bloating every instance.
|
||||
//
|
||||
// Why std::string_view instead of FixedString?
|
||||
// ResolvedTensor is consteval-only — produced and consumed entirely at
|
||||
// compile time. No library loading, no runtime lifetime concerns. The
|
||||
// string_views point to string literals from user code (e.g.,
|
||||
// GemmOp{.lhs = "A"}), which have static storage duration — no dangling.
|
||||
// FixedString is required for PhysicalTensor because it IS used as a
|
||||
// template parameter (NTTP), which requires structural types (no pointers).
|
||||
// ResolvedTensor is never a template parameter.
|
||||
//
|
||||
// Plain aggregate — no methods, no validation. Resolution validates; this
|
||||
// type just carries the result to makeSpec().
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <rocm_ck/datatype.hpp>
|
||||
#include <rocm_ck/layout.hpp>
|
||||
|
||||
#include <optional>
|
||||
#include <string_view>
|
||||
|
||||
namespace rocm_ck {
|
||||
|
||||
// Present when a tensor carries block-quantized data (e.g., INT4 weights).
|
||||
// The scale tensor is a separate entry in the Signature; this struct ties
|
||||
// the quantized tensor to its scale.
|
||||
struct ResolvedQuantization
|
||||
{
|
||||
std::string_view scale_name;
|
||||
DataType scale_dtype;
|
||||
int group_size; // elements per quantization group
|
||||
};
|
||||
|
||||
struct ResolvedTensor
|
||||
{
|
||||
std::string_view name;
|
||||
DataType dtype;
|
||||
int rank = 2;
|
||||
Layout layout = Layout::Row;
|
||||
std::optional<ResolvedQuantization> quantize = std::nullopt;
|
||||
};
|
||||
|
||||
} // namespace rocm_ck
|
||||
@@ -4,7 +4,7 @@
|
||||
# rocm_ck tests
|
||||
#
|
||||
# Test tiers:
|
||||
# ROCM_CK_SMOKE — Fast host-only tests (< 1s total). No GPU, no HIP.
|
||||
# ROCM_CK_SMOKE — Fast host-only tests (< 1s total). No GPU required.
|
||||
# ROCM_CK_KERNEL — GPU kernel tests. Require HIP and a GPU.
|
||||
#
|
||||
# Usage:
|
||||
@@ -22,37 +22,33 @@ include(${CMAKE_CURRENT_SOURCE_DIR}/../../cmake/gtest.cmake)
|
||||
# ---------------------------------------------------------------------------
|
||||
function(add_rocm_ck_test test_name)
|
||||
add_executable(${test_name} ${ARGN})
|
||||
target_link_libraries(${test_name} PRIVATE rocm_ck GTest::gtest_main)
|
||||
target_link_libraries(${test_name} PRIVATE rocm_ck GTest::gtest_main GTest::gmock)
|
||||
target_compile_options(${test_name} PRIVATE
|
||||
-Wno-global-constructors # GTest registration macros
|
||||
-Wno-undef # GTest internal headers
|
||||
-Wno-global-constructors # GTest registration macros
|
||||
-Wno-undef # GTest internal headers
|
||||
-Wno-zero-as-null-pointer-constant # C++20 <=> comparisons to 0
|
||||
)
|
||||
endfunction()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Smoke tests (fast, host-only, no GPU)
|
||||
# ---------------------------------------------------------------------------
|
||||
set(ROCM_CK_SMOKE_TESTS
|
||||
add_rocm_ck_test(rocm_ck_unit
|
||||
unit/unit_args.cpp
|
||||
unit/unit_datatype.cpp
|
||||
unit/unit_fixed_string.cpp
|
||||
unit/unit_index_t.cpp
|
||||
unit/unit_layout.cpp
|
||||
unit/unit_physical_tensor.cpp
|
||||
)
|
||||
|
||||
set(ROCM_CK_SMOKE_TARGETS)
|
||||
foreach(test_source ${ROCM_CK_SMOKE_TESTS})
|
||||
get_filename_component(test_name ${test_source} NAME_WLE)
|
||||
set(target_name "rocm_ck_${test_name}")
|
||||
add_rocm_ck_test(${target_name} ${test_source})
|
||||
add_test(NAME ${target_name} COMMAND ${target_name})
|
||||
set_tests_properties(${target_name} PROPERTIES LABELS "ROCM_CK_SMOKE")
|
||||
list(APPEND ROCM_CK_SMOKE_TARGETS ${target_name})
|
||||
endforeach()
|
||||
|
||||
# rocm_ck_unit_index_t verifies rocm_ck index types match ck_tile
|
||||
target_include_directories(rocm_ck_unit_index_t PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
|
||||
target_include_directories(rocm_ck_unit PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
|
||||
add_test(NAME rocm_ck_unit COMMAND rocm_ck_unit)
|
||||
set_tests_properties(rocm_ck_unit PROPERTIES LABELS "ROCM_CK_SMOKE")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convenience targets
|
||||
# ---------------------------------------------------------------------------
|
||||
add_custom_target(build-smoke-rocm-ck DEPENDS ${ROCM_CK_SMOKE_TARGETS})
|
||||
add_custom_target(build-smoke-rocm-ck DEPENDS rocm_ck_unit)
|
||||
|
||||
add_custom_target(smoke-rocm-ck
|
||||
COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -L "ROCM_CK_SMOKE"
|
||||
|
||||
216
rocm_ck/tests/unit/unit_args.cpp
Normal file
216
rocm_ck/tests/unit/unit_args.cpp
Normal file
@@ -0,0 +1,216 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <rocm_ck/args.hpp>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
|
||||
using ::rocm_ck::Args;
|
||||
using ::rocm_ck::kMaxRank;
|
||||
using ::rocm_ck::kMaxScalars;
|
||||
using ::rocm_ck::kMaxTensors;
|
||||
using ::rocm_ck::makeShape;
|
||||
using ::rocm_ck::makeStrides;
|
||||
using ::rocm_ck::ScalarValue;
|
||||
using ::rocm_ck::TensorArg;
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
namespace {
|
||||
|
||||
// ============================================================================
|
||||
// TensorArg ABI
|
||||
// ============================================================================
|
||||
|
||||
TEST(TensorArg, IsTriviallyCopyable) { EXPECT_TRUE(std::is_trivially_copyable_v<TensorArg>); }
|
||||
|
||||
TEST(TensorArg, HasStandardLayout) { EXPECT_TRUE(std::is_standard_layout_v<TensorArg>); }
|
||||
|
||||
TEST(TensorArg, Occupies80Bytes)
|
||||
{
|
||||
// ptr(8) + lengths(6*4=24) + strides(6*8=48) = 80
|
||||
EXPECT_EQ(sizeof(TensorArg), 80);
|
||||
}
|
||||
|
||||
TEST(TensorArg, AlignsTo8Bytes) { EXPECT_EQ(alignof(TensorArg), 8); }
|
||||
|
||||
TEST(TensorArg, PlacesFieldsAtExpectedOffsets)
|
||||
{
|
||||
EXPECT_EQ(offsetof(TensorArg, ptr), 0);
|
||||
EXPECT_EQ(offsetof(TensorArg, lengths), 8);
|
||||
EXPECT_EQ(offsetof(TensorArg, strides), 32);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ScalarValue ABI
|
||||
// ============================================================================
|
||||
|
||||
TEST(ScalarValue, IsTriviallyCopyable) { EXPECT_TRUE(std::is_trivially_copyable_v<ScalarValue>); }
|
||||
|
||||
TEST(ScalarValue, Occupies8Bytes)
|
||||
{
|
||||
// Union of float(4), int32(4), uint32(4), double(8) -> 8 bytes
|
||||
EXPECT_EQ(sizeof(ScalarValue), 8);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Args ABI
|
||||
// ============================================================================
|
||||
|
||||
TEST(Args, IsTriviallyCopyable) { EXPECT_TRUE(std::is_trivially_copyable_v<Args>); }
|
||||
|
||||
TEST(Args, HasStandardLayout) { EXPECT_TRUE(std::is_standard_layout_v<Args>); }
|
||||
|
||||
TEST(Args, Occupies1552Bytes)
|
||||
{
|
||||
// 16 tensors * 80 + 16 scalars * 8 + batch_count(4) + pad(4)
|
||||
// + 16 batch_strides * 8 + workspace_ptr(8) = 1280 + 128 + 8 + 128 + 8 = 1552
|
||||
EXPECT_EQ(sizeof(Args), 1552);
|
||||
}
|
||||
|
||||
TEST(Args, AlignsTo8Bytes) { EXPECT_EQ(alignof(Args), 8); }
|
||||
|
||||
TEST(Args, FitsWithin4KBKernargBudget)
|
||||
{
|
||||
// HSA minimum kernarg size is 4096 bytes
|
||||
EXPECT_LE(sizeof(Args), 4096);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Capacity constants
|
||||
// ============================================================================
|
||||
|
||||
TEST(Args, DefinesExpectedCapacityLimits)
|
||||
{
|
||||
EXPECT_EQ(kMaxRank, 6);
|
||||
EXPECT_EQ(kMaxTensors, 16);
|
||||
EXPECT_EQ(kMaxScalars, 16);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ScalarValue union access
|
||||
// ============================================================================
|
||||
|
||||
TEST(ScalarValue, StoresAndRetrievesFloat)
|
||||
{
|
||||
ScalarValue sv{};
|
||||
sv.f32 = 3.14f;
|
||||
EXPECT_FLOAT_EQ(sv.f32, 3.14f);
|
||||
}
|
||||
|
||||
TEST(ScalarValue, StoresAndRetrievesInt32)
|
||||
{
|
||||
ScalarValue sv{};
|
||||
sv.i32 = -42;
|
||||
EXPECT_EQ(sv.i32, -42);
|
||||
}
|
||||
|
||||
TEST(ScalarValue, StoresAndRetrievesDouble)
|
||||
{
|
||||
ScalarValue sv{};
|
||||
sv.f64 = 2.718281828;
|
||||
EXPECT_DOUBLE_EQ(sv.f64, 2.718281828);
|
||||
}
|
||||
|
||||
TEST(ScalarValue, StoresAndRetrievesUInt32)
|
||||
{
|
||||
ScalarValue sv{};
|
||||
sv.u32 = 0xDEADBEEF;
|
||||
EXPECT_EQ(sv.u32, 0xDEADBEEF);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Args field coverage — batch_strides and workspace_ptr
|
||||
// ============================================================================
|
||||
|
||||
TEST(Args, BatchStridesFieldExists)
|
||||
{
|
||||
Args args{};
|
||||
args.batch_strides[0] = 12345;
|
||||
args.batch_strides[kMaxTensors - 1] = -99;
|
||||
EXPECT_EQ(args.batch_strides[0], 12345);
|
||||
EXPECT_EQ(args.batch_strides[kMaxTensors - 1], -99);
|
||||
}
|
||||
|
||||
TEST(Args, WorkspacePtrFieldExists)
|
||||
{
|
||||
Args args{};
|
||||
int dummy = 42;
|
||||
args.workspace_ptr = &dummy;
|
||||
EXPECT_EQ(args.workspace_ptr, &dummy);
|
||||
}
|
||||
|
||||
TEST(Args, BatchCountFieldExists)
|
||||
{
|
||||
Args args{};
|
||||
args.batch_count = 8;
|
||||
EXPECT_EQ(args.batch_count, 8);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Boundary access tests
|
||||
// ============================================================================
|
||||
|
||||
TEST(Args, BoundaryAccessToTensors)
|
||||
{
|
||||
Args args{};
|
||||
// Access last tensor slot (kMaxTensors - 1 = 15)
|
||||
args.tensors[kMaxTensors - 1].ptr = nullptr;
|
||||
EXPECT_EQ(args.tensors[kMaxTensors - 1].ptr, nullptr);
|
||||
}
|
||||
|
||||
TEST(Args, BoundaryAccessToScalars)
|
||||
{
|
||||
Args args{};
|
||||
// Access last scalar slot (kMaxScalars - 1 = 15)
|
||||
args.scalars[kMaxScalars - 1].f32 = 1.0f;
|
||||
EXPECT_FLOAT_EQ(args.scalars[kMaxScalars - 1].f32, 1.0f);
|
||||
}
|
||||
|
||||
TEST(TensorArg, BoundaryAccessToLengthsAndStrides)
|
||||
{
|
||||
TensorArg ta{};
|
||||
// Access last rank dimension (kMaxRank - 1 = 5)
|
||||
ta.lengths[kMaxRank - 1] = 42;
|
||||
ta.strides[kMaxRank - 1] = 99;
|
||||
EXPECT_EQ(ta.lengths[kMaxRank - 1], 42);
|
||||
EXPECT_EQ(ta.strides[kMaxRank - 1], 99);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// makeShape
|
||||
// ============================================================================
|
||||
|
||||
TEST(MakeShape, ZeroFillsUnusedDimensions)
|
||||
{
|
||||
EXPECT_THAT(makeShape(128, 64), ElementsAre(128, 64, 0, 0, 0, 0));
|
||||
}
|
||||
|
||||
TEST(MakeShape, FillsAllSixDimensions)
|
||||
{
|
||||
EXPECT_THAT(makeShape(2, 3, 4, 5, 6, 7), ElementsAre(2, 3, 4, 5, 6, 7));
|
||||
}
|
||||
|
||||
TEST(MakeShape, SingleDimension) { EXPECT_THAT(makeShape(1024), ElementsAre(1024, 0, 0, 0, 0, 0)); }
|
||||
|
||||
// ============================================================================
|
||||
// makeStrides
|
||||
// ============================================================================
|
||||
|
||||
TEST(MakeStrides, ZeroFillsUnusedDimensions)
|
||||
{
|
||||
EXPECT_THAT(makeStrides(256, 1), ElementsAre(256, 1, 0, 0, 0, 0));
|
||||
}
|
||||
|
||||
TEST(MakeStrides, HandlesLargeInt64Values)
|
||||
{
|
||||
constexpr int64_t large = 1LL << 40;
|
||||
EXPECT_THAT(makeStrides(large, 1), ElementsAre(large, 1, 0, 0, 0, 0));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
79
rocm_ck/tests/unit/unit_datatype.cpp
Normal file
79
rocm_ck/tests/unit/unit_datatype.cpp
Normal file
@@ -0,0 +1,79 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <rocm_ck/datatype.hpp>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
using ::rocm_ck::DataType;
|
||||
using ::rocm_ck::dataTypeBits;
|
||||
using ::rocm_ck::dataTypeName;
|
||||
using ::testing::TestParamInfo;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
|
||||
namespace {
|
||||
|
||||
// ============================================================================
|
||||
// Parameterized: one row per DataType variant
|
||||
// ============================================================================
|
||||
|
||||
struct DataTypeEntry
|
||||
{
|
||||
DataType dt;
|
||||
int bits;
|
||||
const char* name;
|
||||
};
|
||||
|
||||
class DataTypeTest : public TestWithParam<DataTypeEntry>
|
||||
{
|
||||
};
|
||||
|
||||
TEST_P(DataTypeTest, ReportsCorrectBits)
|
||||
{
|
||||
EXPECT_EQ(dataTypeBits(GetParam().dt), GetParam().bits);
|
||||
}
|
||||
|
||||
TEST_P(DataTypeTest, MapsToExpectedName)
|
||||
{
|
||||
EXPECT_STREQ(dataTypeName(GetParam().dt), GetParam().name);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
AllTypes,
|
||||
DataTypeTest,
|
||||
Values(DataTypeEntry{.dt = DataType::FP64, .bits = 64, .name = "FP64"},
|
||||
DataTypeEntry{.dt = DataType::FP32, .bits = 32, .name = "FP32"},
|
||||
DataTypeEntry{.dt = DataType::FP16, .bits = 16, .name = "FP16"},
|
||||
DataTypeEntry{.dt = DataType::BF16, .bits = 16, .name = "BF16"},
|
||||
DataTypeEntry{.dt = DataType::FP8_FNUZ, .bits = 8, .name = "FP8_FNUZ"},
|
||||
DataTypeEntry{.dt = DataType::BF8_FNUZ, .bits = 8, .name = "BF8_FNUZ"},
|
||||
DataTypeEntry{.dt = DataType::FP8_OCP, .bits = 8, .name = "FP8_OCP"},
|
||||
DataTypeEntry{.dt = DataType::BF8_OCP, .bits = 8, .name = "BF8_OCP"},
|
||||
DataTypeEntry{.dt = DataType::I4, .bits = 4, .name = "I4"},
|
||||
DataTypeEntry{.dt = DataType::I8, .bits = 8, .name = "I8"},
|
||||
DataTypeEntry{.dt = DataType::I16, .bits = 16, .name = "I16"},
|
||||
DataTypeEntry{.dt = DataType::I32, .bits = 32, .name = "I32"},
|
||||
DataTypeEntry{.dt = DataType::I64, .bits = 64, .name = "I64"},
|
||||
DataTypeEntry{.dt = DataType::U8, .bits = 8, .name = "U8"},
|
||||
DataTypeEntry{.dt = DataType::U16, .bits = 16, .name = "U16"},
|
||||
DataTypeEntry{.dt = DataType::U32, .bits = 32, .name = "U32"},
|
||||
DataTypeEntry{.dt = DataType::U64, .bits = 64, .name = "U64"}),
|
||||
[](const TestParamInfo<DataTypeEntry>& p) { return std::string(p.param.name); });
|
||||
|
||||
// ============================================================================
|
||||
// constexpr validation
|
||||
// ============================================================================
|
||||
|
||||
TEST(DataType, EvaluatesBitsAndNameAtCompileTime)
|
||||
{
|
||||
constexpr int fp32_bits = dataTypeBits(DataType::FP32);
|
||||
EXPECT_EQ(fp32_bits, 32);
|
||||
|
||||
constexpr const char* fp32_name = dataTypeName(DataType::FP32);
|
||||
EXPECT_STREQ(fp32_name, "FP32");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
50
rocm_ck/tests/unit/unit_fixed_string.cpp
Normal file
50
rocm_ck/tests/unit/unit_fixed_string.cpp
Normal file
@@ -0,0 +1,50 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <rocm_ck/fixed_string.hpp>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using ::rocm_ck::FixedString;
|
||||
|
||||
namespace {
|
||||
|
||||
TEST(FixedString, MatchesSingleCharacter)
|
||||
{
|
||||
EXPECT_TRUE(FixedString<16>("A") == "A");
|
||||
EXPECT_FALSE(FixedString<16>("A") == "B");
|
||||
}
|
||||
|
||||
TEST(FixedString, MatchesExactStringOnly)
|
||||
{
|
||||
EXPECT_TRUE(FixedString<16>("bias") == "bias");
|
||||
EXPECT_FALSE(FixedString<16>("bias") == "bia");
|
||||
EXPECT_FALSE(FixedString<16>("bias") == "biases");
|
||||
}
|
||||
|
||||
TEST(FixedString, AcceptsMaxCapacityMinusOne)
|
||||
{
|
||||
EXPECT_TRUE(FixedString<16>("123456789012345") == "123456789012345");
|
||||
}
|
||||
|
||||
TEST(FixedString, SupportsEmptyString)
|
||||
{
|
||||
EXPECT_EQ(FixedString<16>("").len, 0);
|
||||
EXPECT_TRUE(FixedString<16>("") == "");
|
||||
EXPECT_FALSE(FixedString<16>("") == "A");
|
||||
}
|
||||
|
||||
TEST(FixedString, EqualStringsCompareEqual)
|
||||
{
|
||||
EXPECT_EQ(FixedString<16>("A"), FixedString<16>("A"));
|
||||
EXPECT_NE(FixedString<16>("A"), FixedString<16>("B"));
|
||||
}
|
||||
|
||||
TEST(FixedString, OrderingIsLexicographic)
|
||||
{
|
||||
EXPECT_LT(FixedString<16>("A"), FixedString<16>("B"));
|
||||
EXPECT_LT(FixedString<16>("B"), FixedString<16>("Z"));
|
||||
EXPECT_GT(FixedString<16>("Z"), FixedString<16>("A"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
88
rocm_ck/tests/unit/unit_layout.cpp
Normal file
88
rocm_ck/tests/unit/unit_layout.cpp
Normal file
@@ -0,0 +1,88 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <rocm_ck/layout.hpp>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <array>
|
||||
|
||||
using ::rocm_ck::isValidLayoutForRank;
|
||||
using ::rocm_ck::Layout;
|
||||
using ::rocm_ck::layoutName;
|
||||
using ::rocm_ck::layoutStrides;
|
||||
using ::rocm_ck::leadingDimStride;
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
namespace {
|
||||
|
||||
// ============================================================================
|
||||
// layoutName
|
||||
// ============================================================================
|
||||
|
||||
TEST(Layout, MapsEnumValuesToExpectedStrings)
|
||||
{
|
||||
EXPECT_STREQ(layoutName(Layout::Row), "Row");
|
||||
EXPECT_STREQ(layoutName(Layout::Col), "Col");
|
||||
EXPECT_STREQ(layoutName(Layout::Auto), "Auto");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// isValidLayoutForRank
|
||||
// ============================================================================
|
||||
|
||||
TEST(Layout, AllowsRowAndColOnlyForRank2)
|
||||
{
|
||||
EXPECT_FALSE(isValidLayoutForRank(Layout::Row, 1));
|
||||
EXPECT_TRUE(isValidLayoutForRank(Layout::Row, 2));
|
||||
EXPECT_FALSE(isValidLayoutForRank(Layout::Col, 1));
|
||||
EXPECT_TRUE(isValidLayoutForRank(Layout::Col, 2));
|
||||
}
|
||||
|
||||
TEST(Layout, RejectsAutoForAllRanks)
|
||||
{
|
||||
EXPECT_FALSE(isValidLayoutForRank(Layout::Auto, 0));
|
||||
EXPECT_FALSE(isValidLayoutForRank(Layout::Auto, 1));
|
||||
EXPECT_FALSE(isValidLayoutForRank(Layout::Auto, 2));
|
||||
}
|
||||
|
||||
TEST(Layout, RejectsRowAndColForRankGreaterThan2)
|
||||
{
|
||||
EXPECT_FALSE(isValidLayoutForRank(Layout::Row, 3));
|
||||
EXPECT_FALSE(isValidLayoutForRank(Layout::Row, 4));
|
||||
EXPECT_FALSE(isValidLayoutForRank(Layout::Row, 6));
|
||||
EXPECT_FALSE(isValidLayoutForRank(Layout::Col, 3));
|
||||
EXPECT_FALSE(isValidLayoutForRank(Layout::Col, 4));
|
||||
EXPECT_FALSE(isValidLayoutForRank(Layout::Col, 6));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// leadingDimStride
|
||||
// ============================================================================
|
||||
|
||||
TEST(Layout, LeadingDimStrideReturnsFirstForRow)
|
||||
{
|
||||
EXPECT_EQ(leadingDimStride(Layout::Row, std::array{128, 1}), 128);
|
||||
}
|
||||
|
||||
TEST(Layout, LeadingDimStrideReturnsSecondForCol)
|
||||
{
|
||||
EXPECT_EQ(leadingDimStride(Layout::Col, std::array{1, 64}), 64);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// layoutStrides
|
||||
// ============================================================================
|
||||
|
||||
TEST(Layout, LayoutStridesRowMajor)
|
||||
{
|
||||
EXPECT_THAT(layoutStrides(Layout::Row, 32, 64), ElementsAre(64, 1));
|
||||
}
|
||||
|
||||
TEST(Layout, LayoutStridesColMajor)
|
||||
{
|
||||
EXPECT_THAT(layoutStrides(Layout::Col, 32, 64), ElementsAre(1, 32));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
35
rocm_ck/tests/unit/unit_physical_tensor.cpp
Normal file
35
rocm_ck/tests/unit/unit_physical_tensor.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <rocm_ck/physical_tensor.hpp>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using ::rocm_ck::DataType;
|
||||
using ::rocm_ck::FixedString;
|
||||
using ::rocm_ck::kMaxPhysicalTensors;
|
||||
using ::rocm_ck::Layout;
|
||||
using ::rocm_ck::PhysicalTensor;
|
||||
|
||||
namespace {
|
||||
|
||||
TEST(PhysicalTensor, InitializesWithFP32RowAndSlotZero)
|
||||
{
|
||||
constexpr PhysicalTensor pt{};
|
||||
EXPECT_EQ(pt.dtype, DataType::FP32);
|
||||
EXPECT_EQ(pt.layout, Layout::Row);
|
||||
EXPECT_EQ(pt.args_slot, 0);
|
||||
}
|
||||
|
||||
TEST(PhysicalTensor, StoresAllFieldsFromConstruction)
|
||||
{
|
||||
constexpr PhysicalTensor pt{FixedString<16>("bias"), DataType::FP16, Layout::Col, 3};
|
||||
EXPECT_TRUE(pt.name == "bias");
|
||||
EXPECT_EQ(pt.dtype, DataType::FP16);
|
||||
EXPECT_EQ(pt.layout, Layout::Col);
|
||||
EXPECT_EQ(pt.args_slot, 3);
|
||||
}
|
||||
|
||||
TEST(PhysicalTensor, LimitsCapacityTo8) { EXPECT_EQ(kMaxPhysicalTensors, 8); }
|
||||
|
||||
} // namespace
|
||||
Reference in New Issue
Block a user