Files
composable_kernel/dispatcher/tests/test_tile_math.py
Ville Pietilä 60b276647b [rocm-libraries] ROCm/rocm-libraries#8157 (commit b0d9d39)
[CK Tile] Rule-based configuration generation in CK
 Dispatcher codegen (#8157)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

The CK Tile Dispatcher code generation for CK Tile Profiler relies on
flat JSON files to list the generated configurations. This approach has
the following problems

- The JSON files are verbose
- The JSON files get easily out of sync with the CK Builder .config
files from which they were generated from.
- The JSON file based configuration make it hard to list explicitly the
rules that govern the instance generation.

## Technical Details

Replaced the JSON files with a rule based configuration. To preserve the
existing functionality, the `profiler` and the `tests` instance sets are
generated directly from the CK Builder config files. The JSON config
files are removed from source control, and the "on-the-fly" generation
guarantees that the Dispatcher codegen uses up to date configurations.

This is PR introduces six different rule sets for the CK Tile Dispatcher
code generation

1. `profiler`: matches with the old JSON set of profiler configurations.
2. `tests`: matches with the old JSON set of tests configurations.
3. `full`: full configuration set created from a rule-based config
selection
4. `full-tests`: a subset of `full` for generating configurations for
convolution integration tests.
5. `tiny`: a subset of `full-tests` to produce the minimal set of
configurations to test the Dispatcher codegen.
6. `default`: the default rules, which corresponds to the existing
heuristic rules for configuration selection. This ensures that ML based
kernel selection doesn't get broken.

The main use of the `full` rule set is to define a reasonable solution
space for the possible implicit GEMM configurations. We start from the
configurations that allowed by the device architecture. The `full` rule
set defines the relevant tile sizes for each convolution direction. From
the tile size we have a curated mapping to the number of waves over the
different GEMM axes, i.e., we describe how many waves each GEMM
dimensions corresponds to. The GEMM-K wave tile dimension can be
computed from the other parameters and does not need to be listed
explicitly.

An orthogonal axis to the tiling strategy is the vectorization strategy.
This mainly defined by the data type and hardware as in general, we want
to use the maximum possible load widths. The maximum sizes for each
convolution direction variant are defined by the implicit GEMM matrix
dimensions. For cases where have a low number of channels per
convolution group, we need smaller vector load sizes. These are captured
by the `VecStrategy` enumeration in the codegen rules.

The problem with the rule based configuration selection is that we "over
generate" configurations. The old JSON configurations compose
approximately 25% of all configuration that the `full` rule set creates.
The additional configurations are valid, but they many not provide any
performance benefits. Hence, we keep the `profiler` and `tests` rule set
for now to avoid building an excessive amount configurations by default.
The `full` rule set can be taken into use by specifying CMake
configuration flag `-D DISPATCHER_RULE_SET=full`. By default, the
`tests` rule set is used, i.e., we don't change the existing bahaviour.

## Test Plan

Added a new stage in the CI/CD pipeline that ensures the Dispatcher
codegen rules are up to date. Otherwise the functionality is covered by
the existing CI/CD tests. There are no functional changes to the
convolution kernels. Only how the different instances are generated.

## Test Result

If the CK Tile conv instances build without errors, the Dispatcher
codegen is generating valid code. If all tests in CI/CD pipeline are
passing, the Dispatcher codegen generates valid instances.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-06-18 01:22:50 +00:00

615 lines
28 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Unit tests for codegen/tile_math.py.
The reference lookup tables TILE_TO_WAVE_WARP and _TILE_WTILK_TO_VECS are
used as ground-truth oracles. The mathematical functions in tile_math.py
must generate at least the set of configurations present in those tables
(no false negatives).
Run:
python3 -m pytest dispatcher/tests/test_tile_math.py -v
or:
cd projects/composablekernel/dispatcher
python3 -m pytest tests/test_tile_math.py -v
"""
import sys
import unittest
from pathlib import Path
from typing import Dict, List, Tuple
SCRIPT_DIR = Path(__file__).parent.resolve()
DISPATCHER_DIR = SCRIPT_DIR.parent
sys.path.insert(0, str(DISPATCHER_DIR / "codegen"))
from arch_specs_generated import WARP_SUPPORTED_COMBINATIONS, ELEMENT_SIZE_MAP # noqa: E402
from grouped_conv.tile_math import ( # noqa: E402
get_valid_wave_warp_pairs,
get_valid_vec_sizes,
get_vec_sizes_for_wave_warp,
dtype_keys_for_warp_tile_k,
WARP_SIZE,
)
# =============================================================================
# Reference oracle tables (ground truth extracted from CK profiler configs)
#
# They live here as static test fixtures: the math functions in tile_math.py must generate
# at least every entry present in these tables (zero false negatives).
# =============================================================================
TILE_TO_WAVE_WARP: Dict[Tuple[int, int, int], List[Tuple[Tuple, Tuple]]] = {
(16, 16, 32): [((1, 1, 1), (16, 16, 8)), ((1, 1, 1), (16, 16, 16))],
(16, 16, 64): [((1, 1, 1), (16, 16, 8)), ((1, 1, 1), (16, 16, 16))],
(16, 16, 128): [((1, 1, 1), (16, 16, 8)), ((1, 1, 1), (16, 16, 16))],
(16, 32, 64): [((1, 2, 1), (16, 16, 8)), ((1, 2, 1), (16, 16, 16))],
(16, 64, 32): [((1, 1, 1), (16, 16, 8)), ((1, 1, 1), (16, 16, 16))],
(16, 64, 64): [((1, 2, 1), (16, 16, 8)), ((1, 2, 1), (16, 16, 16))],
(16, 128, 32): [((1, 1, 1), (16, 16, 16))],
(16, 128, 64): [((1, 2, 1), (16, 16, 8)), ((1, 2, 1), (16, 16, 16))],
(16, 256, 32): [((1, 1, 1), (16, 16, 16))],
(16, 256, 64): [((1, 4, 1), (16, 16, 16))],
(32, 16, 64): [((2, 1, 1), (16, 16, 8)), ((2, 1, 1), (16, 16, 16))],
(32, 32, 32): [((1, 1, 1), (32, 32, 8))],
(32, 64, 16): [((1, 1, 1), (32, 32, 4))],
(32, 64, 32): [((1, 1, 1), (32, 32, 8))],
(32, 64, 64): [((1, 2, 1), (32, 32, 8))],
(32, 128, 16): [((1, 2, 1), (32, 32, 4))],
(32, 128, 32): [((1, 1, 1), (32, 32, 8)), ((1, 2, 1), (32, 32, 8))],
(32, 128, 64): [((1, 2, 1), (32, 32, 8))],
(32, 256, 64): [((1, 4, 1), (32, 32, 8))],
(64, 16, 16): [((1, 1, 1), (16, 16, 4)), ((1, 1, 1), (16, 16, 16)),
((4, 1, 1), (16, 16, 4)), ((4, 1, 1), (16, 16, 16))],
(64, 16, 32): [((1, 1, 1), (16, 16, 8)), ((1, 1, 1), (16, 16, 16)),
((4, 1, 1), (16, 16, 8)), ((4, 1, 1), (16, 16, 16))],
(64, 16, 64): [((2, 1, 1), (16, 16, 8)), ((2, 1, 1), (16, 16, 16)),
((4, 1, 1), (16, 16, 16))],
(64, 32, 16): [((1, 1, 1), (32, 32, 4))],
(64, 32, 32): [((1, 1, 1), (32, 32, 8))],
(64, 32, 64): [((2, 1, 1), (32, 32, 8))],
(64, 64, 8): [((2, 1, 1), (32, 32, 8))],
(64, 64, 16): [((1, 1, 1), (32, 32, 4))],
(64, 64, 32): [((1, 1, 1), (32, 32, 8)), ((2, 2, 1), (16, 16, 8)),
((2, 2, 1), (16, 16, 16))],
(64, 64, 64): [((2, 2, 1), (16, 16, 16)), ((2, 2, 1), (32, 32, 8))],
(64, 128, 16): [((1, 2, 1), (32, 32, 4)), ((2, 2, 1), (32, 32, 4))],
(64, 128, 32): [((1, 2, 1), (32, 32, 8)), ((2, 2, 1), (32, 32, 8))],
(64, 128, 64): [((2, 2, 1), (32, 32, 8))],
(128, 16, 64): [((2, 1, 1), (16, 16, 8)), ((2, 1, 1), (16, 16, 16))],
(128, 32, 16): [((2, 1, 1), (32, 32, 4)), ((4, 1, 1), (32, 32, 4)),
((4, 1, 1), (32, 32, 8))],
(128, 32, 32): [((1, 1, 1), (32, 32, 8)), ((2, 1, 1), (32, 32, 8)),
((2, 1, 2), (32, 32, 8)), ((4, 1, 1), (32, 32, 8))],
(128, 32, 64): [((2, 1, 1), (32, 32, 8)), ((4, 1, 1), (32, 32, 16))],
(128, 64, 8): [((2, 1, 1), (32, 32, 8)), ((2, 2, 1), (32, 32, 8))],
(128, 64, 16): [((2, 1, 1), (32, 32, 4)), ((2, 2, 1), (32, 32, 4)),
((2, 2, 1), (32, 32, 8))],
(128, 64, 32): [((2, 1, 1), (32, 32, 8)), ((2, 2, 1), (32, 32, 8))],
(128, 64, 64): [((2, 2, 1), (32, 32, 8))],
(128, 128, 16):[((1, 2, 1), (32, 32, 4)), ((2, 2, 1), (32, 32, 4))],
(128, 128, 32):[((1, 2, 1), (32, 32, 8)), ((2, 2, 1), (32, 32, 8))],
(128, 128, 64):[((2, 2, 1), (32, 32, 8))],
(128, 192, 16):[((2, 2, 1), (32, 32, 4))],
(128, 256, 16):[((2, 2, 1), (32, 32, 4))],
(128, 256, 32):[((2, 2, 1), (32, 32, 8))],
(224, 256, 64):[((2, 2, 1), (16, 16, 16))],
(256, 16, 64): [((4, 1, 1), (16, 16, 16))],
(256, 32, 64): [((4, 1, 1), (32, 32, 8))],
(256, 64, 8): [((2, 2, 1), (32, 32, 8))],
(256, 128, 16):[((2, 2, 1), (32, 32, 4))],
(256, 128, 32):[((2, 2, 1), (32, 32, 8))],
(256, 224, 64):[((2, 2, 1), (16, 16, 16))],
(256, 256, 32):[((2, 2, 1), (16, 16, 16)), ((2, 2, 1), (32, 32, 8))],
}
_TILE_WTILK_TO_VECS: Dict[Tuple[int, int, int, int], List[Tuple[int, int, int]]] = {
(16, 16, 32, 8): [(1, 1, 2)],
(16, 16, 32, 16): [(1, 1, 1), (1, 1, 2)],
(16, 16, 64, 8): [(4, 4, 4)],
(16, 16, 64, 16): [(1, 1, 1), (1, 4, 4), (4, 1, 1), (4, 4, 4), (8, 8, 4)],
(16, 16, 128, 8): [(4, 4, 4)],
(16, 16, 128, 16): [(8, 8, 4)],
(16, 32, 64, 8): [(4, 4, 4)],
(16, 32, 64, 16): [(1, 1, 1), (1, 2, 4), (1, 4, 4), (2, 1, 1), (2, 2, 4), (2, 4, 4), (8, 8, 4)],
(16, 64, 32, 8): [(1, 4, 4), (4, 1, 1), (4, 4, 4)],
(16, 64, 32, 16): [(1, 8, 4), (8, 1, 1), (8, 8, 4)],
(16, 64, 64, 8): [(4, 4, 4)],
(16, 64, 64, 16): [(1, 1, 1), (1, 8, 4), (2, 1, 1), (2, 8, 4), (8, 8, 4)],
(16, 128, 32, 16): [(4, 4, 1)],
(16, 128, 64, 8): [(4, 4, 4)],
(16, 128, 64, 16): [(1, 8, 4), (2, 8, 4), (8, 8, 4)],
(16, 256, 32, 16): [(8, 8, 1)],
(16, 256, 64, 16): [(1, 8, 4), (2, 8, 4), (8, 8, 4)],
(32, 16, 64, 8): [(4, 4, 2)],
(32, 16, 64, 16): [(1, 1, 1), (1, 2, 2), (2, 1, 1), (2, 2, 2), (4, 1, 1), (4, 2, 2), (8, 8, 2)],
(32, 32, 32, 8): [(2, 2, 1), (2, 2, 2)],
(32, 64, 16, 4): [(1, 4, 4), (4, 4, 4)],
(32, 64, 32, 8): [(1, 1, 8), (2, 2, 1), (2, 8, 8), (4, 4, 1), (4, 4, 2), (4, 4, 4), (8, 8, 8)],
(32, 64, 64, 8): [(4, 4, 8), (8, 8, 8)],
(32, 128, 16, 4): [(4, 4, 4)],
(32, 128, 32, 8): [(1, 1, 8), (4, 4, 4), (8, 8, 1), (8, 8, 2), (8, 8, 8)],
(32, 128, 64, 8): [(4, 4, 8), (8, 8, 8)],
(32, 256, 64, 8): [(8, 8, 8)],
(64, 16, 16, 4): [(4, 1, 1)],
(64, 16, 16, 16): [(4, 1, 1)],
(64, 16, 32, 8): [(1, 4, 4), (4, 1, 1), (4, 4, 4), (8, 1, 1), (8, 2, 2)],
(64, 16, 32, 16): [(1, 8, 4), (8, 1, 1), (8, 2, 2), (8, 8, 4)],
(64, 16, 64, 8): [(4, 4, 2)],
(64, 16, 64, 16): [(1, 1, 1), (1, 2, 2), (8, 1, 1), (8, 2, 2), (8, 8, 2), (16, 1, 1), (16, 2, 2)],
(64, 32, 16, 4): [(4, 4, 1), (4, 4, 4)],
(64, 32, 32, 8): [(1, 1, 8), (4, 4, 1), (4, 4, 2), (4, 4, 4), (8, 8, 1), (8, 8, 8)],
(64, 32, 64, 8): [(4, 4, 4), (8, 8, 4)],
(64, 64, 8, 8): [(1, 1, 8)],
(64, 64, 16, 4): [(1, 1, 1), (4, 4, 4)],
(64, 64, 32, 8): [(1, 1, 1), (1, 1, 8), (1, 2, 1), (2, 1, 2), (2, 2, 2), (4, 4, 4), (8, 8, 8)],
(64, 64, 32, 16): [(1, 2, 1), (2, 1, 2), (4, 4, 4), (8, 8, 8)],
(64, 64, 64, 8): [(1, 1, 1), (1, 4, 4), (2, 2, 2), (4, 1, 1), (4, 4, 4), (8, 8, 4), (8, 8, 8)],
(64, 64, 64, 16): [(2, 2, 4), (4, 1, 1), (8, 8, 2), (8, 8, 8)],
(64, 128, 16, 4): [(4, 4, 4)],
(64, 128, 32, 8): [(1, 1, 8), (1, 4, 4), (1, 8, 8), (4, 4, 4), (8, 8, 8)],
(64, 128, 64, 8): [(8, 8, 8)],
(128, 16, 64, 8): [(4, 4, 2)],
(128, 16, 64, 16): [(8, 1, 1), (8, 2, 2), (8, 8, 2)],
(128, 32, 16, 4): [(4, 1, 1), (4, 2, 2), (4, 4, 4)],
(128, 32, 16, 8): [(4, 1, 1), (4, 2, 2)],
(128, 32, 32, 8): [(1, 1, 8), (4, 1, 1), (4, 4, 4), (8, 1, 1), (8, 2, 2), (8, 8, 1), (8, 8, 2), (8, 8, 8)],
(128, 32, 64, 8): [(4, 4, 4), (8, 8, 4)],
(128, 32, 64, 16): [(16, 1, 1), (16, 2, 2), (16, 8, 8)],
(128, 64, 8, 8): [(1, 1, 8)],
(128, 64, 16, 4): [(4, 4, 4)],
(128, 64, 16, 8): [(1, 1, 8)],
(128, 64, 32, 8): [(1, 1, 8), (4, 4, 4), (8, 8, 8)],
(128, 64, 64, 8): [(8, 8, 8)],
(128, 128, 16, 4): [(1, 1, 4), (4, 4, 4)],
(128, 128, 32, 8): [(1, 1, 8), (4, 4, 4), (4, 4, 8), (8, 8, 8)],
(128, 128, 64, 8): [(4, 4, 4), (4, 4, 8), (8, 8, 4), (8, 8, 8)],
(128, 192, 16, 4): [(4, 4, 4)],
(128, 256, 16, 4): [(4, 4, 4)],
(128, 256, 32, 8): [(1, 1, 8), (4, 4, 4), (8, 4, 8), (8, 8, 8)],
(224, 256, 64, 16):[(8, 8, 8)],
(256, 16, 64, 16): [(8, 1, 1), (8, 2, 2), (8, 8, 2)],
(256, 32, 64, 8): [(8, 8, 4), (8, 8, 8)],
(256, 64, 8, 8): [(1, 1, 8)],
(256, 128, 16, 4): [(4, 4, 4)],
(256, 128, 32, 8): [(1, 1, 8), (2, 2, 2), (4, 4, 4), (8, 8, 8)],
(256, 224, 64, 16):[(8, 8, 8)],
(256, 256, 32, 8): [(4, 4, 4), (8, 8, 4), (8, 8, 8)],
(256, 256, 32, 16):[(8, 8, 8)],
}
# =============================================================================
# TestGetValidWaveWarpPairs
# =============================================================================
class TestGetValidWaveWarpPairs(unittest.TestCase):
"""Tests for get_valid_wave_warp_pairs()."""
# --- Spot checks ---
def test_spot_check_128_64_32(self):
"""Reference pairs for (128,64,32) must all be in the math result."""
tile = (128, 64, 32)
ref_pairs = TILE_TO_WAVE_WARP[tile]
math_pairs = set(get_valid_wave_warp_pairs(*tile, "bf16_bf16_fp32"))
for pair in ref_pairs:
self.assertIn(
pair, math_pairs,
f"Missing reference pair {pair} for tile {tile}",
)
def test_spot_check_64_64_32(self):
"""Reference pairs for (64,64,32) must all be in the union of dtype results."""
tile = (64, 64, 32)
ref_pairs = TILE_TO_WAVE_WARP[tile]
# Union across dtype_keys — reference table is dtype-agnostic
math_pairs = set(get_valid_wave_warp_pairs(*tile, "bf16_bf16_fp32"))
math_pairs |= set(get_valid_wave_warp_pairs(*tile, "fp32_fp32_fp32"))
math_pairs |= set(get_valid_wave_warp_pairs(*tile, "fp16_fp16_fp32"))
for pair in ref_pairs:
self.assertIn(
pair, math_pairs,
f"Missing reference pair {pair} for tile {tile}",
)
def test_spot_check_128_32_32_fp32(self):
"""Reference pairs for (128,32,32) include fp32 warp tiles."""
tile = (128, 32, 32)
ref_pairs = TILE_TO_WAVE_WARP[tile]
# fp32 uses warp_tile_k in {4, 8}; bf16 uses {8, 16}
# Try both dtype_keys and require the union covers all reference pairs
math_pairs = set(get_valid_wave_warp_pairs(*tile, "bf16_bf16_fp32"))
math_pairs |= set(get_valid_wave_warp_pairs(*tile, "fp32_fp32_fp32"))
for pair in ref_pairs:
self.assertIn(
pair, math_pairs,
f"Missing reference pair {pair} for tile {tile}",
)
# --- Full coverage of reference table ---
def test_coverage_all_reference_tiles(self):
"""Every (tile, pair) in TILE_TO_WAVE_WARP must be generated by the math.
The math is called with both bf16_bf16_fp32 and fp32_fp32_fp32 dtype_keys
and the union must contain every reference pair. This is the zero-false-
negatives requirement.
"""
DTYPE_KEYS = ["bf16_bf16_fp32", "fp32_fp32_fp32", "fp16_fp16_fp32"]
missing = []
for tile, ref_pairs in TILE_TO_WAVE_WARP.items():
# Union across dtype_keys (reference table is dtype-agnostic)
math_pairs: set = set()
for dk in DTYPE_KEYS:
math_pairs |= set(get_valid_wave_warp_pairs(*tile, dk))
for pair in ref_pairs:
if pair not in math_pairs:
missing.append((tile, pair))
if missing:
lines = [f" tile={t} pair={p}" for t, p in missing]
self.fail(
f"{len(missing)} reference pairs not generated by math:\n"
+ "\n".join(lines)
)
# --- Structural validity of math output ---
def test_no_invalid_wave_combos(self):
"""Every wave combo returned must be in WARP_SUPPORTED_COMBINATIONS for gfx942."""
valid_waves = {tuple(c) for c in WARP_SUPPORTED_COMBINATIONS["gfx942"]}
tile = (128, 64, 32)
for wave, _warp in get_valid_wave_warp_pairs(*tile, "bf16_bf16_fp32"):
self.assertIn(
wave, valid_waves,
f"Wave combo {wave} not in gfx942 supported combinations",
)
def test_no_invalid_wave_combos_all_tiles(self):
"""All tiles in reference table: every returned wave must be arch-valid."""
valid_waves = {tuple(c) for c in WARP_SUPPORTED_COMBINATIONS["gfx942"]}
for tile in TILE_TO_WAVE_WARP:
for wave, _warp in get_valid_wave_warp_pairs(*tile, "bf16_bf16_fp32"):
self.assertIn(
wave, valid_waves,
f"tile={tile}: wave {wave} not in gfx942 supported combinations",
)
def test_tile_divisibility_enforced(self):
"""A tile not divisible by any warp_tile_m must return no pairs."""
# (13, 17, 32) — prime dimensions, won't divide any MFMA shape
pairs = get_valid_wave_warp_pairs(13, 17, 32, "bf16_bf16_fp32")
self.assertEqual(pairs, [], "Expected no pairs for prime-dimension tile")
def test_warp_tile_divides_tile_m_and_n(self):
"""For every returned pair, warp_tile_m | tile_m and warp_tile_n | tile_n."""
for tile in TILE_TO_WAVE_WARP:
tm, tn, tk = tile
for _wave, (wt_m, wt_n, wt_k) in get_valid_wave_warp_pairs(*tile, "bf16_bf16_fp32"):
self.assertEqual(tm % wt_m, 0, f"tile_m={tm} not divisible by warp_m={wt_m}")
self.assertEqual(tn % wt_n, 0, f"tile_n={tn} not divisible by warp_n={wt_n}")
# --- Special cases ---
def test_wave_k2_special_case(self):
"""(128,32,32) must produce the ((2,1,2), (32,32,8)) pair (wave_k=2)."""
tile = (128, 32, 32)
math_pairs = set(get_valid_wave_warp_pairs(*tile, "bf16_bf16_fp32"))
expected = ((2, 1, 2), (32, 32, 8))
self.assertIn(
expected, math_pairs,
f"Expected wave_k=2 pair {expected} not found for tile {tile}",
)
def test_pipeline_async_constraint(self):
"""pipeline='basic_async_v1' must only return pairs with wave_n==2 and warp_tile_n==16."""
tile = (64, 64, 32)
pairs = get_valid_wave_warp_pairs(*tile, "bf16_bf16_fp32", pipeline="basic_async_v1")
for wave, (wt_m, wt_n, wt_k) in pairs:
_wm, wn, _wk = wave
self.assertEqual(wn, 2, f"async pipeline: expected wave_n=2, got {wn}")
self.assertEqual(wt_n, 16, f"async pipeline: expected warp_tile_n=16, got {wt_n}")
def test_no_duplicates(self):
"""The returned list must not contain duplicate pairs."""
for tile in list(TILE_TO_WAVE_WARP.keys()):
pairs = get_valid_wave_warp_pairs(*tile, "bf16_bf16_fp32")
self.assertEqual(
len(pairs), len(set(pairs)),
f"tile={tile}: duplicate pairs in result",
)
def test_gfx950_returns_superset_of_gfx942(self):
"""gfx950 supports more wave combos, so it must return >= as many pairs as gfx942."""
tile = (128, 64, 32)
pairs_942 = set(get_valid_wave_warp_pairs(*tile, "bf16_bf16_fp32", arch="gfx942"))
pairs_950 = set(get_valid_wave_warp_pairs(*tile, "bf16_bf16_fp32", arch="gfx950"))
self.assertTrue(
pairs_942.issubset(pairs_950),
"gfx942 pairs should be a subset of gfx950 pairs",
)
# =============================================================================
# TestGetValidVecSizes
# =============================================================================
class TestGetValidVecSizes(unittest.TestCase):
"""Tests for get_valid_vec_sizes()."""
# --- Spot checks ---
def test_spot_check_128_64_32_wave221_wt_32328(self):
"""Reference vecs for (128,64,32) wave=(2,2,1) warp=(32,32,8) bf16."""
vecs = get_valid_vec_sizes(128, 64, 32, 2, 2, 1, 32, 32, 8, "bf16_bf16_fp32")
# Reference: [(1,1,8), (4,4,4), (8,8,8)] from _TILE_WTILK_TO_VECS[(128,64,32,8)]
ref = _TILE_WTILK_TO_VECS[(128, 64, 32, 8)]
math_set = set(vecs)
for v in ref:
self.assertIn(v, math_set, f"Reference vec {v} missing for (128,64,32) wave=(2,2,1)")
def test_spot_check_64_64_32_wave111_wt_32328(self):
"""Reference vecs for (64,64,32) wave=(1,1,1) warp=(32,32,8) bf16."""
vecs = get_valid_vec_sizes(64, 64, 32, 1, 1, 1, 32, 32, 8, "bf16_bf16_fp32")
ref = _TILE_WTILK_TO_VECS[(64, 64, 32, 8)]
math_set = set(vecs)
for v in ref:
self.assertIn(v, math_set, f"Reference vec {v} missing for (64,64,32) wave=(1,1,1)")
# --- Structural validity ---
def test_pixel_budget_respected(self):
"""vec_a/vec_b must be compatible with their pixel budget.
Compatible means the vector divides the per-thread pixel budget OR is an
exact multiple of it (the wide load decomposes into v/pixels sub-loads).
"""
configs = [
(128, 64, 32, 2, 2, 1, 32, 32, 8, "bf16_bf16_fp32"),
(64, 64, 32, 1, 1, 1, 32, 32, 8, "bf16_bf16_fp32"),
(256, 128, 32, 2, 2, 1, 32, 32, 8, "bf16_bf16_fp32"),
]
for cfg in configs:
tm, tn, tk, wm, wn, wk, wt_m, wt_n, wt_k, dk = cfg
block_size = WARP_SIZE * wm * wn * wk
pixels_a = tm * tk // block_size
pixels_b = tn * tk // block_size
for va, vb, vc in get_valid_vec_sizes(*cfg):
self.assertTrue(pixels_a % va == 0 or va % pixels_a == 0,
f"cfg={cfg}: pixels_a={pixels_a} incompatible with vec_a={va}")
self.assertTrue(pixels_b % vb == 0 or vb % pixels_b == 0,
f"cfg={cfg}: pixels_b={pixels_b} incompatible with vec_b={vb}")
def test_fp32_vec_c_8_admitted(self):
"""fp32 tiles admit vec_c=8 (relaxed 32-byte ceiling). Category (A)."""
vecs = get_valid_vec_sizes(128, 128, 32, 2, 2, 1, 32, 32, 8, "fp32_fp32_fp32")
self.assertIn((4, 4, 8), set(vecs),
"fp32 (128,128,32) should admit vec triple (4,4,8)")
def test_bf16_asymmetric_small_pixel_vec8(self):
"""Asymmetric small-pixel tile admits a vec=8 triple. Category (B)."""
# (16,256,64) wave=(1,4,1): pixels_a = 16*64/256 = 4 < 8, but 8 % 4 == 0.
vecs = get_valid_vec_sizes(16, 256, 64, 1, 4, 1, 16, 16, 16, "bf16_bf16_fp32")
self.assertTrue(any(8 in (va, vb) for va, vb, vc in vecs),
"bf16 (16,256,64) wave=(1,4,1) should admit a vec=8 triple")
def test_bf16_vec_c_capped_at_16(self):
"""bf16 vec_c never exceeds 16 (LDS 256-bit ceiling still enforced)."""
for va, vb, vc in get_valid_vec_sizes(128, 128, 32, 2, 2, 1, 32, 32, 8, "bf16_bf16_fp32"):
self.assertLessEqual(vc, 16, f"bf16: vec_c={vc} > 16")
def test_lds_validity_bf16(self):
"""All returned vecs must satisfy the power-of-2 bit-width constraint for bf16.
Standard LDS allows 8128 bits; some bwd_data global loads allow 256 bits.
The constraint is: bits = vec * sizeof * 8 must be a power of 2 up to 256.
"""
sizeof_bf16 = 2.0
for va, vb, vc in get_valid_vec_sizes(128, 64, 32, 2, 2, 1, 32, 32, 8, "bf16_bf16_fp32"):
for v in (va, vb, vc):
bits = int(v * sizeof_bf16 * 8)
self.assertGreater(bits, 0, f"vec={v}: bits must be > 0")
self.assertLessEqual(bits, 256, f"vec={v}: bits={bits} > 256")
self.assertEqual(bits & (bits - 1), 0, f"vec={v}: bits={bits} not power of 2")
def test_dtype_max_vec_fp32(self):
"""For fp32, vec_a and vec_b must be <= 8 (32 bytes / 4 bytes, allowing 2×16-byte loads)."""
for va, vb, vc in get_valid_vec_sizes(128, 64, 32, 2, 2, 1, 32, 32, 4, "fp32_fp32_fp32"):
self.assertLessEqual(va, 8, f"fp32: vec_a={va} > 8")
self.assertLessEqual(vb, 8, f"fp32: vec_b={vb} > 8")
def test_dtype_max_vec_bf16(self):
"""For bf16, vec_a and vec_b must be <= 16 (32 bytes / 2 bytes per element)."""
for va, vb, vc in get_valid_vec_sizes(128, 64, 32, 2, 2, 1, 32, 32, 8, "bf16_bf16_fp32"):
self.assertLessEqual(va, 16, f"bf16: vec_a={va} > 16")
self.assertLessEqual(vb, 16, f"bf16: vec_b={vb} > 16")
def test_vec_c_divisibility(self):
"""tile_n must be divisible by vec_c (output is vectorized along N)."""
tile_n = 64
wave_n = 2
warp_tile_n = 16
for va, vb, vc in get_valid_vec_sizes(
128, tile_n, 32, 2, wave_n, 1, 16, warp_tile_n, 16, "bf16_bf16_fp32"
):
self.assertEqual(
tile_n % vc, 0,
f"tile_n={tile_n} not divisible by vec_c={vc}",
)
def test_empty_result_for_non_divisible_block(self):
"""If tile_m * tile_k is not divisible by block_size, return []."""
# tile_m=13 (prime) → pixels_a won't be integer → expect empty
result = get_valid_vec_sizes(13, 16, 32, 1, 1, 1, 16, 16, 16, "bf16_bf16_fp32")
self.assertEqual(result, [])
# =============================================================================
# TestGetVecSizesForWaveWarp (wrapper)
# =============================================================================
class TestGetVecSizesForWaveWarp(unittest.TestCase):
"""Tests for the get_vec_sizes_for_wave_warp() convenience wrapper."""
def test_spot_check_128_64_32_wt8(self):
"""Reference vecs for (128,64,32,8) in _TILE_WTILK_TO_VECS must all appear."""
ref = _TILE_WTILK_TO_VECS[(128, 64, 32, 8)]
math_set = set(get_vec_sizes_for_wave_warp(128, 64, 32, 8, "bf16_bf16_fp32"))
for v in ref:
self.assertIn(v, math_set,
f"Reference vec {v} missing from get_vec_sizes_for_wave_warp(128,64,32,8)")
def test_spot_check_64_64_32_wt8(self):
"""Reference vecs for (64,64,32,8) in _TILE_WTILK_TO_VECS must all appear."""
ref = _TILE_WTILK_TO_VECS[(64, 64, 32, 8)]
math_set = set(get_vec_sizes_for_wave_warp(64, 64, 32, 8, "bf16_bf16_fp32"))
for v in ref:
self.assertIn(v, math_set,
f"Reference vec {v} missing from get_vec_sizes_for_wave_warp(64,64,32,8)")
def test_result_is_sorted(self):
"""Returned list must be in sorted order."""
result = get_vec_sizes_for_wave_warp(128, 64, 32, 8, "bf16_bf16_fp32")
self.assertEqual(result, sorted(result))
def test_no_duplicates(self):
"""Returned list must not contain duplicates."""
result = get_vec_sizes_for_wave_warp(128, 64, 32, 8, "bf16_bf16_fp32")
self.assertEqual(len(result), len(set(result)))
# =============================================================================
# TestMathVsReferenceStatistics — coverage tests (the key correctness tests)
# =============================================================================
# Map warp_tile_k → dtype_keys to try when verifying _TILE_WTILK_TO_VECS
# Derived from warp_gemm_dispatcher.hpp; we try all plausible dtypes per warp_tile_k.
_DTYPE_KEYS_ALL = ["bf16_bf16_fp32", "fp32_fp32_fp32", "fp16_fp16_fp32", "fp8_fp8_fp32"]
class TestMathVsReferenceStatistics(unittest.TestCase):
"""Comprehensive coverage: math must generate >= all reference entries."""
def test_wave_warp_no_false_negatives(self):
"""TILE_TO_WAVE_WARP: every reference pair must be in the math output.
The math is queried with multiple dtype_keys and the union must cover
every reference pair. This guarantees zero false negatives.
"""
missing = []
for tile, ref_pairs in TILE_TO_WAVE_WARP.items():
math_pairs: set = set()
for dk in _DTYPE_KEYS_ALL:
math_pairs |= set(get_valid_wave_warp_pairs(*tile, dk))
for pair in ref_pairs:
if pair not in math_pairs:
missing.append((tile, pair))
if missing:
lines = [f" tile={t} pair={p}" for t, p in missing[:20]]
self.fail(
f"{len(missing)} reference wave/warp pairs not generated by math "
f"(showing first 20):\n" + "\n".join(lines)
)
def test_vec_no_false_negatives(self):
"""_TILE_WTILK_TO_VECS: every reference vec must be in the math output.
For each (tile_m, tile_n, tile_k, warp_tile_k) key, the math is queried
with all plausible dtype_keys (inferred from warp_tile_k) and the union
must contain every reference vec triple.
"""
missing = []
for (tm, tn, tk, wtk), ref_vecs in _TILE_WTILK_TO_VECS.items():
dtype_keys = dtype_keys_for_warp_tile_k(wtk)
if not dtype_keys:
dtype_keys = _DTYPE_KEYS_ALL # fallback: try all
math_vecs: set = set()
for dk in dtype_keys:
math_vecs |= set(get_vec_sizes_for_wave_warp(tm, tn, tk, wtk, dk))
for v in ref_vecs:
if v not in math_vecs:
missing.append(((tm, tn, tk, wtk), v))
if missing:
lines = [f" key={k} vec={v}" for k, v in missing[:30]]
self.fail(
f"{len(missing)} reference vec triples not generated by math "
f"(showing first 30):\n" + "\n".join(lines)
)
def test_extra_wave_warp_pairs_are_structurally_valid(self):
"""Pairs the math generates but reference doesn't have must still be valid.
Extra pairs are not failures — they represent valid configs not yet in
the JSON profiler files. But they must satisfy structural constraints.
"""
valid_waves = {tuple(c) for c in WARP_SUPPORTED_COMBINATIONS["gfx942"]}
invalid_extras = []
for tile in TILE_TO_WAVE_WARP:
ref_set = set(TILE_TO_WAVE_WARP[tile])
math_pairs = get_valid_wave_warp_pairs(*tile, "bf16_bf16_fp32")
extras = [p for p in math_pairs if p not in ref_set]
for wave, (wt_m, wt_n, wt_k) in extras:
tm, tn, _tk = tile
# Wave must be arch-valid
if wave not in valid_waves:
invalid_extras.append((tile, wave, "not in arch wave combos"))
# Tile must be divisible by warp tile
elif tm % wt_m != 0 or tn % wt_n != 0:
invalid_extras.append((tile, (wt_m, wt_n, wt_k), "divisibility violated"))
if invalid_extras:
lines = [f" {e}" for e in invalid_extras[:10]]
self.fail(
f"{len(invalid_extras)} extra pairs are structurally invalid:\n"
+ "\n".join(lines)
)
def test_coverage_rate_wave_warp(self):
"""Log coverage statistics for TILE_TO_WAVE_WARP (informational)."""
total = 0
covered = 0
for tile, ref_pairs in TILE_TO_WAVE_WARP.items():
math_pairs: set = set()
for dk in _DTYPE_KEYS_ALL:
math_pairs |= set(get_valid_wave_warp_pairs(*tile, dk))
for pair in ref_pairs:
total += 1
if pair in math_pairs:
covered += 1
rate = covered / total * 100 if total else 0
print(f"\n[wave/warp coverage] {covered}/{total} = {rate:.1f}%")
self.assertEqual(covered, total, f"Coverage {rate:.1f}% < 100%")
def test_coverage_rate_vec(self):
"""Log coverage statistics for _TILE_WTILK_TO_VECS (informational)."""
total = 0
covered = 0
for (tm, tn, tk, wtk), ref_vecs in _TILE_WTILK_TO_VECS.items():
dtype_keys = dtype_keys_for_warp_tile_k(wtk) or _DTYPE_KEYS_ALL
math_vecs: set = set()
for dk in dtype_keys:
math_vecs |= set(get_vec_sizes_for_wave_warp(tm, tn, tk, wtk, dk))
for v in ref_vecs:
total += 1
if v in math_vecs:
covered += 1
rate = covered / total * 100 if total else 0
print(f"\n[vec coverage] {covered}/{total} = {rate:.1f}%")
self.assertEqual(covered, total, f"Vec coverage {rate:.1f}% < 100%")
if __name__ == "__main__":
unittest.main(verbosity=2)