mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
[CK] Add rocm_ck schema engine: Signature, resolve(), ArchProperties (#7179) ## Summary A `Signature` is a directed compute graph: tensors are nodes, operators are edges. Shared names between operator outputs and inputs form the graph structure. `resolve()` walks this graph at compile time (`consteval`), inferring dtype, rank, and layout for every tensor — invalid configs become compiler errors, not runtime crashes. **Key design decisions:** - **Operators teach the system about tensors.** `GemmOp` implies rank 2 and Row/Col/Row layout. `AddOp` and `ReluOp` propagate from connected slots. The dtype cascade fills in the rest: per-tensor → signature-wide → error. - **Adding a new op is zero lines in the resolution engine** if it's structurally binary (`lhs/rhs/out`) or unary (`in/out`) — C++20 concepts handle dispatch automatically. Only ops with special semantics need explicit branches. - **TargetSet is a compile-time bitset over GPU targets.** The wave tile validation table is the single source of truth for valid instruction shapes, traced from CK Tile's WarpGemmDispatcher. FP8 tiles are available on gfx942+ via IterateK composition, not gfx950-only. **Reading order:** signature.hpp (data model) → arch_properties.hpp (TargetSet, wave tiles) → resolve.hpp (resolution engine). 3 new headers, 3 unit tests (including diamond DAG coverage), 3 compile-fail tests. Introduces tests/compile_fail/ infrastructure. **Stack**: PR 2 of 3 porting the rocm_ck constexpr schema from experimental to production. 1. Foundation types — DataType, Layout, Args, Ops (#7114) 2. **This PR** — Schema engine (graph resolution) 3. Spec factories — GemmSpec, makeSpec() (#7180 ) Note: We also removed `FmhaBwdOp` for clarity, since that was introduced early and doesn't have tests set up. **Depends on**: #7114 ## Test plan - [x] ctest --test-dir build --output-on-failure — unit tests + compile-fail tests pass - [x] Compile-fail tests correctly reject: mixed CDNA+RDNA TargetSet, conflicting layouts, empty quantization scale names --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
177 lines
5.0 KiB
C++
177 lines
5.0 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#include <rocm_ck/signature.hpp>
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
using ::rocm_ck::AddOp;
|
|
using ::rocm_ck::DataType;
|
|
using ::rocm_ck::FastGeluOp;
|
|
using ::rocm_ck::GemmOp;
|
|
using ::rocm_ck::kMaxOps;
|
|
using ::rocm_ck::kMaxScalars;
|
|
using ::rocm_ck::kMaxTensors;
|
|
using ::rocm_ck::Layout;
|
|
using ::rocm_ck::MulOp;
|
|
using ::rocm_ck::Op;
|
|
using ::rocm_ck::Quantization;
|
|
using ::rocm_ck::ReluOp;
|
|
using ::rocm_ck::Scalar;
|
|
using ::rocm_ck::SigmoidOp;
|
|
using ::rocm_ck::Signature;
|
|
using ::rocm_ck::Tensor;
|
|
|
|
// ============================================================================
|
|
// Signature construction
|
|
// ============================================================================
|
|
|
|
TEST(Signature, DefaultsToNoDtype)
|
|
{
|
|
constexpr Signature sig{};
|
|
EXPECT_FALSE(sig.dtype.has_value());
|
|
}
|
|
|
|
TEST(Signature, StoresExplicitDtype)
|
|
{
|
|
constexpr Signature sig{.dtype = DataType::FP16};
|
|
EXPECT_TRUE(sig.dtype.has_value());
|
|
EXPECT_EQ(*sig.dtype, DataType::FP16);
|
|
}
|
|
|
|
// ============================================================================
|
|
// Tensor
|
|
// ============================================================================
|
|
|
|
TEST(Tensor, DefaultsToAutoLayoutAndRankZero)
|
|
{
|
|
constexpr Tensor t{.name = "A"};
|
|
EXPECT_EQ(t.name, "A");
|
|
EXPECT_FALSE(t.dtype.has_value());
|
|
EXPECT_EQ(t.rank, 0);
|
|
EXPECT_EQ(t.layout, Layout::Auto);
|
|
}
|
|
|
|
TEST(Tensor, StoresAllExplicitFields)
|
|
{
|
|
constexpr Tensor t{.name = "Q", .dtype = DataType::FP32, .rank = 3, .layout = Layout::Row};
|
|
EXPECT_EQ(t.name, "Q");
|
|
EXPECT_EQ(*t.dtype, DataType::FP32);
|
|
EXPECT_EQ(t.rank, 3);
|
|
EXPECT_EQ(t.layout, Layout::Row);
|
|
}
|
|
|
|
// ============================================================================
|
|
// Scalar
|
|
// ============================================================================
|
|
|
|
TEST(Scalar, DefaultsToFP32Dtype)
|
|
{
|
|
constexpr Scalar s{.name = "alpha"};
|
|
EXPECT_EQ(s.name, "alpha");
|
|
EXPECT_EQ(s.dtype, DataType::FP32);
|
|
}
|
|
|
|
TEST(Scalar, StoresExplicitDtype)
|
|
{
|
|
constexpr Scalar s{.name = "scale", .dtype = DataType::FP16};
|
|
EXPECT_EQ(s.dtype, DataType::FP16);
|
|
}
|
|
|
|
// ============================================================================
|
|
// Op variant
|
|
// ============================================================================
|
|
|
|
TEST(Op, DefaultsToMonostate)
|
|
{
|
|
constexpr Op op{};
|
|
EXPECT_TRUE(std::holds_alternative<std::monostate>(op));
|
|
}
|
|
|
|
TEST(Op, HoldsGemmOp)
|
|
{
|
|
constexpr Op op = GemmOp{.lhs = "A", .rhs = "B", .out = "C"};
|
|
EXPECT_TRUE(std::holds_alternative<GemmOp>(op));
|
|
}
|
|
|
|
TEST(Op, HoldsAllUnaryOpTypes)
|
|
{
|
|
constexpr Op relu = ReluOp{.in = "X", .out = "Y"};
|
|
EXPECT_TRUE(std::holds_alternative<ReluOp>(relu));
|
|
|
|
constexpr Op gelu = FastGeluOp{.in = "X", .out = "Y"};
|
|
EXPECT_TRUE(std::holds_alternative<FastGeluOp>(gelu));
|
|
|
|
constexpr Op sigmoid = SigmoidOp{.in = "X", .out = "Y"};
|
|
EXPECT_TRUE(std::holds_alternative<SigmoidOp>(sigmoid));
|
|
}
|
|
|
|
TEST(Op, HoldsAllBinaryOpTypes)
|
|
{
|
|
constexpr Op add = AddOp{.lhs = "X", .rhs = "Y", .out = "Z"};
|
|
EXPECT_TRUE(std::holds_alternative<AddOp>(add));
|
|
|
|
constexpr Op mul = MulOp{.lhs = "X", .rhs = "Y", .out = "Z"};
|
|
EXPECT_TRUE(std::holds_alternative<MulOp>(mul));
|
|
}
|
|
|
|
// ============================================================================
|
|
// GemmOp defaults
|
|
// ============================================================================
|
|
|
|
TEST(GemmOp, DefaultsAccDtypeToFP32)
|
|
{
|
|
constexpr GemmOp gemm{.lhs = "A", .rhs = "B", .out = "C"};
|
|
EXPECT_EQ(gemm.acc_dtype, DataType::FP32);
|
|
}
|
|
|
|
// ============================================================================
|
|
// Quantization
|
|
// ============================================================================
|
|
|
|
TEST(Quantization, DefaultsToFP32ScaleAndGroupSize128)
|
|
{
|
|
constexpr Quantization q{.scale_name = "scale"};
|
|
EXPECT_EQ(q.scale_name, "scale");
|
|
EXPECT_EQ(q.scale_dtype, DataType::FP32);
|
|
EXPECT_EQ(q.group_size, 128);
|
|
}
|
|
|
|
TEST(Quantization, StoresExplicitFields)
|
|
{
|
|
constexpr Quantization q{.scale_name = "bq", .scale_dtype = DataType::FP16, .group_size = 64};
|
|
EXPECT_EQ(q.scale_name, "bq");
|
|
EXPECT_EQ(q.scale_dtype, DataType::FP16);
|
|
EXPECT_EQ(q.group_size, 64);
|
|
}
|
|
|
|
TEST(Tensor, DefaultsToNoQuantize)
|
|
{
|
|
constexpr Tensor t{.name = "B"};
|
|
EXPECT_FALSE(t.quantize.has_value());
|
|
}
|
|
|
|
TEST(Tensor, StoresQuantizeMetadata)
|
|
{
|
|
constexpr Tensor t{
|
|
.name = "B",
|
|
.dtype = DataType::I4,
|
|
.quantize =
|
|
Quantization{.scale_name = "scale", .scale_dtype = DataType::FP32, .group_size = 128}};
|
|
EXPECT_TRUE(t.quantize.has_value());
|
|
EXPECT_EQ(t.quantize->scale_name, "scale");
|
|
EXPECT_EQ(t.quantize->scale_dtype, DataType::FP32);
|
|
EXPECT_EQ(t.quantize->group_size, 128);
|
|
}
|
|
|
|
// ============================================================================
|
|
// Capacity constants
|
|
// ============================================================================
|
|
|
|
TEST(Signature, DefinesExpectedCapacityLimits)
|
|
{
|
|
EXPECT_EQ(kMaxTensors, 16);
|
|
EXPECT_EQ(kMaxScalars, 16);
|
|
EXPECT_EQ(kMaxOps, 8);
|
|
}
|