Files
composable_kernel/rocm_ck/tests/unit/unit_resolve.cpp
John Afaganis 96c39b331e [rocm-libraries] ROCm/rocm-libraries#7829 (commit 13af7da)
[ck] Enforce ASCII-only C/C++ sources for hipRTC
 compatibility (#7829)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

CK source files must be compilable via **hipRTC (HIP runtime
compilation)**, whose preprocessor does not accept non-ASCII bytes
anywhere in a translation unit — **including in comments**. Bytes that
are harmless under `hipcc` (em-dashes, smart quotes, multiplication
signs, Greek letters, box-drawing glyphs, etc.) cause hipRTC to fail at
preprocessing time. These regularly leak in via LLM-assisted authoring
or copy/paste from formatted documents and silently break hipRTC paths
that are not exercised by the default `hipcc`-based build matrix.

This PR (a) cleans every existing violation (53 files) and (b) adds a
pre-checkin gate so new violations are rejected before merge.

## File extensions covered

Both the cleanup scan and the new Jenkins enforcement stage use the same
predicate:

```
*.h  *.hpp  *.cpp  *.h.in  *.hpp.in  *.cpp.in  *.inc  *.cl
```

(excluding `*/build/*` and `*/include/rapidjson/*`). This is a strict
superset of the existing `Clang Format` stage's predicate — `*.inc` is
added so test-fixture include files are also gated. The local pre-commit
hook's `c++/inc` type filter covers the same set.

## Why no enforcement today

CK is opted out of the rocm-libraries root `.pre-commit-config.yaml`, so
the existing `pre-commit` workflow doesn't touch CK. The local CK
`.pre-commit-config.yaml` only runs for developers who installed hooks.
The **authoritative gate is therefore the new Jenkins stage** in this
PR; the local hook is convenience.

## Commit layout (bisect-friendly)

1. `79798aa6261` — **`[ck] Convert reflect/ rendering to ASCII for
hipRTC compatibility`**
Behavior change, isolated. `TreeFormatter` swaps `├─ / └─ / │ ` for `|-
/ +- / | ` (3-col width preserved so alignment is unchanged).
`conv_description.hpp` swaps `×` for `x` as the dimension separator.
`test_conv_description.cpp` expected strings updated in lockstep so the
snapshot test stays green. This is the only commit in the series with
observable runtime impact.

2. `738fdb0d81c` — **`[ck] Strip non-ASCII bytes from C++ sources for
hipRTC compatibility`**
Mechanical text cleanup across 53 files. Replacements happen in comments
or in `std::cout` strings that are not asserted on by any test. None of
the 174 `.inc` files in the tree required edits, but they were in the
scan's predicate so the enforcement stage's predicate is a superset of
what was scanned. Full replacement table in the commit message.

3. `1d7cd8ba235` — **`[ck] Enforce ASCII-only C/C++ sources for hipRTC
compatibility`**
- New `projects/composablekernel/script/check_ascii_only.sh` (modeled on
`check_copyright_year.sh`).
- New entry in `projects/composablekernel/.pre-commit-config.yaml` under
the local-hooks block (`types_or: [c++, inc]`).
- New `ASCII Only Check` parallel stage in
`projects/composablekernel/Jenkinsfile`'s `Static checks` block,
mirroring the existing `Clang Format` stage but with `*.inc` added to
the find predicate. Always-on, no `RUN_CPPCHECK` gate.

The tree is buildable at every commit boundary. Commit 1 leaves 50 known
violations; commit 2 leaves 0; commit 3 wires the gate.

## Demo

Script output on a synthesized violation:

```
$ printf '// em-dash test \xe2\x80\x94 here\n' > /tmp/bad.cpp
$ projects/composablekernel/script/check_ascii_only.sh /tmp/bad.cpp
ERROR: /tmp/bad.cpp contains non-ASCII bytes:
1:// em-dash test — here
  Fix: replace with ASCII (em-dash -> --, smart quotes -> ", arrows -> ->, etc.)
$ echo $?
1
```

Full repo scan after the cleanup commits (note the `-name '*.inc'`
clause):

```
$ cd projects/composablekernel && find . -type f \( -name '*.h' -o -name '*.hpp' -o -name '*.cpp' \
    -o -name '*.h.in' -o -name '*.hpp.in' -o -name '*.cpp.in' -o -name '*.inc' -o -name '*.cl' \) \
    -not -path '*/build/*' -not -path '*/include/rapidjson/*' -print0 \
  | xargs -0 -P 8 -n 64 script/check_ascii_only.sh
$ echo $?
0
```

## Test plan

- [ ] Jenkins PR build: confirm new `Static checks -> ASCII Only Check`
stage runs green over the full predicate (incl. `*.inc`) and existing
`Clang Format` stage is unaffected.
- [ ] `test_conv_description` passes against the ASCII tree-formatter
output (touched in commit 1).
- [ ] Local: `pre-commit run ascii-only-checker --all-files` runs
cleanly after installing CK pre-commit hooks via
`script/install_precommit.sh`.
- [ ] Manually inject a non-ASCII byte in any `.cpp/.hpp/.inc` file,
push: confirm Jenkins fails the new stage with a clear error.
- [ ] Spot-check a representative subset of touched files under hipRTC
compilation to confirm no remaining hipRTC-blocking content (optional,
since the static byte check is a sufficient condition for hipRTC
preprocessor acceptance on this dimension).

🤖 Generated with [Claude Code](https://claude.com/claude-code)
2026-06-04 15:00:17 +00:00

536 lines
20 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <rocm_ck/resolve.hpp>
#include <gtest/gtest.h>
using ::rocm_ck::AddOp;
using ::rocm_ck::BinaryOpLike;
using ::rocm_ck::DataType;
using ::rocm_ck::FastGeluOp;
using ::rocm_ck::GeluOp;
using ::rocm_ck::GemmOp;
using ::rocm_ck::kMaxTensors;
using ::rocm_ck::Layout;
using ::rocm_ck::MulOp;
using ::rocm_ck::Quantization;
using ::rocm_ck::ReluOp;
using ::rocm_ck::resolve;
using ::rocm_ck::Scalar;
using ::rocm_ck::ScaleOp;
using ::rocm_ck::SigmoidOp;
using ::rocm_ck::Signature;
using ::rocm_ck::SiluOp;
using ::rocm_ck::SoftmaxOp;
using ::rocm_ck::Tensor;
using ::rocm_ck::UnaryOpLike;
// ============================================================================
// Simple GemmOp resolution
// ============================================================================
TEST(Resolve, ResolvesSimpleGemmToThreeTensors)
{
constexpr auto r = resolve(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_EQ(r.num_tensors, 3);
}
TEST(Resolve, CascadesSignatureDtypeToAllGemmTensors)
{
constexpr auto r = resolve(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_EQ(r.tensor("A").dtype, DataType::FP16);
EXPECT_EQ(r.tensor("B").dtype, DataType::FP16);
EXPECT_EQ(r.tensor("C").dtype, DataType::FP16);
}
TEST(Resolve, AssignsRank2ToGemmTensors)
{
constexpr auto r = resolve(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_EQ(r.tensor("A").rank, 2);
EXPECT_EQ(r.tensor("B").rank, 2);
EXPECT_EQ(r.tensor("C").rank, 2);
}
TEST(Resolve, AssignsRowColRowLayoutToGemmTensors)
{
constexpr auto r = resolve(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_EQ(r.tensor("A").layout, Layout::Row);
EXPECT_EQ(r.tensor("B").layout, Layout::Col);
EXPECT_EQ(r.tensor("C").layout, Layout::Row);
}
// ============================================================================
// Custom tensor names
// ============================================================================
TEST(Resolve, AcceptsCustomTensorNames)
{
constexpr auto r = resolve(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "X", .rhs = "Y", .out = "Z"}}});
EXPECT_EQ(r.tensor("X").rank, 2);
EXPECT_EQ(r.tensor("Y").rank, 2);
EXPECT_EQ(r.tensor("Z").rank, 2);
}
// ============================================================================
// dtype cascade
// ============================================================================
TEST(Resolve, CascadesBF16DtypeToAllTensors)
{
constexpr auto r = resolve(
Signature{.dtype = DataType::BF16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_EQ(r.tensor("A").dtype, DataType::BF16);
EXPECT_EQ(r.tensor("C").dtype, DataType::BF16);
}
TEST(Resolve, AllowsPerTensorDtypeOverride)
{
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "C", .dtype = DataType::FP32}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_EQ(r.tensor("C").dtype, DataType::FP32);
EXPECT_EQ(r.tensor("A").dtype, DataType::FP16); // cascade still applies to A
}
// ============================================================================
// Explicit tensor rank/layout overrides
// ============================================================================
TEST(Resolve, AllowsPerTensorRankOverride)
{
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "A", .rank = 3}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_EQ(r.tensor("A").rank, 3);
}
TEST(Resolve, AllowsPerTensorLayoutOverride)
{
// Override B from default Col to Row (RxR layout)
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B", .layout = Layout::Row}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_EQ(r.tensor("A").layout, Layout::Row); // default preserved
EXPECT_EQ(r.tensor("B").layout, Layout::Row); // overridden from Col
EXPECT_EQ(r.tensor("C").layout, Layout::Row); // default preserved
}
TEST(Resolve, AllowsMultipleLayoutOverrides)
{
// Override both A and B (CxC layout)
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "A", .layout = Layout::Col},
Tensor{.name = "B", .layout = Layout::Col}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_EQ(r.tensor("A").layout, Layout::Col);
EXPECT_EQ(r.tensor("B").layout, Layout::Col);
EXPECT_EQ(r.tensor("C").layout, Layout::Row); // default preserved
}
// ============================================================================
// GEMM + Add + Relu chain
// ============================================================================
TEST(Resolve, ResolvesGemmAddReluToSixTensors)
{
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias", .out = "D"},
ReluOp{.in = "D", .out = "E"}}});
EXPECT_EQ(r.num_tensors, 6); // A, B, C, bias, D, E
}
TEST(Resolve, PropagatesRankAndLayoutThroughEpilogueChain)
{
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias", .out = "D"},
ReluOp{.in = "D", .out = "E"}}});
EXPECT_EQ(r.tensor("C").rank, 2);
EXPECT_EQ(r.tensor("bias").rank, 2);
EXPECT_EQ(r.tensor("bias").layout, Layout::Row);
EXPECT_EQ(r.tensor("D").rank, 2);
EXPECT_EQ(r.tensor("D").layout, Layout::Row);
EXPECT_EQ(r.tensor("E").rank, 2);
EXPECT_EQ(r.tensor("E").layout, Layout::Row);
}
TEST(Resolve, PropagatesRankAndLayoutThroughDiamondDAG)
{
// Diamond: GEMM->C splits into two Add paths, then joins.
// C -> Add(C,bias1)->D1 --> Add(D1,D2)->E
// C -> Add(C,bias2)->D2 -+
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias1", .out = "D1"},
AddOp{.lhs = "C", .rhs = "bias2", .out = "D2"},
AddOp{.lhs = "D1", .rhs = "D2", .out = "E"}}});
EXPECT_EQ(r.num_tensors, 8); // A, B, C, bias1, D1, bias2, D2, E
EXPECT_EQ(r.tensor("D1").rank, 2);
EXPECT_EQ(r.tensor("D2").rank, 2);
EXPECT_EQ(r.tensor("E").rank, 2);
EXPECT_EQ(r.tensor("bias1").layout, Layout::Row);
EXPECT_EQ(r.tensor("E").layout, Layout::Row);
}
TEST(Resolve, AssignsSequentialIndicesToChainedOps)
{
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias", .out = "D"},
ReluOp{.in = "D", .out = "E"}}});
EXPECT_EQ(r.tensorIndex("A"), 0);
EXPECT_EQ(r.tensorIndex("B"), 1);
EXPECT_EQ(r.tensorIndex("C"), 2);
EXPECT_EQ(r.tensorIndex("bias"), 3);
EXPECT_EQ(r.tensorIndex("D"), 4);
EXPECT_EQ(r.tensorIndex("E"), 5);
}
// ============================================================================
// Standalone AddOp
// ============================================================================
TEST(Resolve, ResolvesStandaloneAddWithoutImpliedRank)
{
constexpr auto r = resolve(
Signature{.dtype = DataType::FP32, .ops = {AddOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_EQ(r.num_tensors, 3);
EXPECT_EQ(r.tensor("A").rank, 0); // no op implies rank
EXPECT_EQ(r.tensor("A").layout, Layout::Auto); // no op implies layout
}
// ============================================================================
// Conflict detection -- redundant identical sets are silent
// ============================================================================
TEST(Resolve, AllowsRedundantIdenticalLayoutFromTwoGemmOps)
{
// GemmOp1 outputs "C" as Row. GemmOp2 uses "C" as lhs (also Row).
// Two ops set the same layout -> no conflict.
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
GemmOp{.lhs = "C", .rhs = "D", .out = "E"}}});
EXPECT_EQ(r.tensor("C").layout, Layout::Row);
EXPECT_EQ(r.tensor("C").rank, 2);
}
TEST(Resolve, AllowsPropagationThroughAddWithConsistentLayout)
{
// GemmOp sets C=Row. AddOp connects C to bias and D.
// Propagation sets bias and D to Row (matching C) -> no conflict.
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias", .out = "D"}}});
EXPECT_EQ(r.tensor("C").layout, Layout::Row);
EXPECT_EQ(r.tensor("bias").layout, Layout::Row);
EXPECT_EQ(r.tensor("D").layout, Layout::Row);
}
// ============================================================================
// FMHA pattern: two GemmOps + SoftmaxOp
// ============================================================================
TEST(Resolve, ResolvesFMHATwoGemmSoftmaxPattern)
{
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "Q", .rhs = "K", .out = "S"},
SoftmaxOp{.in = "S", .out = "P"},
GemmOp{.lhs = "P", .rhs = "V", .out = "O"}}});
EXPECT_EQ(r.num_tensors, 6); // Q, K, S, P, V, O
EXPECT_EQ(r.tensor("Q").rank, 2);
EXPECT_EQ(r.tensor("S").rank, 2);
EXPECT_EQ(r.tensor("P").rank, 2); // propagated via SoftmaxOp
EXPECT_EQ(r.tensor("O").rank, 2);
}
// ============================================================================
// Scalar tracking
// ============================================================================
TEST(Resolve, PreservesScalarNamesAndDtypes)
{
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.scalars = {Scalar{.name = "alpha", .dtype = DataType::FP32},
Scalar{.name = "beta", .dtype = DataType::FP32}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_EQ(r.num_scalars, 2);
EXPECT_EQ(r.scalar("alpha").dtype, DataType::FP32);
EXPECT_EQ(r.scalar("beta").dtype, DataType::FP32);
EXPECT_EQ(r.scalarIndex("alpha"), 0);
EXPECT_EQ(r.scalarIndex("beta"), 1);
}
TEST(Resolve, ReportsZeroScalarsWhenNoneDeclared)
{
constexpr auto r = resolve(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_EQ(r.num_scalars, 0);
}
// ============================================================================
// findTensor / findScalar (constexpr, not consteval -- returns -1 on miss)
// ============================================================================
TEST(Resolve, FindsTensorByName)
{
constexpr auto r = resolve(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_EQ(r.findTensor("A"), 0);
EXPECT_EQ(r.findTensor("C"), 2);
}
TEST(Resolve, ReturnsNegativeOneForUnknownTensor)
{
constexpr auto r = resolve(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_EQ(r.findTensor("Z"), -1);
}
TEST(Resolve, ReturnsNegativeOneForUnknownScalar)
{
constexpr auto r = resolve(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_EQ(r.findScalar("nonexistent"), -1);
}
// ============================================================================
// Quantized tensors
// ============================================================================
TEST(Resolve, QuantizedBAutoRegistersScaleTensor)
{
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B",
.dtype = DataType::I4,
.quantize = Quantization{.scale_name = "scale",
.scale_dtype = DataType::FP32,
.group_size = 128}}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
// A, B, C from GemmOp + scale auto-registered = 4 tensors
EXPECT_EQ(r.num_tensors, 4);
}
TEST(Resolve, ScaleTensorGetsDtypeFromQuantization)
{
constexpr auto r = resolve(
Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B",
.dtype = DataType::I4,
.quantize = Quantization{.scale_name = "scale",
.scale_dtype = DataType::FP32}}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
// Scale tensor dtype comes from Quantization, not the signature cascade
EXPECT_EQ(r.tensor("scale").dtype, DataType::FP32);
}
TEST(Resolve, ScaleTensorGetsRank2RowLayout)
{
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B",
.dtype = DataType::I4,
.quantize = Quantization{.scale_name = "scale"}}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_EQ(r.tensor("scale").rank, 2);
EXPECT_EQ(r.tensor("scale").layout, Layout::Row);
}
TEST(Resolve, QuantizedTensorKeepsOwnDtype)
{
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B",
.dtype = DataType::I4,
.quantize = Quantization{.scale_name = "scale"}}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_EQ(r.tensor("B").dtype, DataType::I4);
EXPECT_EQ(r.tensor("A").dtype, DataType::FP16); // cascade still works
}
TEST(Resolve, QuantizedResolvedTensorCarriesQuantizeInfo)
{
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B",
.dtype = DataType::I4,
.quantize = Quantization{.scale_name = "scale",
.scale_dtype = DataType::FP32,
.group_size = 64}}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_TRUE(r.tensor("B").quantize.has_value());
EXPECT_EQ(r.tensor("B").quantize->scale_name, "scale");
EXPECT_EQ(r.tensor("B").quantize->scale_dtype, DataType::FP32);
EXPECT_EQ(r.tensor("B").quantize->group_size, 64);
}
TEST(Resolve, NonQuantizedTensorHasNoQuantizeInfo)
{
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B",
.dtype = DataType::I4,
.quantize = Quantization{.scale_name = "scale"}}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}});
EXPECT_FALSE(r.tensor("A").quantize.has_value());
EXPECT_FALSE(r.tensor("C").quantize.has_value());
EXPECT_FALSE(r.tensor("scale").quantize.has_value());
}
TEST(Resolve, QuantizedGemmWithEpiloguePreservesScaleTensor)
{
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B",
.dtype = DataType::I4,
.quantize = Quantization{.scale_name = "scale"}}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias", .out = "D"},
ReluOp{.in = "D", .out = "E"}}});
// A, B, C, bias, D, E from ops + scale auto-registered = 7
EXPECT_EQ(r.num_tensors, 7);
EXPECT_EQ(r.tensor("scale").dtype, DataType::FP32);
EXPECT_TRUE(r.tensor("B").quantize.has_value());
}
// ============================================================================
// C++20 concepts
// ============================================================================
TEST(Concepts, ClassifiesAddAndMulAsBinaryOpLike)
{
EXPECT_TRUE(BinaryOpLike<AddOp>);
EXPECT_TRUE(BinaryOpLike<MulOp>);
EXPECT_FALSE(BinaryOpLike<ReluOp>);
EXPECT_FALSE(BinaryOpLike<SoftmaxOp>);
}
TEST(Concepts, ClassifiesActivationsAsUnaryOpLike)
{
EXPECT_TRUE(UnaryOpLike<ReluOp>);
EXPECT_TRUE(UnaryOpLike<FastGeluOp>);
EXPECT_TRUE(UnaryOpLike<GeluOp>);
EXPECT_TRUE(UnaryOpLike<SiluOp>);
EXPECT_TRUE(UnaryOpLike<SigmoidOp>);
EXPECT_TRUE(UnaryOpLike<SoftmaxOp>);
EXPECT_FALSE(UnaryOpLike<AddOp>);
EXPECT_FALSE(UnaryOpLike<GemmOp>);
}
TEST(Concepts, ClassifiesGemmOpAsBinaryButNotUnary)
{
// GemmOp has lhs/rhs/out AND is special-cased, not generic BinaryOpLike
// (it has .lhs, .rhs, .out but is handled separately in registerSlots)
EXPECT_TRUE(BinaryOpLike<GemmOp>); // structurally matches, but dispatch special-cases it
EXPECT_FALSE(UnaryOpLike<GemmOp>);
}
TEST(Concepts, ClassifiesScaleOpAsUnaryNotBinary)
{
EXPECT_TRUE(UnaryOpLike<ScaleOp>);
EXPECT_FALSE(BinaryOpLike<ScaleOp>);
}
// ============================================================================
// ScaleOp with explicit Scalar
// ============================================================================
TEST(Resolve, ScaleOpReferencesExplicitScalar)
{
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.scalars = {Scalar{.name = "alpha", .dtype = DataType::FP32}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
ScaleOp{.in = "C", .out = "D", .scale = "alpha"}}});
EXPECT_EQ(r.num_tensors, 4); // A, B, C, D
EXPECT_EQ(r.num_scalars, 1);
EXPECT_EQ(r.scalar("alpha").dtype, DataType::FP32);
EXPECT_EQ(r.scalarIndex("alpha"), 0);
}
TEST(Resolve, ScaleOpPreservesScalarDtype)
{
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.scalars = {Scalar{.name = "scale_factor", .dtype = DataType::FP16}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
ScaleOp{.in = "C", .out = "D", .scale = "scale_factor"}}});
EXPECT_EQ(r.scalar("scale_factor").dtype, DataType::FP16);
}
// ============================================================================
// Boundary: signature at kMaxTensors
// ============================================================================
TEST(Resolve, HandlesSignatureWithManyTensors)
{
// Create a chain of AddOps to generate many tensors (close to kMaxTensors).
// Each AddOp creates 3 tensors (lhs, rhs, out). We'll create a chain that
// approaches the limit.
// kMaxTensors is 16, so a signature with 3 GEMMs (each with 3 tensors = 9)
// plus some adds should get close.
constexpr auto r = resolve( //
Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A1", .rhs = "B1", .out = "C1"},
GemmOp{.lhs = "A2", .rhs = "B2", .out = "C2"},
GemmOp{.lhs = "A3", .rhs = "B3", .out = "C3"},
AddOp{.lhs = "C1", .rhs = "C2", .out = "D1"},
AddOp{.lhs = "D1", .rhs = "C3", .out = "D2"}}});
// A1, B1, C1, A2, B2, C2, A3, B3, C3, D1, D2 = 11 tensors
EXPECT_EQ(r.num_tensors, 11);
EXPECT_EQ(r.tensor("A1").dtype, DataType::FP16);
EXPECT_EQ(r.tensor("D2").dtype, DataType::FP16);
}