mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
[CK] Add rocm_ck spec factories: GemmSpec, makeSpec() (#7180) ## What this PR does This is the third PR in the rocm_ck schema stack: 1. **#7150** — Foundation types (DataType, Layout, Args, Ops) 2. **#7163** — Schema engine (Signature, resolve(), ArchProperties) 3. **#7180 (this)** — Spec factories (GemmSpec, makeSpec()) `makeSpec()` is the bridge between user intent and kernel instantiation. It takes a **Signature** (WHAT to compute — operator graph, dtypes, layouts) and a **GemmAlgorithm** (HOW to compute it — tile sizes, pipeline, partitioning) and produces a validated `GemmSpec` — a structural type ready to use as a non-type template parameter. The key property: **every constraint is enforced at compile time.** An invalid GEMM configuration is a compile error, not a runtime crash or silent corruption. The 33 compile-fail tests are the executable specification of what's allowed. ## What's interesting **Physical tensor table.** Not every tensor in a compute graph needs device memory. The intermediate result of `C = A * B` in a fused GEMM+Add+ReLU lives only in registers. `makeSpec()` walks the operator chain and determines which tensors are physical (need Args slots) and which are intermediate. The output is a fixed-layout table: `[lhs, rhs, output, D0?, D1?, scale?]`. **Epilogue composition.** Instead of a combinatorial explosion of named patterns (GemmAdd, GemmAddRelu, GemmMulSilu, ...), the epilogue is a composable chain of ops. `{GemmOp, AddOp, ReluOp}` produces `epilogue_ops = {Add, Relu}` with the bias tensor automatically slotted as D0. Two consecutive AddOps fold into a single Add with two D tensors via CK Tile's parameter pack. **Signature/Algorithm split.** The same Signature can pair with multiple GemmAlgorithms to produce different tuning variants without changing the mathematical result. This is the foundation for the dispatcher — one operation description, many tile configurations. ## New types | Type | Role | |------|------| | `GemmSpec` | Validated NTTP kernel descriptor — physical tensors, tile geometry, epilogue chain | | `GemmAlgorithm` | User-facing tuning input — tile sizes, pipeline, partitioning, padding | | `EpilogueOp` | NTTP-compatible projection of the Op variant for epilogue chains | | `Dim3` | M x N x K triple for tile geometry | ## Test coverage - **69 unit tests** — happy paths, layouts, dtypes, quantization, epilogue chains, algorithm variants - **33 compile-fail tests** — one per constraint (tile divisibility, INT8 rules, pipeline restrictions, etc.) - **6 schema compatibility baselines** — frozen specs that break if the schema changes --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
225 lines
8.9 KiB
C++
225 lines
8.9 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
//
|
|
// Schema compatibility tests: frozen baseline configs from example 04.
|
|
//
|
|
// These tests verify that schema changes (new fields, modified defaults,
|
|
// validation rules) do NOT break existing variants. Each test freezes the
|
|
// exact makeSpec() call from a .hip variant file and asserts on the
|
|
// full GemmSpec output.
|
|
//
|
|
// If a test fails after a schema change, the change is NOT backwards-
|
|
// compatible. Fix the schema or update the variant (and document why).
|
|
|
|
#include <rocm_ck/gemm_spec.hpp>
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
using ::rocm_ck::AddOp;
|
|
using ::rocm_ck::DataType;
|
|
using ::rocm_ck::EpilogueOp;
|
|
using ::rocm_ck::GemmAlgorithm;
|
|
using ::rocm_ck::GemmOp;
|
|
using ::rocm_ck::GemmSpec;
|
|
using ::rocm_ck::Layout;
|
|
using ::rocm_ck::makeSpec;
|
|
using ::rocm_ck::ReluOp;
|
|
using ::rocm_ck::Signature;
|
|
using ::rocm_ck::TargetSet;
|
|
|
|
// Frozen baseline tests: these assert ALL fields of each spec variant.
|
|
// This is intentionally brittle — adding a new field to GemmSpec will
|
|
// break these tests, forcing explicit review of the change's impact on
|
|
// existing variants. Update the expected values when making intentional
|
|
// schema changes.
|
|
|
|
// ============================================================================
|
|
// gemm_fp32: FP32 plain GEMM, 16x16x16 MFMA tile
|
|
// ============================================================================
|
|
|
|
TEST(SchemaCompat, GemmFP32)
|
|
{
|
|
constexpr auto k = makeSpec(
|
|
Signature{.dtype = DataType::FP32, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
|
|
GemmAlgorithm{
|
|
.block_tile = {128, 128, 32}, .block_waves = {2, 2, 1}, .wave_tile = {16, 16, 16}},
|
|
TargetSet::cdna());
|
|
|
|
EXPECT_EQ(k.num_physical_tensors, 3);
|
|
EXPECT_EQ(slot(k, "A"), 0);
|
|
EXPECT_EQ(slot(k, "B"), 1);
|
|
EXPECT_EQ(slot(k, "C"), 2);
|
|
EXPECT_EQ(dtype(k, "A"), DataType::FP32);
|
|
EXPECT_EQ(dtype(k, "B"), DataType::FP32);
|
|
EXPECT_EQ(dtype(k, "C"), DataType::FP32);
|
|
EXPECT_EQ(layout(k, "A"), Layout::Row);
|
|
EXPECT_EQ(layout(k, "B"), Layout::Col);
|
|
EXPECT_EQ(layout(k, "C"), Layout::Row);
|
|
EXPECT_EQ(k.acc_dtype, DataType::FP32);
|
|
EXPECT_EQ(k.num_epilogue_ops, 0);
|
|
EXPECT_EQ(k.workgroup_size, 256);
|
|
EXPECT_EQ(k.wave_tile.m, 16);
|
|
EXPECT_EQ(k.wave_tile.n, 16);
|
|
EXPECT_EQ(k.wave_tile.k, 16);
|
|
}
|
|
|
|
// ============================================================================
|
|
// gemm_fp16: FP16 plain GEMM, 16x16x16 MFMA tile
|
|
// ============================================================================
|
|
|
|
TEST(SchemaCompat, GemmFP16)
|
|
{
|
|
constexpr auto k = makeSpec(
|
|
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
|
|
GemmAlgorithm{
|
|
.block_tile = {128, 128, 32}, .block_waves = {2, 2, 1}, .wave_tile = {16, 16, 16}},
|
|
TargetSet::cdna());
|
|
|
|
EXPECT_EQ(k.num_physical_tensors, 3);
|
|
EXPECT_EQ(slot(k, "A"), 0);
|
|
EXPECT_EQ(slot(k, "B"), 1);
|
|
EXPECT_EQ(slot(k, "C"), 2);
|
|
EXPECT_EQ(dtype(k, "A"), DataType::FP16);
|
|
EXPECT_EQ(dtype(k, "B"), DataType::FP16);
|
|
EXPECT_EQ(dtype(k, "C"), DataType::FP16);
|
|
EXPECT_EQ(layout(k, "A"), Layout::Row);
|
|
EXPECT_EQ(layout(k, "B"), Layout::Col);
|
|
EXPECT_EQ(layout(k, "C"), Layout::Row);
|
|
EXPECT_EQ(k.acc_dtype, DataType::FP32);
|
|
EXPECT_EQ(k.num_epilogue_ops, 0);
|
|
EXPECT_EQ(k.workgroup_size, 256);
|
|
EXPECT_EQ(k.wave_tile.m, 16);
|
|
EXPECT_EQ(k.wave_tile.n, 16);
|
|
EXPECT_EQ(k.wave_tile.k, 16);
|
|
}
|
|
|
|
// ============================================================================
|
|
// gemm_bf16: BF16 plain GEMM, 16x16x16 MFMA tile
|
|
// ============================================================================
|
|
|
|
TEST(SchemaCompat, GemmBF16)
|
|
{
|
|
constexpr auto k = makeSpec(
|
|
Signature{.dtype = DataType::BF16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
|
|
GemmAlgorithm{
|
|
.block_tile = {128, 128, 32}, .block_waves = {2, 2, 1}, .wave_tile = {16, 16, 16}},
|
|
TargetSet::cdna());
|
|
|
|
EXPECT_EQ(k.num_physical_tensors, 3);
|
|
EXPECT_EQ(slot(k, "A"), 0);
|
|
EXPECT_EQ(slot(k, "B"), 1);
|
|
EXPECT_EQ(slot(k, "C"), 2);
|
|
EXPECT_EQ(dtype(k, "A"), DataType::BF16);
|
|
EXPECT_EQ(dtype(k, "B"), DataType::BF16);
|
|
EXPECT_EQ(dtype(k, "C"), DataType::BF16);
|
|
EXPECT_EQ(layout(k, "A"), Layout::Row);
|
|
EXPECT_EQ(layout(k, "B"), Layout::Col);
|
|
EXPECT_EQ(layout(k, "C"), Layout::Row);
|
|
EXPECT_EQ(k.acc_dtype, DataType::FP32);
|
|
EXPECT_EQ(k.num_epilogue_ops, 0);
|
|
EXPECT_EQ(k.workgroup_size, 256);
|
|
EXPECT_EQ(k.wave_tile.m, 16);
|
|
EXPECT_EQ(k.wave_tile.n, 16);
|
|
EXPECT_EQ(k.wave_tile.k, 16);
|
|
}
|
|
|
|
// ============================================================================
|
|
// gemm_fp16_w32: FP16 plain GEMM, 32x32x16 MFMA tile
|
|
// ============================================================================
|
|
|
|
TEST(SchemaCompat, GemmFP16W32)
|
|
{
|
|
constexpr auto k = makeSpec(
|
|
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
|
|
GemmAlgorithm{
|
|
.block_tile = {128, 128, 32}, .block_waves = {2, 2, 1}, .wave_tile = {32, 32, 16}},
|
|
TargetSet::cdna());
|
|
|
|
EXPECT_EQ(k.num_physical_tensors, 3);
|
|
EXPECT_EQ(slot(k, "A"), 0);
|
|
EXPECT_EQ(slot(k, "B"), 1);
|
|
EXPECT_EQ(slot(k, "C"), 2);
|
|
EXPECT_EQ(dtype(k, "A"), DataType::FP16);
|
|
EXPECT_EQ(dtype(k, "B"), DataType::FP16);
|
|
EXPECT_EQ(dtype(k, "C"), DataType::FP16);
|
|
EXPECT_EQ(layout(k, "A"), Layout::Row);
|
|
EXPECT_EQ(layout(k, "B"), Layout::Col);
|
|
EXPECT_EQ(layout(k, "C"), Layout::Row);
|
|
EXPECT_EQ(k.acc_dtype, DataType::FP32);
|
|
EXPECT_EQ(k.num_epilogue_ops, 0);
|
|
EXPECT_EQ(k.workgroup_size, 256);
|
|
EXPECT_EQ(k.wave_tile.m, 32);
|
|
EXPECT_EQ(k.wave_tile.n, 32);
|
|
EXPECT_EQ(k.wave_tile.k, 16);
|
|
}
|
|
|
|
// ============================================================================
|
|
// gemm_fp16_add: FP16 GEMM + Add (1 D tensor)
|
|
// ============================================================================
|
|
|
|
TEST(SchemaCompat, GemmFP16Add)
|
|
{
|
|
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
|
|
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
|
|
AddOp{.lhs = "C", .rhs = "bias", .out = "D"}}},
|
|
GemmAlgorithm{.block_tile = {128, 128, 32},
|
|
.block_waves = {2, 2, 1},
|
|
.wave_tile = {16, 16, 16}},
|
|
TargetSet::cdna());
|
|
|
|
EXPECT_EQ(k.num_physical_tensors, 4);
|
|
EXPECT_EQ(slot(k, "A"), 0);
|
|
EXPECT_EQ(slot(k, "B"), 1);
|
|
EXPECT_EQ(slot(k, "D"), 2); // final output
|
|
EXPECT_EQ(slot(k, "bias"), 3); // D0 slot
|
|
EXPECT_EQ(dtype(k, "A"), DataType::FP16);
|
|
EXPECT_EQ(dtype(k, "B"), DataType::FP16);
|
|
EXPECT_EQ(dtype(k, "D"), DataType::FP16);
|
|
EXPECT_EQ(dtype(k, "bias"), DataType::FP16);
|
|
EXPECT_EQ(layout(k, "A"), Layout::Row);
|
|
EXPECT_EQ(layout(k, "B"), Layout::Col);
|
|
EXPECT_EQ(k.acc_dtype, DataType::FP32);
|
|
EXPECT_EQ(k.num_epilogue_ops, 1);
|
|
EXPECT_TRUE(k.hasEpilogueOp(EpilogueOp::Add));
|
|
EXPECT_EQ(k.workgroup_size, 256);
|
|
EXPECT_EQ(k.wave_tile.m, 16);
|
|
EXPECT_EQ(k.wave_tile.n, 16);
|
|
EXPECT_EQ(k.wave_tile.k, 16);
|
|
}
|
|
|
|
// ============================================================================
|
|
// gemm_fp16_add_relu: FP16 GEMM + Add + Relu (1 D tensor)
|
|
// ============================================================================
|
|
|
|
TEST(SchemaCompat, GemmFP16AddRelu)
|
|
{
|
|
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
|
|
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
|
|
AddOp{.lhs = "C", .rhs = "bias", .out = "D"},
|
|
ReluOp{.in = "D", .out = "E"}}},
|
|
GemmAlgorithm{.block_tile = {128, 128, 32},
|
|
.block_waves = {2, 2, 1},
|
|
.wave_tile = {16, 16, 16}},
|
|
TargetSet::cdna());
|
|
|
|
EXPECT_EQ(k.num_physical_tensors, 4);
|
|
EXPECT_EQ(slot(k, "A"), 0);
|
|
EXPECT_EQ(slot(k, "B"), 1);
|
|
EXPECT_EQ(slot(k, "E"), 2); // final output
|
|
EXPECT_EQ(slot(k, "bias"), 3); // D0 slot
|
|
EXPECT_EQ(dtype(k, "A"), DataType::FP16);
|
|
EXPECT_EQ(dtype(k, "B"), DataType::FP16);
|
|
EXPECT_EQ(dtype(k, "E"), DataType::FP16);
|
|
EXPECT_EQ(dtype(k, "bias"), DataType::FP16);
|
|
EXPECT_EQ(layout(k, "A"), Layout::Row);
|
|
EXPECT_EQ(layout(k, "B"), Layout::Col);
|
|
EXPECT_EQ(k.acc_dtype, DataType::FP32);
|
|
EXPECT_EQ(k.num_epilogue_ops, 2);
|
|
EXPECT_TRUE(k.hasEpilogueOp(EpilogueOp::Add));
|
|
EXPECT_TRUE(k.hasEpilogueOp(EpilogueOp::Relu));
|
|
EXPECT_EQ(k.workgroup_size, 256);
|
|
EXPECT_EQ(k.wave_tile.m, 16);
|
|
EXPECT_EQ(k.wave_tile.n, 16);
|
|
EXPECT_EQ(k.wave_tile.k, 16);
|
|
}
|