Files
composable_kernel/rocm_ck/tests/unit/unit_gemm_spec.cpp
John Afaganis 96c39b331e [rocm-libraries] ROCm/rocm-libraries#7829 (commit 13af7da)
[ck] Enforce ASCII-only C/C++ sources for hipRTC
 compatibility (#7829)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

CK source files must be compilable via **hipRTC (HIP runtime
compilation)**, whose preprocessor does not accept non-ASCII bytes
anywhere in a translation unit — **including in comments**. Bytes that
are harmless under `hipcc` (em-dashes, smart quotes, multiplication
signs, Greek letters, box-drawing glyphs, etc.) cause hipRTC to fail at
preprocessing time. These regularly leak in via LLM-assisted authoring
or copy/paste from formatted documents and silently break hipRTC paths
that are not exercised by the default `hipcc`-based build matrix.

This PR (a) cleans every existing violation (53 files) and (b) adds a
pre-checkin gate so new violations are rejected before merge.

## File extensions covered

Both the cleanup scan and the new Jenkins enforcement stage use the same
predicate:

```
*.h  *.hpp  *.cpp  *.h.in  *.hpp.in  *.cpp.in  *.inc  *.cl
```

(excluding `*/build/*` and `*/include/rapidjson/*`). This is a strict
superset of the existing `Clang Format` stage's predicate — `*.inc` is
added so test-fixture include files are also gated. The local pre-commit
hook's `c++/inc` type filter covers the same set.

## Why no enforcement today

CK is opted out of the rocm-libraries root `.pre-commit-config.yaml`, so
the existing `pre-commit` workflow doesn't touch CK. The local CK
`.pre-commit-config.yaml` only runs for developers who installed hooks.
The **authoritative gate is therefore the new Jenkins stage** in this
PR; the local hook is convenience.

## Commit layout (bisect-friendly)

1. `79798aa6261` — **`[ck] Convert reflect/ rendering to ASCII for
hipRTC compatibility`**
Behavior change, isolated. `TreeFormatter` swaps `├─ / └─ / │ ` for `|-
/ +- / | ` (3-col width preserved so alignment is unchanged).
`conv_description.hpp` swaps `×` for `x` as the dimension separator.
`test_conv_description.cpp` expected strings updated in lockstep so the
snapshot test stays green. This is the only commit in the series with
observable runtime impact.

2. `738fdb0d81c` — **`[ck] Strip non-ASCII bytes from C++ sources for
hipRTC compatibility`**
Mechanical text cleanup across 53 files. Replacements happen in comments
or in `std::cout` strings that are not asserted on by any test. None of
the 174 `.inc` files in the tree required edits, but they were in the
scan's predicate so the enforcement stage's predicate is a superset of
what was scanned. Full replacement table in the commit message.

3. `1d7cd8ba235` — **`[ck] Enforce ASCII-only C/C++ sources for hipRTC
compatibility`**
- New `projects/composablekernel/script/check_ascii_only.sh` (modeled on
`check_copyright_year.sh`).
- New entry in `projects/composablekernel/.pre-commit-config.yaml` under
the local-hooks block (`types_or: [c++, inc]`).
- New `ASCII Only Check` parallel stage in
`projects/composablekernel/Jenkinsfile`'s `Static checks` block,
mirroring the existing `Clang Format` stage but with `*.inc` added to
the find predicate. Always-on, no `RUN_CPPCHECK` gate.

The tree is buildable at every commit boundary. Commit 1 leaves 50 known
violations; commit 2 leaves 0; commit 3 wires the gate.

## Demo

Script output on a synthesized violation:

```
$ printf '// em-dash test \xe2\x80\x94 here\n' > /tmp/bad.cpp
$ projects/composablekernel/script/check_ascii_only.sh /tmp/bad.cpp
ERROR: /tmp/bad.cpp contains non-ASCII bytes:
1:// em-dash test — here
  Fix: replace with ASCII (em-dash -> --, smart quotes -> ", arrows -> ->, etc.)
$ echo $?
1
```

Full repo scan after the cleanup commits (note the `-name '*.inc'`
clause):

```
$ cd projects/composablekernel && find . -type f \( -name '*.h' -o -name '*.hpp' -o -name '*.cpp' \
    -o -name '*.h.in' -o -name '*.hpp.in' -o -name '*.cpp.in' -o -name '*.inc' -o -name '*.cl' \) \
    -not -path '*/build/*' -not -path '*/include/rapidjson/*' -print0 \
  | xargs -0 -P 8 -n 64 script/check_ascii_only.sh
$ echo $?
0
```

## Test plan

- [ ] Jenkins PR build: confirm new `Static checks -> ASCII Only Check`
stage runs green over the full predicate (incl. `*.inc`) and existing
`Clang Format` stage is unaffected.
- [ ] `test_conv_description` passes against the ASCII tree-formatter
output (touched in commit 1).
- [ ] Local: `pre-commit run ascii-only-checker --all-files` runs
cleanly after installing CK pre-commit hooks via
`script/install_precommit.sh`.
- [ ] Manually inject a non-ASCII byte in any `.cpp/.hpp/.inc` file,
push: confirm Jenkins fails the new stage with a clear error.
- [ ] Spot-check a representative subset of touched files under hipRTC
compilation to confirm no remaining hipRTC-blocking content (optional,
since the static byte check is a sufficient condition for hipRTC
preprocessor acceptance on this dimension).

🤖 Generated with [Claude Code](https://claude.com/claude-code)
2026-06-04 15:00:17 +00:00

1120 lines
46 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <rocm_ck/gemm_spec.hpp>
#include <gtest/gtest.h>
using ::rocm_ck::AddOp;
using ::rocm_ck::DataType;
using ::rocm_ck::EpilogueOp;
using ::rocm_ck::FastGeluOp;
using ::rocm_ck::GeluOp;
using ::rocm_ck::GemmAlgorithm;
using ::rocm_ck::GemmOp;
using ::rocm_ck::GemmSpec;
using ::rocm_ck::GpuTarget;
using ::rocm_ck::isValidWaveTile;
using ::rocm_ck::Layout;
using ::rocm_ck::makeSpec;
using ::rocm_ck::MulOp;
using ::rocm_ck::Pipeline;
using ::rocm_ck::PipelineScheduler;
using ::rocm_ck::Quantization;
using ::rocm_ck::ReluOp;
using ::rocm_ck::SigmoidOp;
using ::rocm_ck::Signature;
using ::rocm_ck::SiluOp;
using ::rocm_ck::StoreStrategy;
using ::rocm_ck::TargetSet;
using ::rocm_ck::Tensor;
using ::rocm_ck::TilePartitioner;
// ============================================================================
// isValidWaveTile
// ============================================================================
TEST(WaveTileValidation, AcceptsFP32With16x16Tile)
{
EXPECT_TRUE(isValidWaveTile(DataType::FP32, 16, 16, 4, TargetSet::cdna()));
EXPECT_TRUE(isValidWaveTile(DataType::FP32, 16, 16, 8, TargetSet::cdna()));
EXPECT_TRUE(isValidWaveTile(DataType::FP32, 16, 16, 16, TargetSet::cdna()));
}
TEST(WaveTileValidation, AcceptsFP32With32x32OnlyForSmallK)
{
EXPECT_TRUE(isValidWaveTile(DataType::FP32, 32, 32, 4, TargetSet::cdna()));
EXPECT_TRUE(isValidWaveTile(DataType::FP32, 32, 32, 8, TargetSet::cdna()));
EXPECT_FALSE(isValidWaveTile(
DataType::FP32, 32, 32, 16, TargetSet::cdna())); // k=16 invalid at 32x32 for fp32
}
TEST(WaveTileValidation, AcceptsFP16With16x16Tile)
{
EXPECT_TRUE(isValidWaveTile(DataType::FP16, 16, 16, 16, TargetSet::cdna()));
EXPECT_TRUE(isValidWaveTile(DataType::FP16, 16, 16, 32, TargetSet::cdna()));
EXPECT_FALSE(isValidWaveTile(DataType::FP16, 16, 16, 4, TargetSet::cdna()));
}
TEST(WaveTileValidation, AcceptsFP16With32x32Tile)
{
EXPECT_TRUE(isValidWaveTile(DataType::FP16, 32, 32, 8, TargetSet::cdna()));
EXPECT_TRUE(isValidWaveTile(DataType::FP16, 32, 32, 16, TargetSet::cdna()));
EXPECT_FALSE(isValidWaveTile(
DataType::FP16, 32, 32, 4, TargetSet::cdna())); // k=4 invalid at 32x32 for fp16
}
TEST(WaveTileValidation, AcceptsSameTilesForBF16AsFP16)
{
EXPECT_TRUE(isValidWaveTile(DataType::BF16, 16, 16, 16, TargetSet::cdna()));
EXPECT_TRUE(isValidWaveTile(DataType::BF16, 32, 32, 16, TargetSet::cdna()));
EXPECT_FALSE(isValidWaveTile(DataType::BF16, 32, 32, 4, TargetSet::cdna()));
}
TEST(WaveTileValidation, RejectsAsymmetricAndIntegerConfigs)
{
// Asymmetric tiles not supported
EXPECT_FALSE(isValidWaveTile(DataType::FP32, 16, 32, 8, TargetSet::cdna()));
EXPECT_FALSE(isValidWaveTile(DataType::FP16, 32, 16, 16, TargetSet::cdna()));
// Integer types not yet in wave tile validation table
EXPECT_FALSE(isValidWaveTile(DataType::I32, 16, 16, 4, TargetSet::cdna()));
}
// ============================================================================
// makeSpec: plain GEMM
// ============================================================================
TEST(MakeSpec, ProducesThreePhysicalTensorsForPlainGemm)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.num_physical_tensors, 3);
}
TEST(MakeSpec, MapsGemmTensorsToSequentialSlots)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(slot(k, "A"), 0);
EXPECT_EQ(slot(k, "B"), 1);
EXPECT_EQ(slot(k, "C"), 2);
}
TEST(MakeSpec, PropagatesDtypeToAllGemmTensors)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(dtype(k, "A"), DataType::FP16);
EXPECT_EQ(dtype(k, "B"), DataType::FP16);
EXPECT_EQ(dtype(k, "C"), DataType::FP16);
}
TEST(MakeSpec, ComputesThreadBlockSizeFromWaves)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
// 2 * 2 * 1 * 64 = 256
EXPECT_EQ(k.workgroup_size, 256);
}
TEST(MakeSpec, ReportsZeroEpilogueOpsForPlainGemm)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.num_epilogue_ops, 0);
}
// ============================================================================
// makeSpec: GEMM + Add
// ============================================================================
TEST(MakeSpec, RegistersAddAsEpilogueOp)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias", .out = "D"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.num_epilogue_ops, 1);
EXPECT_EQ(k.epilogue_ops[0], EpilogueOp::Add);
EXPECT_EQ(k.num_physical_tensors, 4); // A, B, D(output), bias(D0)
}
TEST(MakeSpec, PlacesBiasInD0SlotForGemmAdd)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias", .out = "D"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(slot(k, "D"), 2); // output slot
EXPECT_EQ(slot(k, "bias"), 3); // D0 slot
}
TEST(MakeSpec, PropagatesDtypeToBiasTensor)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias", .out = "D"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(dtype(k, "bias"), DataType::FP16);
}
// ============================================================================
// makeSpec: GEMM + Add + Relu
// ============================================================================
TEST(MakeSpec, RegistersAddAndReluAsEpilogueOps)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias", .out = "D"},
ReluOp{.in = "D", .out = "E"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.num_epilogue_ops, 2);
EXPECT_TRUE(k.hasEpilogueOp(EpilogueOp::Add));
EXPECT_TRUE(k.hasEpilogueOp(EpilogueOp::Relu));
}
TEST(MakeSpec, UsesFinalOutputSlotForGemmAddRelu)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias", .out = "D"},
ReluOp{.in = "D", .out = "E"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(slot(k, "E"), 2); // final output in slot 2
EXPECT_EQ(k.num_physical_tensors, 4); // A, B, E(output), bias(D0)
}
// ============================================================================
// makeSpec: 32x32 wave tile
// ============================================================================
TEST(MakeSpec, Accepts32x32WaveTileWithCorrectBlockSize)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {32, 32, 16}},
TargetSet::cdna());
EXPECT_EQ(k.workgroup_size, 256);
EXPECT_EQ(k.wave_tile.m, 32);
EXPECT_EQ(k.wave_tile.n, 32);
EXPECT_EQ(k.wave_tile.k, 16);
}
// ============================================================================
// makeSpec: layout defaults
// ============================================================================
TEST(MakeSpec, AssignsRowColRowLayoutByDefault)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP32, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(layout(k, "A"), Layout::Row);
EXPECT_EQ(layout(k, "B"), Layout::Col);
EXPECT_EQ(layout(k, "C"), Layout::Row);
}
TEST(MakeSpec, OverridesBLayoutToRowForRR)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B", .layout = Layout::Row}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(layout(k, "A"), Layout::Row);
EXPECT_EQ(layout(k, "B"), Layout::Row);
EXPECT_EQ(layout(k, "C"), Layout::Row);
}
TEST(MakeSpec, OverridesBothLayoutsForCC)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "A", .layout = Layout::Col},
Tensor{.name = "B", .layout = Layout::Col}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(layout(k, "A"), Layout::Col);
EXPECT_EQ(layout(k, "B"), Layout::Col);
EXPECT_EQ(layout(k, "C"), Layout::Row);
}
TEST(MakeSpec, OverridesALayoutForCR)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "A", .layout = Layout::Col},
Tensor{.name = "B", .layout = Layout::Row}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(layout(k, "A"), Layout::Col);
EXPECT_EQ(layout(k, "B"), Layout::Row);
EXPECT_EQ(layout(k, "C"), Layout::Row);
}
TEST(MakeSpec, LayoutOverrideFlowsToPhysicalTensorTable)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B", .layout = Layout::Row}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
// Verify the physical tensor table (what the device code sees)
EXPECT_EQ(k.lhs().layout, Layout::Row);
EXPECT_EQ(k.rhs().layout, Layout::Row);
EXPECT_EQ(k.output().layout, Layout::Row);
}
// ============================================================================
// GemmSpec named accessors
// ============================================================================
TEST(GemmSpec, ProvidesLhsRhsOutputNamedAccessors)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.lhs().args_slot, 0);
EXPECT_EQ(k.rhs().args_slot, 1);
EXPECT_EQ(k.output().args_slot, 2);
EXPECT_EQ(k.lhs().dtype, DataType::FP16);
}
// ============================================================================
// Accumulator dtype
// ============================================================================
TEST(MakeSpec, DefaultsAccDtypeToFP32)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.acc_dtype, DataType::FP32); // GemmOp default acc_dtype
}
// ============================================================================
// Multiple data types
// ============================================================================
TEST(MakeSpec, ProducesFP32GemmWithMatchingAccDtype)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP32, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(dtype(k, "A"), DataType::FP32);
EXPECT_EQ(k.acc_dtype, DataType::FP32);
}
TEST(MakeSpec, ProducesBF16GemmWithCorrectDtype)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::BF16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(dtype(k, "A"), DataType::BF16);
}
// ============================================================================
// Split-K (k_batch)
// ============================================================================
TEST(MakeSpec, DefaultsKBatchToOne)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.k_batch, 1);
}
TEST(MakeSpec, AcceptsExplicitKBatch)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{.block_tile = {128, 128, 32},
.block_waves = {2, 2, 1},
.wave_tile = {16, 16, 16},
.k_batch = 4},
TargetSet::cdna());
EXPECT_EQ(k.k_batch, 4);
}
TEST(MakeSpec, KBatchPreservesOtherFields)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{.block_tile = {128, 128, 32},
.block_waves = {2, 2, 1},
.wave_tile = {16, 16, 16},
.k_batch = 4},
TargetSet::cdna());
EXPECT_EQ(k.num_physical_tensors, 3);
EXPECT_EQ(k.workgroup_size, 256);
EXPECT_EQ(k.block_tile.k, 32);
}
TEST(MakeSpec, KBatchWorksWithEpilogueOps)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias", .out = "D"}}},
GemmAlgorithm{.block_tile = {128, 128, 32},
.block_waves = {2, 2, 1},
.wave_tile = {16, 16, 16},
.k_batch = 2},
TargetSet::cdna());
EXPECT_EQ(k.k_batch, 2);
EXPECT_EQ(k.num_epilogue_ops, 1);
EXPECT_TRUE(k.hasEpilogueOp(EpilogueOp::Add));
}
// ============================================================================
// isValidWaveTile: GpuTarget-specific validation
// ============================================================================
TEST(WaveTileValidation, AcceptsFP8TilesForGfx942)
{
// gfx942 base MFMA: 32x32x16, 16x16x32
EXPECT_TRUE(isValidWaveTile(DataType::FP8_FNUZ, 32, 32, 16, GpuTarget::gfx942));
EXPECT_TRUE(isValidWaveTile(DataType::FP8_FNUZ, 16, 16, 32, GpuTarget::gfx942));
// IterateK compositions available on gfx942+
EXPECT_TRUE(isValidWaveTile(DataType::FP8_FNUZ, 32, 32, 32, GpuTarget::gfx942));
EXPECT_TRUE(isValidWaveTile(DataType::FP8_FNUZ, 32, 32, 64, GpuTarget::gfx942));
EXPECT_TRUE(isValidWaveTile(DataType::FP8_FNUZ, 16, 16, 64, GpuTarget::gfx942));
}
TEST(WaveTileValidation, AcceptsFP8TilesForGfx950)
{
// gfx950 MFMA: 32x32x{16,32,64}, 16x16x{32,64}
EXPECT_TRUE(isValidWaveTile(DataType::FP8_FNUZ, 32, 32, 16, GpuTarget::gfx950));
EXPECT_TRUE(isValidWaveTile(DataType::FP8_FNUZ, 32, 32, 32, GpuTarget::gfx950));
EXPECT_TRUE(isValidWaveTile(DataType::FP8_FNUZ, 32, 32, 64, GpuTarget::gfx950));
EXPECT_TRUE(isValidWaveTile(DataType::FP8_FNUZ, 16, 16, 32, GpuTarget::gfx950));
EXPECT_TRUE(isValidWaveTile(DataType::FP8_FNUZ, 16, 16, 64, GpuTarget::gfx950));
}
TEST(WaveTileValidation, TargetSetAllMeansIntersectionAcrossAllTargets)
{
// all() = intersection across ALL targets (CDNA + RDNA).
// Only 16x16x16 FP16/BF16 pass (valid on both MFMA and WMMA).
EXPECT_TRUE(isValidWaveTile(DataType::FP16, 16, 16, 16, TargetSet::all()));
EXPECT_TRUE(isValidWaveTile(DataType::BF16, 16, 16, 16, TargetSet::all()));
// I8 16x16x16 fails -- CDNA MFMA I8 tiles are 32x32x16 and 16x16x32, not 16x16x16
EXPECT_FALSE(isValidWaveTile(DataType::I8, 16, 16, 16, TargetSet::all()));
// 32x32 tiles fail -- WMMA only has 16x16x16
EXPECT_FALSE(isValidWaveTile(DataType::FP16, 32, 32, 16, TargetSet::all()));
// FP8 fails -- gfx90a has no FP8, gfx1151 has no FP8
EXPECT_FALSE(isValidWaveTile(DataType::FP8_FNUZ, 32, 32, 16, TargetSet::all()));
EXPECT_FALSE(isValidWaveTile(DataType::FP8_FNUZ, 16, 16, 32, TargetSet::all()));
// FP32 fails -- WMMA doesn't support FP32
EXPECT_FALSE(isValidWaveTile(DataType::FP32, 16, 16, 4, TargetSet::all()));
}
TEST(WaveTileValidation, TargetSetCdnaRejectsFP8BecauseGfx90a)
{
// cdna() includes gfx90a which has no FP8 -- intersection rejects all FP8 tiles
EXPECT_FALSE(isValidWaveTile(DataType::FP8_FNUZ, 32, 32, 16, TargetSet::cdna()));
EXPECT_FALSE(isValidWaveTile(DataType::FP8_FNUZ, 32, 32, 64, TargetSet::cdna()));
}
TEST(WaveTileValidation, TargetSetGfx94AcceptsFP8)
{
// family_gfx94() = gfx942 + gfx950 -- both support FP8
EXPECT_TRUE(isValidWaveTile(DataType::FP8_FNUZ, 32, 32, 16, TargetSet::family_gfx94()));
EXPECT_TRUE(isValidWaveTile(DataType::FP8_FNUZ, 16, 16, 32, TargetSet::family_gfx94()));
// IterateK compositions valid across gfx94 family
EXPECT_TRUE(isValidWaveTile(DataType::FP8_FNUZ, 32, 32, 64, TargetSet::family_gfx94()));
}
TEST(WaveTileValidation, Gfx90aAcceptsSameTilesAsCDNABaseline)
{
// gfx90a has same MFMA tile set as the baseline (no FP8)
EXPECT_TRUE(isValidWaveTile(DataType::FP16, 16, 16, 16, GpuTarget::gfx90a));
EXPECT_TRUE(isValidWaveTile(DataType::FP16, 32, 32, 16, GpuTarget::gfx90a));
// gfx90a has no FP8 MFMA support
EXPECT_FALSE(isValidWaveTile(DataType::FP8_FNUZ, 32, 32, 16, GpuTarget::gfx90a));
}
TEST(WaveTileValidation, BF8HasSameTilesAsFP8)
{
EXPECT_TRUE(isValidWaveTile(DataType::BF8_FNUZ, 32, 32, 16, GpuTarget::gfx942));
EXPECT_TRUE(isValidWaveTile(DataType::BF8_FNUZ, 32, 32, 32, GpuTarget::gfx950));
EXPECT_TRUE(isValidWaveTile(DataType::BF8_FNUZ, 32, 32, 32, GpuTarget::gfx942));
}
// ============================================================================
// makeSpec: GpuTarget parameter
// ============================================================================
TEST(MakeSpec, AcceptsGpuTargetParameter)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
GpuTarget::gfx942);
EXPECT_EQ(k.workgroup_size, 256);
}
TEST(MakeSpec, AcceptsTargetSetCdna)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.workgroup_size, 256);
}
// ============================================================================
// makeSpec: Pipeline::Memory + Scheduling
// ============================================================================
TEST(MakeSpec, AcceptsMemoryPipelineWithIntrawaveScheduling)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{.block_tile = {128, 128, 32},
.block_waves = {2, 2, 1},
.wave_tile = {16, 16, 16},
.pipeline = Pipeline::Memory,
.pipeline_scheduler = PipelineScheduler::Intrawave},
TargetSet::cdna());
EXPECT_EQ(k.pipeline, Pipeline::Memory);
EXPECT_EQ(k.pipeline_scheduler, PipelineScheduler::Intrawave);
}
TEST(MakeSpec, AcceptsMemoryPipelineWithInterwaveScheduling)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{.block_tile = {128, 128, 32},
.block_waves = {2, 2, 1},
.wave_tile = {16, 16, 16},
.pipeline = Pipeline::Memory,
.pipeline_scheduler = PipelineScheduler::Interwave},
TargetSet::cdna());
EXPECT_EQ(k.pipeline, Pipeline::Memory);
EXPECT_EQ(k.pipeline_scheduler, PipelineScheduler::Interwave);
}
TEST(MakeSpec, DefaultsSchedulingToIntrawave)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.pipeline_scheduler, PipelineScheduler::Intrawave);
}
// ============================================================================
// makeSpec: quantized GEMM (INT4 weight with scale tensor)
// ============================================================================
TEST(MakeSpec, PlainGemmHasGroupSizeZero)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.group_size, 0);
}
TEST(MakeSpec, QuantizedBAddsScaleTensorToPhysicalTable)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B",
.dtype = DataType::I4,
.quantize = Quantization{.scale_name = "scale",
.scale_dtype = DataType::FP32,
.group_size = 128}}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.num_physical_tensors, 4); // A, B, C, scale
}
TEST(MakeSpec, ScaleTensorGetsCorrectSlotAndDtype)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B",
.dtype = DataType::I4,
.quantize = Quantization{.scale_name = "scale",
.scale_dtype = DataType::FP32,
.group_size = 128}}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(slot(k, "scale"), 3);
EXPECT_EQ(dtype(k, "scale"), DataType::FP32);
EXPECT_EQ(layout(k, "scale"), Layout::Row);
}
TEST(MakeSpec, GroupSizeMatchesQuantizationConfig)
{
constexpr auto k = makeSpec(
Signature{
.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B",
.dtype = DataType::I4,
.quantize = Quantization{.scale_name = "scale", .group_size = 64}}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.group_size, 64);
}
TEST(MakeSpec, ScaleAccessorReturnsScaleTensor)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B",
.dtype = DataType::I4,
.quantize = Quantization{.scale_name = "scale",
.scale_dtype = DataType::FP32}}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.scale().dtype, DataType::FP32);
EXPECT_EQ(k.scale().args_slot, 3);
}
TEST(MakeSpec, QuantizedGemmWithEpiloguePutsScaleAfterD0)
{
constexpr auto k =
makeSpec(Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B",
.dtype = DataType::I4,
.quantize = Quantization{.scale_name = "scale"}}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias", .out = "D"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
// A, B, D(output), bias(D0), scale = 5
EXPECT_EQ(k.num_physical_tensors, 5);
EXPECT_EQ(slot(k, "bias"), 3); // D0
EXPECT_EQ(slot(k, "scale"), 4); // scale after D0
}
TEST(MakeSpec, RhsDtypeIsI4InQuantizedGemm)
{
constexpr auto k =
makeSpec(Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B",
.dtype = DataType::I4,
.quantize = Quantization{.scale_name = "scale"}}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.rhs().dtype, DataType::I4);
EXPECT_EQ(k.lhs().dtype, DataType::FP16);
}
// ============================================================================
// makeSpec: numDTensors() derivation
// ============================================================================
TEST(MakeSpec, PlainGemmHasZeroDTensors)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.numDTensors(), 0);
}
TEST(MakeSpec, GemmAddHasOneDTensor)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias", .out = "D"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.numDTensors(), 1);
}
TEST(MakeSpec, QuantizedGemmHasZeroDTensors)
{
constexpr auto k =
makeSpec(Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B",
.dtype = DataType::I4,
.quantize = Quantization{.scale_name = "scale"}}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
// Scale is NOT a D tensor -- num_d_tensors excludes it
EXPECT_EQ(k.numDTensors(), 0);
EXPECT_EQ(k.num_physical_tensors, 4); // A, B, C, scale
}
TEST(MakeSpec, QuantizedGemmAddHasOneDTensor)
{
constexpr auto k =
makeSpec(Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B",
.dtype = DataType::I4,
.quantize = Quantization{.scale_name = "scale"}}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias", .out = "D"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
// bias is D0, scale is separate -- num_d_tensors counts only bias
EXPECT_EQ(k.numDTensors(), 1);
EXPECT_EQ(k.num_physical_tensors, 5); // A, B, D, bias, scale
}
// ============================================================================
// GemmAlgorithm padding flags
// ============================================================================
TEST(GemmAlgorithm, PaddingFlagsDefaultToFalse)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_FALSE(k.pad_m);
EXPECT_FALSE(k.pad_n);
}
TEST(GemmAlgorithm, PadMCanBeSetToTrue)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{.block_tile = {128, 128, 32},
.block_waves = {2, 2, 1},
.wave_tile = {16, 16, 16},
.pad_m = true},
TargetSet::cdna());
EXPECT_TRUE(k.pad_m);
EXPECT_FALSE(k.pad_n);
}
TEST(GemmAlgorithm, PadNCanBeSetToTrue)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{.block_tile = {128, 128, 32},
.block_waves = {2, 2, 1},
.wave_tile = {16, 16, 16},
.pad_n = true},
TargetSet::cdna());
EXPECT_FALSE(k.pad_m);
EXPECT_TRUE(k.pad_n);
}
TEST(GemmAlgorithm, BothPaddingFlagsCanBeEnabled)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{.block_tile = {128, 128, 32},
.block_waves = {2, 2, 1},
.wave_tile = {16, 16, 16},
.pad_m = true,
.pad_n = true},
TargetSet::cdna());
EXPECT_TRUE(k.pad_m);
EXPECT_TRUE(k.pad_n);
}
// ============================================================================
// Pipeline enum variants
// ============================================================================
TEST(MakeSpec, AcceptsPipelineV3)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{.block_tile = {128, 128, 32},
.block_waves = {2, 2, 1},
.wave_tile = {16, 16, 16},
.pipeline = Pipeline::V3},
TargetSet::cdna());
EXPECT_EQ(k.pipeline, Pipeline::V3);
}
TEST(MakeSpec, AcceptsPipelineV4)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{.block_tile = {128, 128, 32},
.block_waves = {2, 2, 1},
.wave_tile = {16, 16, 16},
.pipeline = Pipeline::V4},
TargetSet::cdna());
EXPECT_EQ(k.pipeline, Pipeline::V4);
}
TEST(MakeSpec, AcceptsPipelinePreshuffle)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{.block_tile = {128, 128, 32},
.block_waves = {2, 2, 1},
.wave_tile = {16, 16, 16},
.pipeline = Pipeline::Preshuffle},
TargetSet::cdna());
EXPECT_EQ(k.pipeline, Pipeline::Preshuffle);
}
// ============================================================================
// TilePartitioner enum variants
// ============================================================================
TEST(MakeSpec, AcceptsTilePartitionerDirect)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{.block_tile = {128, 128, 32},
.block_waves = {2, 2, 1},
.wave_tile = {16, 16, 16},
.tile_partitioner = TilePartitioner::Direct},
TargetSet::cdna());
EXPECT_EQ(k.tile_partitioner, TilePartitioner::Direct);
}
TEST(MakeSpec, AcceptsTilePartitionerStreamK)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{.block_tile = {128, 128, 32},
.block_waves = {2, 2, 1},
.wave_tile = {16, 16, 16},
.tile_partitioner = TilePartitioner::StreamK},
TargetSet::cdna());
EXPECT_EQ(k.tile_partitioner, TilePartitioner::StreamK);
}
// ============================================================================
// StoreStrategy enum variants
// ============================================================================
TEST(MakeSpec, AcceptsStoreStrategyDirect2D)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}}},
GemmAlgorithm{.block_tile = {128, 128, 32},
.block_waves = {2, 2, 1},
.wave_tile = {16, 16, 16},
.store_strategy = StoreStrategy::Direct2D},
TargetSet::cdna());
EXPECT_EQ(k.store_strategy, StoreStrategy::Direct2D);
}
// ============================================================================
// Explicit acc_dtype override
// ============================================================================
TEST(MakeSpec, ExplicitAccDtypeIsPreserved)
{
constexpr auto k = makeSpec(
Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C", .acc_dtype = DataType::FP16}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.acc_dtype, DataType::FP16);
}
// ============================================================================
// isValidWaveTile with unsupported dtypes
// ============================================================================
TEST(WaveTileValidation, RejectsI64)
{
EXPECT_FALSE(isValidWaveTile(DataType::I64, 16, 16, 16, TargetSet::cdna()));
EXPECT_FALSE(isValidWaveTile(DataType::I64, 32, 32, 16, TargetSet::cdna()));
}
TEST(WaveTileValidation, RejectsFP64)
{
EXPECT_FALSE(isValidWaveTile(DataType::FP64, 16, 16, 4, TargetSet::cdna()));
EXPECT_FALSE(isValidWaveTile(DataType::FP64, 32, 32, 8, TargetSet::cdna()));
}
// ============================================================================
// Quantized GEMM + multiple epilogue ops
// ============================================================================
TEST(MakeSpec, QuantizedGemmWithMultipleEpilogueOps)
{
constexpr auto k =
makeSpec(Signature{.dtype = DataType::FP16,
.tensors = {Tensor{.name = "B",
.dtype = DataType::I4,
.quantize = Quantization{.scale_name = "scale"}}},
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias", .out = "D"},
ReluOp{.in = "D", .out = "E"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
// Physical tensors: A, B, E(output), bias(D0), scale
EXPECT_EQ(k.num_physical_tensors, 5);
EXPECT_EQ(slot(k, "A"), 0);
EXPECT_EQ(slot(k, "B"), 1);
EXPECT_EQ(slot(k, "E"), 2); // final output
EXPECT_EQ(slot(k, "bias"), 3); // D0
EXPECT_EQ(slot(k, "scale"), 4); // scale tensor
// Verify epilogue ops
EXPECT_EQ(k.num_epilogue_ops, 2);
EXPECT_TRUE(k.hasEpilogueOp(EpilogueOp::Add));
EXPECT_TRUE(k.hasEpilogueOp(EpilogueOp::Relu));
// Verify dtypes
EXPECT_EQ(dtype(k, "B"), DataType::I4);
EXPECT_EQ(dtype(k, "scale"), DataType::FP32);
EXPECT_EQ(dtype(k, "bias"), DataType::FP16);
}
// ============================================================================
// makeSpec: two consecutive AddOps (Add+Add -> 2 D tensors)
// ============================================================================
TEST(MakeSpec, TwoConsecutiveAddOpsProduceTwoDTensors)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias0", .out = "D"},
AddOp{.lhs = "D", .rhs = "bias1", .out = "E"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.numDTensors(), 2);
EXPECT_EQ(k.num_physical_tensors, 5); // A, B, E(output), bias0(D0), bias1(D1)
EXPECT_EQ(slot(k, "bias0"), 3); // D0
EXPECT_EQ(slot(k, "bias1"), 4); // D1
EXPECT_EQ(slot(k, "E"), 2); // final output
EXPECT_EQ(k.num_epilogue_ops, 2);
EXPECT_EQ(k.epilogue_ops[0], EpilogueOp::Add);
EXPECT_EQ(k.epilogue_ops[1], EpilogueOp::Add);
}
// ============================================================================
// makeSpec: maximum epilogue ops (boundary test for kMaxEpilogueOps=4)
// ============================================================================
TEST(MakeSpec, AcceptsMaxEpilogueOps)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias", .out = "D"},
ReluOp{.in = "D", .out = "E"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
// 2 epilogue ops (Add + Relu) -- well under the limit of 4
EXPECT_EQ(k.num_epilogue_ops, 2);
EXPECT_TRUE(k.hasEpilogueOp(EpilogueOp::Add));
EXPECT_TRUE(k.hasEpilogueOp(EpilogueOp::Relu));
}
// ============================================================================
// Epilogue generalization: ordering, chaining, interleaving
// ============================================================================
TEST(MakeSpec, UnaryOnlyWithoutBinaryOp)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
ReluOp{.in = "C", .out = "D"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.num_epilogue_ops, 1);
EXPECT_EQ(k.epilogue_ops[0], EpilogueOp::Relu);
EXPECT_EQ(k.num_physical_tensors, 3);
EXPECT_EQ(k.numDTensors(), 0);
}
TEST(MakeSpec, ChainedUnaryOps)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
ReluOp{.in = "C", .out = "D"},
SigmoidOp{.in = "D", .out = "E"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.num_epilogue_ops, 2);
EXPECT_EQ(k.epilogue_ops[0], EpilogueOp::Relu);
EXPECT_EQ(k.epilogue_ops[1], EpilogueOp::Sigmoid);
EXPECT_EQ(k.num_physical_tensors, 3);
EXPECT_EQ(slot(k, "E"), 2);
}
TEST(MakeSpec, UnaryBeforeBinaryOp)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
ReluOp{.in = "C", .out = "D"},
AddOp{.lhs = "D", .rhs = "bias", .out = "E"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.num_epilogue_ops, 2);
EXPECT_EQ(k.epilogue_ops[0], EpilogueOp::Relu);
EXPECT_EQ(k.epilogue_ops[1], EpilogueOp::Add);
EXPECT_EQ(k.num_physical_tensors, 4);
EXPECT_EQ(slot(k, "E"), 2);
EXPECT_EQ(slot(k, "bias"), 3);
}
TEST(MakeSpec, InterleavedBinaryUnaryBinary)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
AddOp{.lhs = "C", .rhs = "bias", .out = "D"},
ReluOp{.in = "D", .out = "E"},
MulOp{.lhs = "E", .rhs = "scale", .out = "F"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.num_epilogue_ops, 3);
EXPECT_EQ(k.epilogue_ops[0], EpilogueOp::Add);
EXPECT_EQ(k.epilogue_ops[1], EpilogueOp::Relu);
EXPECT_EQ(k.epilogue_ops[2], EpilogueOp::Mul);
EXPECT_EQ(k.numDTensors(), 2);
EXPECT_EQ(slot(k, "bias"), 3);
EXPECT_EQ(slot(k, "scale"), 4);
EXPECT_EQ(slot(k, "F"), 2);
}
TEST(MakeSpec, MulOpOnly)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
MulOp{.lhs = "C", .rhs = "scale", .out = "D"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.num_epilogue_ops, 1);
EXPECT_EQ(k.epilogue_ops[0], EpilogueOp::Mul);
EXPECT_EQ(k.numDTensors(), 1);
EXPECT_EQ(slot(k, "scale"), 3);
}
TEST(MakeSpec, AllActivationVariants)
{
constexpr auto gelu = makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
GeluOp{.in = "C", .out = "D"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(gelu.epilogue_ops[0], EpilogueOp::Gelu);
constexpr auto fast_gelu =
makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
FastGeluOp{.in = "C", .out = "D"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(fast_gelu.epilogue_ops[0], EpilogueOp::FastGelu);
constexpr auto silu = makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
SiluOp{.in = "C", .out = "D"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(silu.epilogue_ops[0], EpilogueOp::Silu);
constexpr auto sigmoid = makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
SigmoidOp{.in = "C", .out = "D"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(sigmoid.epilogue_ops[0], EpilogueOp::Sigmoid);
}
TEST(MakeSpec, EpilogueOpsPreserveInsertionOrder)
{
constexpr auto k = makeSpec(Signature{.dtype = DataType::FP16,
.ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"},
SigmoidOp{.in = "C", .out = "D"},
AddOp{.lhs = "D", .rhs = "bias", .out = "E"},
FastGeluOp{.in = "E", .out = "F"}}},
GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}},
TargetSet::cdna());
EXPECT_EQ(k.num_epilogue_ops, 3);
EXPECT_EQ(k.epilogue_ops[0], EpilogueOp::Sigmoid);
EXPECT_EQ(k.epilogue_ops[1], EpilogueOp::Add);
EXPECT_EQ(k.epilogue_ops[2], EpilogueOp::FastGelu);
}