Files
composable_kernel/tile_engine/ops/gemm/test_validation.py
Thomas Ning 705804d9bf Restructure the Tile Engine to have faster build time and clear config report (#2747)
* Making edits to identify individual compilation issues.

* Minor fix for blob txt files not being created.

* Fixing compilation issues.

* Fixing ordering bug.

* Adding python profiling functionality.

* Setting individual build as default.

* Setting gpu target filtering for tile engine to gfx90a, gfx942 and gfx950.

* update the default running parameters and settings

* Fixing bug with benchmarking, shifting file generation to build instead of config.

* Updating fixes.

* Fixing json output and parsing.

* Disable ccache for tile engine gemm ops because we dont need it.

* Removing duplicate type definition.

* Improving json printing.

* Add the flexibility of different layout and more warp tile support

* Fix extra flag in name of individual kernels.

* Fixing bug with booleans.

* Solve the first patch of the post merge conflict

* Compilation fixes, and cosmetic improvements.

* Yet again compilation fixes after latest changes from develop.

* Fixing python benchmarking script.

---------

Co-authored-by: Vidyasagar Ananthan <vidyasagar.ananthan@amd.com>
Co-authored-by: Vidyasagar Ananthan <vanantha@amd.com>
2025-08-30 06:54:18 -07:00

144 lines
4.2 KiB
Python

#!/usr/bin/env python
"""
Test script to verify that the validation logic is working correctly.
"""
from validation_utils import (
is_tile_config_valid,
is_trait_combination_valid,
validate_warp_tile_combination,
get_gpu_name_by_id,
)
def test_warp_tile_validation():
"""Test warp tile combination validation"""
print("Testing warp tile combination validation...")
# Get GPU name
gpu_name = get_gpu_name_by_id(0)
print(f"Detected GPU: {gpu_name}")
# Test cases for fp16
test_cases = [
# (warp_tile_m, warp_tile_n, warp_tile_k, expected_valid)
([4, 64, 8], False), # Invalid - not in supported list
([4, 64, 16], True), # Valid
([32, 32, 8], True), # Valid
([16, 16, 16], True), # Valid
([32, 32, 16], True), # Valid
([16, 16, 32], True), # Valid
([64, 4, 16], True), # Valid
([128, 128, 128], False), # Invalid - too large
]
print("\nTesting fp16 warp tile combinations:")
for (warp_tile_m, warp_tile_n, warp_tile_k), expected in test_cases:
valid, msg = validate_warp_tile_combination(
warp_tile_m, warp_tile_n, warp_tile_k, "fp16", "fp16", "fp16", gpu_name
)
status = "PASS" if valid == expected else "FAIL"
print(f" [{warp_tile_m}, {warp_tile_n}, {warp_tile_k}]: {valid} - {status}")
if not valid and msg:
print(f" Reason: {msg}")
def test_trait_combinations():
"""Test trait combination validation"""
print("\n\nTesting trait combination validation...")
test_cases = [
# (pipeline, epilogue, scheduler, expected_valid)
("mem", "default", "intrawave", True),
("mem", "cshuffle", "intrawave", True),
("compv3", "default", "interwave", False), # Invalid combination
("compv3", "cshuffle", "interwave", False), # Invalid combination
("compv4", "default", "interwave", False), # Invalid combination
("compv4", "cshuffle", "interwave", False), # Invalid combination
("compv3", "default", "intrawave", True),
("compv4", "cshuffle", "intrawave", True),
]
print("\nTesting trait combinations:")
for pipeline, epilogue, scheduler, expected in test_cases:
valid = is_trait_combination_valid(pipeline, epilogue, scheduler)
status = "PASS" if valid == expected else "FAIL"
print(f" {pipeline}-{epilogue}-{scheduler}: {valid} - {status}")
def test_full_tile_config_validation():
"""Test full tile configuration validation"""
print("\n\nTesting full tile configuration validation...")
# Test case that was failing in the build
tile_m, tile_n, tile_k = 256, 256, 32
warp_m, warp_n, warp_k = 1, 4, 1
warp_tile_m, warp_tile_n, warp_tile_k = 4, 64, 8
print("\nTesting problematic configuration:")
print(f" Tile: {tile_m}x{tile_n}x{tile_k}")
print(f" Warp: {warp_m}x{warp_n}x{warp_k}")
print(f" WarpTile: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}")
valid = is_tile_config_valid(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
"fp16",
"fp16",
"fp16",
"mem",
)
print(f" Valid: {valid}")
print(" Expected: False (warp tile [4, 64, 8] is not supported for fp16)")
# Test a valid configuration
warp_tile_k = 16 # Change to valid value
print("\nTesting corrected configuration:")
print(f" WarpTile: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}")
valid = is_tile_config_valid(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
"fp16",
"fp16",
"fp16",
"mem",
)
print(f" Valid: {valid}")
print(" Expected: True")
def main():
"""Run all tests"""
print("=" * 60)
print("GEMM Validation Test Suite")
print("=" * 60)
test_warp_tile_validation()
test_trait_combinations()
test_full_tile_config_validation()
print("\n" + "=" * 60)
print("Test suite completed")
print("=" * 60)
if __name__ == "__main__":
main()