mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
135 lines
4.0 KiB
Python
135 lines
4.0 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Validation utilities for pooling tile_engine configurations.
|
|
Validates tile configurations and trait combinations for pooling kernels.
|
|
"""
|
|
|
|
import logging
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
WARP_SIZE = 64 # AMD wavefront size
|
|
|
|
|
|
def is_tile_config_valid(
|
|
block_m,
|
|
block_n,
|
|
warp_m,
|
|
warp_n,
|
|
warp_tile_m,
|
|
warp_tile_n,
|
|
thread_tile_m,
|
|
thread_tile_n,
|
|
in_datatype,
|
|
out_datatype,
|
|
):
|
|
"""
|
|
Validate a pooling tile configuration.
|
|
|
|
For pooling, the 2D tile is:
|
|
M = output elements (N*Ho*Wo*C for 2D, N*Do*Ho*Wo*C for 3D)
|
|
N = reduction dimension (window elements: Y*X for 2D, Z*Y*X for 3D)
|
|
|
|
BlockShape params:
|
|
BlockWarps = (warp_m, warp_n)
|
|
BlockTile = (block_m, block_n)
|
|
WarpTile = (warp_tile_m, warp_tile_n)
|
|
ThreadTile = (thread_tile_m, thread_tile_n)
|
|
"""
|
|
|
|
# Basic positivity checks
|
|
if any(
|
|
v <= 0
|
|
for v in [
|
|
block_m,
|
|
block_n,
|
|
warp_m,
|
|
warp_n,
|
|
warp_tile_m,
|
|
warp_tile_n,
|
|
thread_tile_m,
|
|
thread_tile_n,
|
|
]
|
|
):
|
|
logging.debug("All tile parameters must be positive")
|
|
return False
|
|
|
|
# WarpTile must be divisible by ThreadTile
|
|
if warp_tile_m % thread_tile_m != 0:
|
|
logging.debug(
|
|
f"warp_tile_m ({warp_tile_m}) must be divisible by thread_tile_m ({thread_tile_m})"
|
|
)
|
|
return False
|
|
if warp_tile_n % thread_tile_n != 0:
|
|
logging.debug(
|
|
f"warp_tile_n ({warp_tile_n}) must be divisible by thread_tile_n ({thread_tile_n})"
|
|
)
|
|
return False
|
|
|
|
# WarpTile / ThreadTile product must be multiple of warp size
|
|
threads_per_warp = (warp_tile_m * warp_tile_n) // (thread_tile_m * thread_tile_n)
|
|
if threads_per_warp % WARP_SIZE != 0:
|
|
logging.debug(
|
|
f"warp_tile product / thread_tile product ({threads_per_warp}) "
|
|
f"must be multiple of WARP_SIZE ({WARP_SIZE})"
|
|
)
|
|
return False
|
|
|
|
# Calculate WarpSizeScaleFactor
|
|
warp_size_scale_factor = threads_per_warp // WARP_SIZE
|
|
|
|
if warp_tile_m // thread_tile_m > warp_tile_n // thread_tile_n:
|
|
warp_size_scale_factor_m = warp_size_scale_factor
|
|
warp_size_scale_factor_n = 1
|
|
else:
|
|
warp_size_scale_factor_m = 1
|
|
warp_size_scale_factor_n = warp_size_scale_factor
|
|
|
|
# Block dimensions must be properly divisible
|
|
if (block_m * warp_size_scale_factor_m) % (warp_m * warp_tile_m) != 0:
|
|
logging.debug(
|
|
f"block_m*scale ({block_m * warp_size_scale_factor_m}) "
|
|
f"must be divisible by warp_m*warp_tile_m ({warp_m * warp_tile_m})"
|
|
)
|
|
return False
|
|
if (block_n * warp_size_scale_factor_n) % (warp_n * warp_tile_n) != 0:
|
|
logging.debug(
|
|
f"block_n*scale ({block_n * warp_size_scale_factor_n}) "
|
|
f"must be divisible by warp_n*warp_tile_n ({warp_n * warp_tile_n})"
|
|
)
|
|
return False
|
|
|
|
# BlockSize = WARP_SIZE * warp_m * warp_n; should be reasonable
|
|
block_size = WARP_SIZE * warp_m * warp_n
|
|
if block_size > 1024:
|
|
logging.debug(f"BlockSize ({block_size}) exceeds maximum of 1024")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def is_trait_combination_valid(reduce_op, output_index, propagate_nan, pooling_dim):
|
|
"""
|
|
Validate a pooling trait combination.
|
|
|
|
Parameters:
|
|
reduce_op: "max" or "avg"
|
|
output_index: bool - whether to output indices
|
|
propagate_nan: bool - whether to propagate NaN
|
|
pooling_dim: "2d" or "3d"
|
|
"""
|
|
# output_index only makes sense for max pooling
|
|
if output_index and reduce_op != "max":
|
|
logging.debug("output_index is only supported for max pooling")
|
|
return False
|
|
|
|
# Pooling dimension must be valid
|
|
if pooling_dim not in ("2d", "3d"):
|
|
logging.debug(f"Invalid pooling dimension: {pooling_dim}")
|
|
return False
|
|
|
|
return True
|