mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
[rocm-libraries] ROCm/rocm-libraries#5722 (commit 55febd2)
[CK Tile] Stream-K gtest Code Gen
## Motivation
Stream-K was using the tile engine infrastructure for smoke tests.
However, tile engine creates a different target per kernel instance,
which has resulted in scalability issues when used in the context of
unit tests. To avoid burdens on cmake configuration and build time, we
have opted to remove our Stream-K tile engine tests. Instead, we use
pure gtests with code gen to generate repetitive .cpp files.
**Note: This appears to change a lot of files because many files are
removed since they are now generated at build time.**
## Technical Details
We originally used Tile Engine to facilitate code gen for unit tests
since we found that pure gtests required the addition of many repetitive
.cpp files of the following form:
```cpp
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8
TYPED_TEST_SUITE(TestCkTileStreamKBf8, KernelTypesStreamKBf8);
#include "test_gemm_streamk_atomic_cases.inc"
#undef TEST_SUITE_NAME
```
Due to issues encountered with tile engine, we instead use pure gtests
to generate the repetitive .cpp files. The code generator parses
`KernelTypesStreamK*` type aliases from the types header using a
two-phase approach:
1. At **configure time**, CMake runs the Python script with
`--list_files` to extract the type alias names from the header
(test_gemm_streamk_types.hpp) and compute the list of .cpp file paths
that will be generated. This lets CMake know the exact set of source
files for each target.
2. At **build time**, `add_custom_command` runs the script again with
`--gen_files` to actually emit the .cpp files into the build directory,
triggered only when the types header or generator script changes.
With these changes, we've removed all Stream-K tile engine tests. There
are now 5 targets for Stream-K GEMM tests:
1. test_ck_tile_streamk_atomic_smoke: smoke tests for Atomic reduction
strategy (pipeline: compv3)
2. test_ck_tile_streamk_linear_smoke: smoke tests for Linear reduction
strategy (pipeline: compv3)
3. test_ck_tile_streamk_tree_smoke: smoke tests for Tree reduction
strategy (pipeline: compv3)
4. test_ck_tile_streamk_pipelines_smoke: smoke tests (smaller set) for
pipelines other than compv3
- Since Stream-K can be thought of as a wrapper around universal GEMM,
we don't need to extensively test each pipeline. So, we opt to run a few
tests for different pipelines. Currently, this just consists of the mem
pipeline, but compv4 is coming soon.
5. test_ck_tile_streamk_extended: extended tests
## Test Plan
I have tests the gtests locally on gfx90a, gfx942, and gfx950.
## Test Result
All local tests pass.
## Submission Checklist
- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
6d77edc3bd
commit
7cc9bae9d2
215
test/ck_tile/gemm_streamk/generate_test_files.py
Normal file
215
test/ck_tile/gemm_streamk/generate_test_files.py
Normal file
@@ -0,0 +1,215 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Generate test .cpp files from KernelTypes definitions in
|
||||
test_gemm_streamk_types.hpp.
|
||||
|
||||
Two modes:
|
||||
--list_files FILE Write the list of output file paths to FILE (one per line)
|
||||
without generating the files. Used at CMake configure time.
|
||||
--gen_files Actually emit the .cpp files into --output_dir.
|
||||
Used at build time via add_custom_command.
|
||||
|
||||
Target selection (--target):
|
||||
extended Kernel types containing 'Atomic' or 'Pipelines'
|
||||
-> includes test_gemm_streamk_extended_cases.inc
|
||||
atomic_smoke Kernel types containing 'Atomic' (not 'Pipelines')
|
||||
-> includes test_gemm_streamk_atomic_cases.inc
|
||||
linear_smoke Kernel types containing 'Linear' (not 'Pipelines')
|
||||
-> includes test_gemm_streamk_reduction_cases.inc
|
||||
tree_smoke Kernel types containing 'Tree' (not 'Pipelines')
|
||||
-> includes test_gemm_streamk_reduction_cases.inc
|
||||
pipelines_smoke Kernel types matching 'Pipelines'
|
||||
-> includes test_gemm_streamk_reduction_cases.inc
|
||||
and test_gemm_streamk_atomic_cases.inc
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Template for every generated .cpp file
|
||||
# --------------------------------------------------------------------------- #
|
||||
CPP_TEMPLATE = """\
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class {class_name} : public TestCkTileStreamK<Tuple>
|
||||
{{
|
||||
}};
|
||||
|
||||
#define TEST_SUITE_NAME {class_name}
|
||||
|
||||
TYPED_TEST_SUITE({class_name}, {type_alias});
|
||||
|
||||
{inc_includes}
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
"""
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Target definitions: filter predicate and .inc files
|
||||
# --------------------------------------------------------------------------- #
|
||||
TARGETS = {
|
||||
"extended": {
|
||||
"filter": lambda suffix: "Atomic" in suffix or suffix == "Pipelines",
|
||||
"inc_files": ["test_gemm_streamk_extended_cases.inc"],
|
||||
},
|
||||
"atomic_smoke": {
|
||||
"filter": lambda suffix: "Atomic" in suffix and suffix != "Pipelines",
|
||||
"inc_files": ["test_gemm_streamk_atomic_cases.inc"],
|
||||
},
|
||||
"linear_smoke": {
|
||||
"filter": lambda suffix: "Linear" in suffix and suffix != "Pipelines",
|
||||
"inc_files": ["test_gemm_streamk_reduction_cases.inc"],
|
||||
},
|
||||
"tree_smoke": {
|
||||
"filter": lambda suffix: "Tree" in suffix and suffix != "Pipelines",
|
||||
"inc_files": ["test_gemm_streamk_reduction_cases.inc"],
|
||||
},
|
||||
"pipelines_smoke": {
|
||||
"filter": lambda suffix: suffix == "Pipelines",
|
||||
"inc_files": [
|
||||
"test_gemm_streamk_reduction_cases.inc",
|
||||
"test_gemm_streamk_atomic_cases.inc",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Mapping from CamelCase suffix fragments to file-name fragments
|
||||
# --------------------------------------------------------------------------- #
|
||||
KNOWN_TOKENS = [
|
||||
("Fp16", "fp16"),
|
||||
("Bf16", "bf16"),
|
||||
("Fp8", "fp8"),
|
||||
("Bf8", "bf8"),
|
||||
("NonPersistent", "nonpersistent"),
|
||||
("Persistent", "persistent"),
|
||||
("Atomic", "atomic"),
|
||||
("Linear", "linear"),
|
||||
("Tree", "tree"),
|
||||
("CompV3", "compv3"),
|
||||
("Pipelines", "pipelines"),
|
||||
]
|
||||
|
||||
|
||||
def suffix_to_file_tag(suffix: str) -> str:
|
||||
"""Convert a CamelCase suffix like 'Fp16PersistentAtomicCompV3' to
|
||||
'fp16_persistent_atomic_compv3'."""
|
||||
parts: list[str] = []
|
||||
remaining = suffix
|
||||
while remaining:
|
||||
matched = False
|
||||
for token, replacement in KNOWN_TOKENS:
|
||||
if remaining.startswith(token):
|
||||
parts.append(replacement)
|
||||
remaining = remaining[len(token) :]
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
raise ValueError(
|
||||
f"Unrecognised token in KernelTypes suffix: '{remaining}' "
|
||||
f"(from '{suffix}')"
|
||||
)
|
||||
return "_".join(parts)
|
||||
|
||||
|
||||
def parse_types_header(header_path: str, target: str) -> list[dict]:
|
||||
"""Return a list of dicts with keys: type_alias, class_name, file_tag, suffix."""
|
||||
target_def = TARGETS[target]
|
||||
# Pattern matches lines like: using KernelTypesStreamKFp16PersistentAtomicCompV3 = ...
|
||||
pattern = re.compile(r"using\s+(KernelTypesStreamK(\w+))\s*=")
|
||||
entries: list[dict] = []
|
||||
with open(header_path) as f:
|
||||
for line in f:
|
||||
match = pattern.search(line)
|
||||
if match:
|
||||
# If the match is: using KernelTypesStreamKFp16PersistentAtomicCompV3 = ...
|
||||
# type_alias is KernelTypesStreamKFp16PersistentAtomicCompV3
|
||||
# suffix is Fp16PersistentAtomicCompV3
|
||||
type_alias = match.group(1)
|
||||
suffix = match.group(2)
|
||||
if not target_def["filter"](suffix):
|
||||
continue
|
||||
entries.append(
|
||||
{
|
||||
"type_alias": type_alias,
|
||||
"class_name": f"TestCkTileStreamK{suffix}",
|
||||
"file_tag": suffix_to_file_tag(suffix),
|
||||
}
|
||||
)
|
||||
return entries
|
||||
|
||||
|
||||
def output_path(output_dir: str, entry: dict) -> str:
|
||||
return os.path.join(output_dir, f"test_gemm_streamk_{entry['file_tag']}.cpp")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--types_header", required=True, help="Path to test_gemm_streamk_types.hpp"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", required=True, help="Directory for generated .cpp files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target",
|
||||
required=True,
|
||||
choices=list(TARGETS.keys()),
|
||||
help="Which target to generate files for",
|
||||
)
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
"--list_files",
|
||||
metavar="FILE",
|
||||
help="Write output file paths to FILE then exit",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gen_files", action="store_true", help="Generate the .cpp files"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
entries = parse_types_header(args.types_header, args.target)
|
||||
if not entries:
|
||||
print(
|
||||
f"ERROR: no KernelTypesStreamK* definitions found for target "
|
||||
f"'{args.target}' in {args.types_header}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
inc_files = TARGETS[args.target]["inc_files"]
|
||||
inc_includes = "\n".join(f'#include "{f}"' for f in inc_files)
|
||||
|
||||
if args.list_files:
|
||||
os.makedirs(os.path.dirname(args.list_files) or ".", exist_ok=True)
|
||||
with open(args.list_files, "w") as f:
|
||||
for entry in entries:
|
||||
f.write(output_path(args.output_dir, entry) + "\n")
|
||||
else:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
for entry in entries:
|
||||
path = output_path(args.output_dir, entry)
|
||||
content = CPP_TEMPLATE.format(
|
||||
class_name=entry["class_name"],
|
||||
type_alias=entry["type_alias"],
|
||||
inc_includes=inc_includes,
|
||||
)
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user