Files
composable_kernel/tile_engine/ops/commons/test_validation.py

145 lines
4.2 KiB
Python

#!/usr/bin/env python
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Test script to verify that the validation logic is working correctly.
"""
from validation_utils import (
is_tile_config_valid,
is_trait_combination_valid,
validate_warp_tile_combination,
)
def test_warp_tile_validation():
"""Test warp tile combination validation"""
print("Testing warp tile combination validation...")
# Get GPU name
gpu_name = "gfx90a"
# Test cases for fp16
test_cases = [
# (warp_tile_m, warp_tile_n, warp_tile_k, expected_valid)
([4, 64, 8], False), # Invalid - not in supported list
([4, 64, 16], True), # Valid
([32, 32, 8], True), # Valid
([16, 16, 16], True), # Valid
([32, 32, 16], True), # Valid
([16, 16, 32], True), # Valid
([64, 4, 16], True), # Valid
([128, 128, 128], False), # Invalid - too large
]
print("\nTesting fp16 warp tile combinations:")
for (warp_tile_m, warp_tile_n, warp_tile_k), expected in test_cases:
valid, msg = validate_warp_tile_combination(
warp_tile_m, warp_tile_n, warp_tile_k, "fp16", "fp16", "fp16", gpu_name
)
status = "PASS" if valid == expected else "FAIL"
print(f" [{warp_tile_m}, {warp_tile_n}, {warp_tile_k}]: {valid} - {status}")
if not valid and msg:
print(f" Reason: {msg}")
def test_trait_combinations():
"""Test trait combination validation"""
print("\n\nTesting trait combination validation...")
test_cases = [
# (pipeline, epilogue, scheduler, expected_valid)
("mem", "default", "intrawave", True),
("mem", "cshuffle", "intrawave", True),
("compv3", "default", "interwave", False), # Invalid combination
("compv3", "cshuffle", "interwave", False), # Invalid combination
("compv4", "default", "interwave", False), # Invalid combination
("compv4", "cshuffle", "interwave", False), # Invalid combination
("compv3", "default", "intrawave", True),
("compv4", "cshuffle", "intrawave", True),
]
print("\nTesting trait combinations:")
for pipeline, epilogue, scheduler, expected in test_cases:
valid = is_trait_combination_valid(pipeline, epilogue, scheduler)
status = "PASS" if valid == expected else "FAIL"
print(f" {pipeline}-{epilogue}-{scheduler}: {valid} - {status}")
def test_full_tile_config_validation():
"""Test full tile configuration validation"""
print("\n\nTesting full tile configuration validation...")
# Test case that was failing in the build
tile_m, tile_n, tile_k = 256, 256, 32
warp_m, warp_n, warp_k = 1, 4, 1
warp_tile_m, warp_tile_n, warp_tile_k = 4, 64, 8
print("\nTesting problematic configuration:")
print(f" Tile: {tile_m}x{tile_n}x{tile_k}")
print(f" Warp: {warp_m}x{warp_n}x{warp_k}")
print(f" WarpTile: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}")
valid = is_tile_config_valid(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
"fp16",
"fp16",
"fp16",
"mem",
)
print(f" Valid: {valid}")
print(" Expected: False (warp tile [4, 64, 8] is not supported for fp16)")
# Test a valid configuration
warp_tile_k = 16 # Change to valid value
print("\nTesting corrected configuration:")
print(f" WarpTile: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}")
valid = is_tile_config_valid(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
"fp16",
"fp16",
"fp16",
"mem",
)
print(f" Valid: {valid}")
print(" Expected: True")
def main():
"""Run all tests"""
print("=" * 60)
print("GEMM Validation Test Suite")
print("=" * 60)
test_warp_tile_validation()
test_trait_combinations()
test_full_tile_config_validation()
print("\n" + "=" * 60)
print("Test suite completed")
print("=" * 60)
if __name__ == "__main__":
main()