Files
composable_kernel/script/splitk.py
2025-06-27 12:52:24 +00:00

159 lines
6.3 KiB
Python

#!/usr/bin/env python3
import math
import argparse
import sys
import numpy as np
def parse_cli_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description="Analyze best Split-K values.")
parser.add_argument("--gemm-m", type=int, dest="gemm_m", required=True, help="GEMM problem M-dimension.")
parser.add_argument("--gemm-n", type=int, dest="gemm_n", required=True, help="GEMM problem N-dimension.")
parser.add_argument("--gemm-k", type=int, dest="gemm_k", required=True, help="GEMM problem K-dimension.")
parser.add_argument("--max-occupancy", type=int, required=True, help="Maximum number of simultaneously active workgroups on a single CU.")
parser.add_argument("--blk-m", type=int, default=64, help="Block size for M dimension (default: 64).")
parser.add_argument("--blk-n", type=int, default=64, help="Block size for N dimension (default: 64).")
parser.add_argument("--blk-k", type=int, default=32, help="Block size for K dimension (default: 32).")
parser.add_argument("--ncus", type=int, default=304, help="Number of compute units (default: 304).")
args, unknown_args = parser.parse_known_args()
if unknown_args:
print(f"Unknown arguments: {unknown_args}", file=sys.stderr)
sys.exit(1)
return args
def find_optimal_split_k(grid_k, grid_size, nproc, max_split_k, print_analysis=False):
"""
Find the optimal split-K value using the periodicity insights from StreamK.
The key insight is that k-start will be periodic when grid_k/iters_per_proc
is a whole number. For split-K, this translates to finding split_k values
where grid_k/split_k divides evenly into the number of processes.
Parameters:
grid_k (int): Number of tiles in the K dimension (K/BLK_K)
grid_size (int): Total number of output tiles in the grid (M/BLK_M * N/BLK_N)
nproc (int): Number of processes (effective compute units)
max_split_k (int, optional): Maximum split-K to consider. Defaults to grid_k
print_analysis (bool): Whether to print detailed analysis
Returns:
dict: Analysis results with optimal split-K values and their properties
"""
results = []
for split_k in range(1, max_split_k + 1):
k_tiles_per_split = (grid_k + split_k - 1) // split_k
# Tail loop cost - the split_k may not divide grid_k evenly, leading to a tail loop
tail_loop_cost = 0 if grid_k % split_k == 0 else 1.0 / (grid_k % split_k)
# Check load balance - how evenly can we distribute the splits
total_blocks = split_k * grid_size
load_balance_score = 1.0 if total_blocks % nproc == 0 else (nproc - (total_blocks % nproc)) / nproc
# Compute cache locality score based on k-tile size per split
# Smaller k_tiles_per_split generally means better cache reuse within each split
cache_locality_score = 1.0 / k_tiles_per_split if k_tiles_per_split > 0 else 0
# Synchronization overhead increases with more splits
sync_overhead_score = 1.0 / (1.0 + split_k)
# Combined score (weighted)
combined_score = (
0.5 * load_balance_score +
0.3 * cache_locality_score +
0.1 * tail_loop_cost +
0.1 * sync_overhead_score
)
results.append({
'split_k': split_k,
'k_tiles_per_split': k_tiles_per_split,
'load_balance_score': load_balance_score,
'cache_locality_score': cache_locality_score,
'sync_overhead_score': sync_overhead_score,
'tail_loop_cost': tail_loop_cost,
'combined_score': combined_score,
'is_perfect_division': grid_k % split_k == 0
})
# Sort by combined score
results.sort(key=lambda x: x['combined_score'], reverse=True)
if print_analysis:
print(f"\nSplit-K Analysis for grid_k={grid_k}, nproc={nproc}")
print("=" * 80)
print(f"{'Split-K':<8} {'Balance':<8} {'Cache':<8} {'Sync':<8} {'Tail':<8} {'Score':<8}")
print("-" * 80)
for result in results[:10]: # Show top 10
print(f"{result['split_k']:<8} "
f"{result['load_balance_score']:<8.3f} {result['cache_locality_score']:<8.3f}"
f"{result['sync_overhead_score']:<8.3f} {result['tail_loop_cost']:<8.3f} {result['combined_score']:<8.3f}")
return {
'optimal_split_k': results[0]['split_k'] if results else 1,
'all_results': results,
'top_candidates': results[:5] if len(results) >= 5 else results
}
# Add a convenience function to find optimal split-K with both theoretical and empirical analysis
def find_best_split_k_comprehensive(
M, N, K, BLK_M, BLK_N, BLK_K,
nproc,
max_split_k=None, print_analysis=True
):
"""
Comprehensive split-K optimization using both theoretical analysis and cache simulation.
Returns:
dict: Complete analysis with recommendations
"""
grid_m = math.ceil(M / BLK_M)
grid_n = math.ceil(N / BLK_N)
grid_k = math.ceil(K / BLK_K)
if max_split_k is None:
max_split_k = grid_k
grid_size = grid_m * grid_n
theoretical_results = find_optimal_split_k(
grid_k=grid_k,
grid_size=grid_size,
nproc=nproc,
max_split_k=max_split_k,
print_analysis=print_analysis
)
return {
'theoretical_optimal': theoretical_results['optimal_split_k'],
'theoretical_analysis': theoretical_results,
'recommendation': theoretical_results['optimal_split_k'],
'grid_dimensions': (grid_m, grid_n, grid_k)
}
def main():
args = parse_cli_args()
nproc = args.ncus * args.max_occupancy
results = find_best_split_k_comprehensive(
M=args.gemm_m, N=args.gemm_n, K=args.gemm_k,
BLK_M=args.blk_m, BLK_N=args.blk_n, BLK_K=args.blk_k,
nproc=nproc, print_analysis=True
)
print("\nComprehensive Split-K Analysis Results:")
print("=" * 60)
print(f"Theoretical Optimal Split-K value: {results['theoretical_optimal']}")
print(f"Recommended Split-K value: {results['recommendation']}")
print(f"Grid Dimensions (M, N, K): {results['grid_dimensions']}")
if __name__ == "__main__":
main()