mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK TILE ENGINE] GEMM Multi D Restructure (#3121)
* Renaming old code * Adding GEMM code with new Architecture * Partial Progress : Errors * Partial Progress : Working code * Changes to element wise function * Removing Debugging statements * Working GEMM Multi D code * Removing Stale Code * Address Copilot review comments * Address Copilot review comments * Changes to validation file * Changes to common code snippets * Creating common folder * Removing duplicate files * Pointing to right common file * Pointing to right common file * Pointing to right common file * Changing to VERBOSE * Changing CMAKE messages to verbose * Updating Cmake with right layout datatype configs * Working code for GEMM Multi D
This commit is contained in:
committed by
GitHub
parent
04efd282cf
commit
a33d98f8e2
102
tile_engine/ops/commons/test_benchmark.sh
Executable file
102
tile_engine/ops/commons/test_benchmark.sh
Executable file
@@ -0,0 +1,102 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Test script for tile engine GEMM benchmarks
|
||||
# This script demonstrates how to run the new individual benchmark executables
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Find the build directory
|
||||
if [ -z "$1" ]; then
|
||||
# Try to find build directory automatically
|
||||
BUILD_DIR=$(find /root/workspace/composable_kernel -name "test_gemm_fix" -type d 2>/dev/null | head -1)
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
echo -e "${RED}Error: Could not find build directory. Please provide it as first argument.${NC}"
|
||||
echo "Usage: $0 <build_directory>"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
BUILD_DIR="$1"
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}Using build directory: $BUILD_DIR${NC}"
|
||||
|
||||
# Check if bin directory exists
|
||||
if [ ! -d "$BUILD_DIR/bin" ]; then
|
||||
echo -e "${RED}Error: bin directory not found in $BUILD_DIR${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Find all benchmark executables
|
||||
echo -e "${YELLOW}Finding benchmark executables...${NC}"
|
||||
BENCHMARKS=$(find "$BUILD_DIR/bin" -name "benchmark_gemm_*" -type f 2>/dev/null)
|
||||
|
||||
if [ -z "$BENCHMARKS" ]; then
|
||||
echo -e "${RED}No benchmark executables found in $BUILD_DIR/bin${NC}"
|
||||
echo "Please build some benchmarks first with:"
|
||||
echo " cd $BUILD_DIR"
|
||||
echo " make benchmark_gemm_<kernel_name>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Count benchmarks
|
||||
NUM_BENCHMARKS=$(echo "$BENCHMARKS" | wc -l)
|
||||
echo -e "${GREEN}Found $NUM_BENCHMARKS benchmark executable(s)${NC}"
|
||||
|
||||
# Test sizes
|
||||
SIZES=(512 1024 2048)
|
||||
|
||||
# Results file
|
||||
RESULTS_FILE="benchmark_results_$(date +%Y%m%d_%H%M%S).csv"
|
||||
|
||||
echo -e "${YELLOW}Running benchmarks...${NC}"
|
||||
echo "Results will be saved to: $RESULTS_FILE"
|
||||
|
||||
# Run each benchmark
|
||||
COUNTER=0
|
||||
for BENCH in $BENCHMARKS; do
|
||||
COUNTER=$((COUNTER + 1))
|
||||
BENCH_NAME=$(basename "$BENCH")
|
||||
echo -e "\n${GREEN}[$COUNTER/$NUM_BENCHMARKS] Running: $BENCH_NAME${NC}"
|
||||
|
||||
for SIZE in "${SIZES[@]}"; do
|
||||
echo -e " Testing size: ${SIZE}x${SIZE}x${SIZE}"
|
||||
|
||||
# Run with verification
|
||||
"$BENCH" -m=$SIZE -n=$SIZE -k=$SIZE -verify=2 -warmup=10 -repeat=20 \
|
||||
-csv_filename="$RESULTS_FILE" -csv_format=simple \
|
||||
2>&1 | grep -E "(Time:|Performance:|Verification:|Error)"
|
||||
|
||||
if [ ${PIPESTATUS[0]} -ne 0 ]; then
|
||||
echo -e " ${RED}Benchmark failed!${NC}"
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
echo -e "\n${GREEN}Benchmark testing complete!${NC}"
|
||||
echo "Results saved to: $RESULTS_FILE"
|
||||
|
||||
# Show summary if CSV file exists
|
||||
if [ -f "$RESULTS_FILE" ]; then
|
||||
echo -e "\n${YELLOW}Summary of results:${NC}"
|
||||
echo "Number of tests: $(tail -n +2 "$RESULTS_FILE" | wc -l)"
|
||||
echo "Successful tests: $(grep -c "true" "$RESULTS_FILE")"
|
||||
echo "Failed tests: $(grep -c "false" "$RESULTS_FILE")"
|
||||
fi
|
||||
|
||||
# Example of running a specific benchmark with different options
|
||||
echo -e "\n${YELLOW}Example commands for manual testing:${NC}"
|
||||
echo "# Basic run:"
|
||||
echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=1024 -n=1024 -k=1024"
|
||||
echo ""
|
||||
echo "# With CPU verification:"
|
||||
echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=1024 -n=1024 -k=1024 -verify=1"
|
||||
echo ""
|
||||
echo "# JSON output for parsing:"
|
||||
echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=1024 -n=1024 -k=1024 -json_output=true"
|
||||
echo ""
|
||||
echo "# Performance testing with TFLOPS metric:"
|
||||
echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=4096 -n=4096 -k=4096 -warmup=100 -repeat=200 -metric=1"
|
||||
141
tile_engine/ops/commons/test_validation.py
Normal file
141
tile_engine/ops/commons/test_validation.py
Normal file
@@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Test script to verify that the validation logic is working correctly.
|
||||
"""
|
||||
|
||||
from validation_utils import (
|
||||
is_tile_config_valid,
|
||||
is_trait_combination_valid,
|
||||
validate_warp_tile_combination,
|
||||
)
|
||||
|
||||
|
||||
def test_warp_tile_validation():
|
||||
"""Test warp tile combination validation"""
|
||||
print("Testing warp tile combination validation...")
|
||||
|
||||
# Get GPU name
|
||||
gpu_name = "gfx90a"
|
||||
|
||||
# Test cases for fp16
|
||||
test_cases = [
|
||||
# (warp_tile_m, warp_tile_n, warp_tile_k, expected_valid)
|
||||
([4, 64, 8], False), # Invalid - not in supported list
|
||||
([4, 64, 16], True), # Valid
|
||||
([32, 32, 8], True), # Valid
|
||||
([16, 16, 16], True), # Valid
|
||||
([32, 32, 16], True), # Valid
|
||||
([16, 16, 32], True), # Valid
|
||||
([64, 4, 16], True), # Valid
|
||||
([128, 128, 128], False), # Invalid - too large
|
||||
]
|
||||
|
||||
print("\nTesting fp16 warp tile combinations:")
|
||||
for (warp_tile_m, warp_tile_n, warp_tile_k), expected in test_cases:
|
||||
valid, msg = validate_warp_tile_combination(
|
||||
warp_tile_m, warp_tile_n, warp_tile_k, "fp16", "fp16", "fp16", gpu_name
|
||||
)
|
||||
status = "PASS" if valid == expected else "FAIL"
|
||||
print(f" [{warp_tile_m}, {warp_tile_n}, {warp_tile_k}]: {valid} - {status}")
|
||||
if not valid and msg:
|
||||
print(f" Reason: {msg}")
|
||||
|
||||
|
||||
def test_trait_combinations():
|
||||
"""Test trait combination validation"""
|
||||
print("\n\nTesting trait combination validation...")
|
||||
|
||||
test_cases = [
|
||||
# (pipeline, epilogue, scheduler, expected_valid)
|
||||
("mem", "default", "intrawave", True),
|
||||
("mem", "cshuffle", "intrawave", True),
|
||||
("compv3", "default", "interwave", False), # Invalid combination
|
||||
("compv3", "cshuffle", "interwave", False), # Invalid combination
|
||||
("compv4", "default", "interwave", False), # Invalid combination
|
||||
("compv4", "cshuffle", "interwave", False), # Invalid combination
|
||||
("compv3", "default", "intrawave", True),
|
||||
("compv4", "cshuffle", "intrawave", True),
|
||||
]
|
||||
|
||||
print("\nTesting trait combinations:")
|
||||
for pipeline, epilogue, scheduler, expected in test_cases:
|
||||
valid = is_trait_combination_valid(pipeline, epilogue, scheduler)
|
||||
status = "PASS" if valid == expected else "FAIL"
|
||||
print(f" {pipeline}-{epilogue}-{scheduler}: {valid} - {status}")
|
||||
|
||||
|
||||
def test_full_tile_config_validation():
|
||||
"""Test full tile configuration validation"""
|
||||
print("\n\nTesting full tile configuration validation...")
|
||||
|
||||
# Test case that was failing in the build
|
||||
tile_m, tile_n, tile_k = 256, 256, 32
|
||||
warp_m, warp_n, warp_k = 1, 4, 1
|
||||
warp_tile_m, warp_tile_n, warp_tile_k = 4, 64, 8
|
||||
|
||||
print("\nTesting problematic configuration:")
|
||||
print(f" Tile: {tile_m}x{tile_n}x{tile_k}")
|
||||
print(f" Warp: {warp_m}x{warp_n}x{warp_k}")
|
||||
print(f" WarpTile: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}")
|
||||
|
||||
valid = is_tile_config_valid(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
"fp16",
|
||||
"fp16",
|
||||
"fp16",
|
||||
"mem",
|
||||
)
|
||||
|
||||
print(f" Valid: {valid}")
|
||||
print(" Expected: False (warp tile [4, 64, 8] is not supported for fp16)")
|
||||
|
||||
# Test a valid configuration
|
||||
warp_tile_k = 16 # Change to valid value
|
||||
print("\nTesting corrected configuration:")
|
||||
print(f" WarpTile: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}")
|
||||
|
||||
valid = is_tile_config_valid(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
"fp16",
|
||||
"fp16",
|
||||
"fp16",
|
||||
"mem",
|
||||
)
|
||||
|
||||
print(f" Valid: {valid}")
|
||||
print(" Expected: True")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("=" * 60)
|
||||
print("GEMM Validation Test Suite")
|
||||
print("=" * 60)
|
||||
|
||||
test_warp_tile_validation()
|
||||
test_trait_combinations()
|
||||
test_full_tile_config_validation()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Test suite completed")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
602
tile_engine/ops/commons/validation_utils.py
Normal file
602
tile_engine/ops/commons/validation_utils.py
Normal file
@@ -0,0 +1,602 @@
|
||||
#!/usr/bin/env python
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
"""
|
||||
Validation utilities for GEMM kernel generation.
|
||||
Extracted from tile_engine_develop for consistency.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Tuple, List
|
||||
|
||||
# Element size mapping for different data types
|
||||
ELEMENT_SIZE_MAP = {
|
||||
"fp16": 2,
|
||||
"bf16": 2,
|
||||
"int8": 1,
|
||||
"fp8": 1,
|
||||
"bf8": 1,
|
||||
"int4": 0.5,
|
||||
"int32": 4,
|
||||
"fp32": 4,
|
||||
"fp64": 8,
|
||||
}
|
||||
|
||||
WARP_SUPPORTED_COMBINATIONS = {
|
||||
"gfx90a": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1],
|
||||
],
|
||||
"gfx942": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1],
|
||||
],
|
||||
"gfx950": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1],
|
||||
],
|
||||
"gfx1201": [
|
||||
[2, 4, 1],
|
||||
[1, 8, 1],
|
||||
[8, 1, 1],
|
||||
[4, 2, 1],
|
||||
],
|
||||
}
|
||||
|
||||
# [TODO] Handle this while moving code to commons
|
||||
# Supported warp tile combinations for different GPU architectures and data types
|
||||
WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
"gfx90a": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
},
|
||||
"gfx942": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
|
||||
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
|
||||
},
|
||||
"gfx950": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"bf8_bf8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 32],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
},
|
||||
"gfx1201": { # Check how to handle for GEMM and Multi D
|
||||
"fp16_fp16_fp16": [
|
||||
[16, 16, 16],
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# Unsupported trait combinations
|
||||
TRAIT_UNSUPPORTED_COMBINATIONS = {
|
||||
("compv3", "cshuffle", "interwave"),
|
||||
("compv3", "default", "interwave"),
|
||||
("compv4", "cshuffle", "interwave"),
|
||||
("compv4", "default", "interwave"),
|
||||
}
|
||||
|
||||
|
||||
def element_size(data_type: str) -> float:
|
||||
"""Calculate the size (in bytes) of a single element for given data type."""
|
||||
data_type = data_type.lower()
|
||||
if data_type not in ELEMENT_SIZE_MAP:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
return ELEMENT_SIZE_MAP[data_type]
|
||||
|
||||
|
||||
def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool:
|
||||
"""Check if a trait combination is valid."""
|
||||
return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS
|
||||
|
||||
|
||||
def validate_warp_configuration(
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
gpu_name: str,
|
||||
) -> bool:
|
||||
"""Validate warp configuration."""
|
||||
|
||||
current_combination = [warp_m, warp_n, warp_k]
|
||||
|
||||
allowed_combinations = WARP_SUPPORTED_COMBINATIONS.get(gpu_name, {})
|
||||
if not allowed_combinations:
|
||||
# If GPU not recognized, try to be permissive but log warning
|
||||
logging.warning(f"No warp_[m/n/k] combinations found for GPU: {gpu_name}")
|
||||
return True
|
||||
|
||||
# Check if current combination is in the allowed list
|
||||
if current_combination not in allowed_combinations:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_dimension_alignment(
|
||||
tile_m: int,
|
||||
tile_n: int,
|
||||
tile_k: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
warp_tile_k: int,
|
||||
) -> Tuple[bool, List[str]]:
|
||||
"""Check if tile dimensions are properly aligned with warp dimensions."""
|
||||
alignment_issues = []
|
||||
|
||||
if tile_m % (warp_m * warp_tile_m) != 0:
|
||||
alignment_issues.append(
|
||||
f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}"
|
||||
)
|
||||
if tile_n % (warp_n * warp_tile_n) != 0:
|
||||
alignment_issues.append(
|
||||
f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}"
|
||||
)
|
||||
if tile_k % (warp_k * warp_tile_k) != 0:
|
||||
alignment_issues.append(
|
||||
f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}"
|
||||
)
|
||||
|
||||
return len(alignment_issues) == 0, alignment_issues
|
||||
|
||||
|
||||
def validate_lds_capacity(
|
||||
tile_m: int,
|
||||
tile_n: int,
|
||||
tile_k: int,
|
||||
a_datatype: str,
|
||||
b_datatype: str,
|
||||
pipeline: str,
|
||||
) -> Tuple[bool, str]:
|
||||
"""Validate LDS capacity requirements."""
|
||||
matrix_a_size = (tile_m * tile_k) * element_size(a_datatype)
|
||||
matrix_b_size = (tile_n * tile_k) * element_size(b_datatype)
|
||||
total_tile_in_lds = matrix_a_size + matrix_b_size
|
||||
|
||||
max_tile_size = 2**15 if pipeline == "compv4" else 2**16
|
||||
|
||||
if total_tile_in_lds > max_tile_size:
|
||||
error_msg = (
|
||||
f"LDS capacity exceeded: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > "
|
||||
f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n"
|
||||
f"- Matrix A ({a_datatype}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n"
|
||||
f"- Matrix B ({b_datatype}): {tile_n}x{tile_k} = {matrix_b_size:,}B"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_warp_tile_combination(
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
warp_tile_k: int,
|
||||
a_datatype: str,
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
gpu_name: str,
|
||||
) -> Tuple[bool, str]:
|
||||
"""Validate warp tile combination against GPU-specific supported combinations."""
|
||||
|
||||
# Construct the key for looking up supported combinations
|
||||
warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}"
|
||||
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
|
||||
|
||||
# Check if we have GPU-specific combinations
|
||||
gpu_warp_tile_combinations = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {})
|
||||
if not gpu_warp_tile_combinations:
|
||||
# If GPU not recognized, try to be permissive but log warning
|
||||
logging.warning(f"No warp tile combinations found for GPU: {gpu_name}")
|
||||
return True, ""
|
||||
|
||||
# Check if we have combinations for this data type combination
|
||||
allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, [])
|
||||
if not allowed_combinations:
|
||||
# For data type combinations not in the list, be permissive
|
||||
logging.debug(
|
||||
f"No warp tile combinations found for data types: {warp_tile_key}"
|
||||
)
|
||||
return True, ""
|
||||
|
||||
# Check if current combination is in the allowed list
|
||||
if current_combination not in allowed_combinations:
|
||||
error_msg = (
|
||||
f"Invalid warp tile combination: {current_combination} not in allowed list. "
|
||||
f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def is_tile_config_valid(
|
||||
tile_m: int,
|
||||
tile_n: int,
|
||||
tile_k: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
warp_tile_k: int,
|
||||
a_datatype: str,
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
pipeline: str,
|
||||
layout: str,
|
||||
gpu_target: str,
|
||||
trait_name: str = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Comprehensive tile configuration validation.
|
||||
Returns True if configuration is valid, False otherwise.
|
||||
"""
|
||||
# Basic sanity checks
|
||||
if tile_m <= 0 or tile_n <= 0 or tile_k <= 0:
|
||||
return False
|
||||
if warp_m <= 0 or warp_n <= 0 or warp_k <= 0:
|
||||
return False
|
||||
if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0:
|
||||
return False
|
||||
|
||||
# Check that warp tiles fit within block tiles
|
||||
if warp_m * warp_tile_m > tile_m:
|
||||
return False
|
||||
if warp_n * warp_tile_n > tile_n:
|
||||
return False
|
||||
if warp_k * warp_tile_k > tile_k:
|
||||
return False
|
||||
|
||||
# Validate warp configuration
|
||||
if not validate_warp_configuration(warp_m, warp_n, warp_k, gpu_target):
|
||||
logging.debug(
|
||||
f"Invalid warp configuration: warp_m({warp_m}), warp_n({warp_n}), warp_k({warp_k})"
|
||||
)
|
||||
return False
|
||||
|
||||
# Validate dimension alignment
|
||||
is_aligned, alignment_issues = validate_dimension_alignment(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
)
|
||||
if not is_aligned:
|
||||
logging.debug(
|
||||
f"Dimension alignment failed: {', '.join(alignment_issues)}. "
|
||||
f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by "
|
||||
f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Validate LDS capacity
|
||||
lds_valid, lds_error = validate_lds_capacity(
|
||||
tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline
|
||||
)
|
||||
if not lds_valid:
|
||||
logging.debug(f"LDS validation failed: {lds_error}")
|
||||
return False
|
||||
|
||||
# Validate whole workgroup cover configuration
|
||||
wr_cover_valid, wg_cover_error = validate_whole_wg_cover_configuration(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
layout,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
)
|
||||
if not wr_cover_valid:
|
||||
logging.debug(
|
||||
f"Whole workgroup cover configuration validation failed: {wg_cover_error}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Validate warp tile combination
|
||||
warp_tile_valid, warp_tile_error = validate_warp_tile_combination(
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
c_datatype,
|
||||
gpu_target,
|
||||
)
|
||||
if not warp_tile_valid:
|
||||
logging.debug(f"Warp tile validation failed: {warp_tile_error}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# [TODO] Handle this while moving code to commons Add more datatype to this function if needed
|
||||
def get_dtype_string(datatype: str) -> str:
|
||||
"""Get C++ type string for datatype"""
|
||||
dtype_map = {
|
||||
"fp16": "ck_tile::fp16_t",
|
||||
"fp8": "ck_tile::fp8_t",
|
||||
"bf8": "ck_tile::bf8_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"fp32": "float",
|
||||
"fp64": "double",
|
||||
}
|
||||
return dtype_map.get(datatype, "float")
|
||||
|
||||
|
||||
LAYOUT_MAP = {
|
||||
"r": "ck_tile::tensor_layout::gemm::RowMajor",
|
||||
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
|
||||
}
|
||||
|
||||
|
||||
def get_abc_layouts(layout_code: str) -> Tuple[str, str, str]:
|
||||
"""
|
||||
Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'.
|
||||
"""
|
||||
code = str(layout_code).strip().lower()
|
||||
|
||||
a_layout = LAYOUT_MAP[code[0]]
|
||||
b_layout = LAYOUT_MAP[code[1]]
|
||||
c_layout = LAYOUT_MAP[code[2]]
|
||||
return a_layout, b_layout, c_layout
|
||||
|
||||
|
||||
def get_abcd_layouts(layout_code: str) -> Tuple[str, str, str, List[str]]:
|
||||
"""
|
||||
Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcrr', 'ccrr', 'crrr', 'rrrr'.
|
||||
"""
|
||||
code = str(layout_code).strip().lower()
|
||||
|
||||
a_layout = LAYOUT_MAP[code[0]]
|
||||
b_layout = LAYOUT_MAP[code[1]]
|
||||
c_layout = LAYOUT_MAP[code[2]]
|
||||
d0_layout = LAYOUT_MAP[code[3]]
|
||||
d1_layout = LAYOUT_MAP[code[3]]
|
||||
return a_layout, b_layout, c_layout, [d0_layout, d1_layout]
|
||||
|
||||
|
||||
def validate_whole_wg_cover_configuration(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
layout,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
) -> Tuple[bool, str]:
|
||||
# Validate whole workgroup cover configuration
|
||||
|
||||
warp_size = 64
|
||||
NumWarps = warp_m * warp_n * warp_k
|
||||
BlockSize = NumWarps * warp_size
|
||||
|
||||
XPerTile = 0
|
||||
YPerTile = 0
|
||||
vector_load_size = 0
|
||||
|
||||
# A matrix validation
|
||||
if layout[0] == "r":
|
||||
vector_load_size = get_global_vector_load_size(
|
||||
BlockSize, tile_k, a_datatype, tile_m, tile_k
|
||||
)
|
||||
|
||||
XPerTile = tile_k
|
||||
YPerTile = tile_m
|
||||
|
||||
elif layout[0] == "c":
|
||||
vector_load_size = get_global_vector_load_size(
|
||||
BlockSize, tile_k, a_datatype, tile_m, tile_m
|
||||
)
|
||||
|
||||
# Validate distribution
|
||||
XPerTile = tile_k
|
||||
YPerTile = tile_m
|
||||
|
||||
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
|
||||
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
|
||||
)
|
||||
|
||||
if not wg_cover_core_valid:
|
||||
logging.debug(
|
||||
f"whole workgroup cover failed for Matrix A distribution: {wg_cover_core_error}"
|
||||
)
|
||||
return False, wg_cover_core_error
|
||||
|
||||
XPerTile = tile_m
|
||||
YPerTile = tile_k
|
||||
|
||||
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
|
||||
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
|
||||
)
|
||||
|
||||
if not wg_cover_core_valid:
|
||||
logging.debug(
|
||||
f"whole workgroup cover failed for Matrix A: {wg_cover_core_error}"
|
||||
)
|
||||
return False, wg_cover_core_error
|
||||
|
||||
# B matrix validation
|
||||
if layout[1] == "r":
|
||||
vector_load_size = get_global_vector_load_size(
|
||||
BlockSize, tile_k, b_datatype, tile_n, tile_n
|
||||
)
|
||||
|
||||
# Validate distribution
|
||||
XPerTile = tile_k
|
||||
YPerTile = tile_n
|
||||
|
||||
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
|
||||
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
|
||||
)
|
||||
|
||||
if not wg_cover_core_valid:
|
||||
print("I am here 3")
|
||||
logging.debug(
|
||||
f"whole workgroup cover failed for Matrix B distribution: {wg_cover_core_error}"
|
||||
)
|
||||
return False, wg_cover_core_error
|
||||
|
||||
XPerTile = tile_n
|
||||
YPerTile = tile_k
|
||||
|
||||
elif layout[1] == "c":
|
||||
XPerTile = tile_k
|
||||
YPerTile = tile_n
|
||||
|
||||
vector_load_size = get_global_vector_load_size(
|
||||
BlockSize, tile_k, b_datatype, tile_n, tile_k
|
||||
)
|
||||
|
||||
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
|
||||
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
|
||||
)
|
||||
if not wg_cover_core_valid:
|
||||
logging.debug(
|
||||
f"whole workgroup cover failed for Matrix B: {wg_cover_core_error}"
|
||||
)
|
||||
return False, wg_cover_core_error
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def wg_cover_core_validation(
|
||||
XPerTile: int,
|
||||
YPerTile: int,
|
||||
BlockSize: int,
|
||||
vector_load_size: int,
|
||||
warp_size: int,
|
||||
) -> Tuple[bool, str]:
|
||||
if XPerTile % vector_load_size != 0:
|
||||
return False, "XPerTile is not divisible by vector_load_size"
|
||||
|
||||
num_warps = BlockSize / warp_size
|
||||
LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size)
|
||||
|
||||
X1 = LargestVec if vector_load_size > LargestVec else vector_load_size
|
||||
X0 = XPerTile / X1
|
||||
Y1 = warp_size // X0
|
||||
|
||||
if X0 * Y1 != warp_size:
|
||||
return False, "X0 * Y1 != warp_size"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def get_global_vector_load_size(
|
||||
BlockSize: int,
|
||||
KPerBlock: int,
|
||||
DataType: str,
|
||||
MNPerBlock: int,
|
||||
XPerTile: int,
|
||||
) -> int:
|
||||
elements_per_thread = MNPerBlock * KPerBlock / BlockSize
|
||||
PackedSize = 1
|
||||
|
||||
if (
|
||||
PackedSize == 2
|
||||
and XPerTile % (PackedSize * 32 / element_size(DataType)) == 0
|
||||
and elements_per_thread % (PackedSize * 32 / element_size(DataType)) == 0
|
||||
):
|
||||
return PackedSize * 32 / element_size(DataType)
|
||||
elif (
|
||||
XPerTile % (PackedSize * 16 / element_size(DataType)) == 0
|
||||
and elements_per_thread % (PackedSize * 16 / element_size(DataType)) == 0
|
||||
):
|
||||
return int(PackedSize * 16 / element_size(DataType))
|
||||
|
||||
elif (
|
||||
XPerTile % (PackedSize * 8 / element_size(DataType)) == 0
|
||||
and elements_per_thread % (PackedSize * 8 / element_size(DataType)) == 0
|
||||
):
|
||||
return int(PackedSize * 8 / element_size(DataType))
|
||||
elif (
|
||||
element_size(DataType) >= PackedSize * 4
|
||||
and XPerTile % (PackedSize * 4 / element_size(DataType)) == 0
|
||||
and elements_per_thread % (PackedSize * 4 / element_size(DataType)) == 0
|
||||
):
|
||||
return int(PackedSize * 4 / element_size(DataType))
|
||||
elif (
|
||||
element_size(DataType) >= PackedSize * 2
|
||||
and XPerTile % (PackedSize * 2 / element_size(DataType)) == 0
|
||||
and elements_per_thread % (PackedSize * 2 / element_size(DataType)) == 0
|
||||
):
|
||||
return int(PackedSize * 2 / element_size(DataType))
|
||||
else:
|
||||
return PackedSize
|
||||
Reference in New Issue
Block a user