Files
composable_kernel/tile_engine/ops/pooling/pooling_validation_utils.py
aledudek 119712bd90 [rocm-libraries] ROCm/rocm-libraries#4469 (commit 0844cb0)
[CK_TILE] Add pooling in tile_engine

## Motivation

<!-- Explain the purpose of this PR and the goals it aims to achieve.
-->
Add pooling in ck tile engine

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-01 07:32:36 +00:00

488 lines
15 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Validation utilities for pooling tile_engine configurations.
Validates tile configurations, trait combinations, and datatype support for
pooling kernels. Modelled after gemm_validation_utils.py — each constraint
from the CK PoolShape / PoolKernel static_asserts is mirrored here so that
invalid configs are rejected at code-generation time rather than at compile
or runtime.
"""
import logging
from typing import List, Tuple
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Hardware constants
# ---------------------------------------------------------------------------
# Default warp size (wave64 for CDNA architectures)
WARP_SIZE = 64
MAX_BLOCK_SIZE = 1024 # Maximum threads per workgroup on AMD GPUs
MAX_LDS_BYTES = 65536 # 64 KB LDS per workgroup
def get_warp_size_for_gpu(gpu_target: str) -> int:
"""Get the warp size for a given GPU target.
CDNA architectures (gfx9xx) use WAVE64 (64 threads per wavefront).
RDNA architectures (gfx10xx, gfx11xx, gfx12xx) use WAVE32 (32 threads per wavefront).
"""
if gpu_target.startswith("gfx9"):
return 64 # CDNA - WAVE64
return 32 # RDNA and others - WAVE32
# ---------------------------------------------------------------------------
# Datatype helpers
# ---------------------------------------------------------------------------
ELEMENT_SIZE_MAP = {
"fp8": 1,
"bf8": 1,
"int8": 1,
"fp16": 2,
"bf16": 2,
"int4": 0.5,
"int32": 4,
"fp32": 4,
"fp64": 8,
}
DTYPE_STRING_MAP = {
"fp8": "ck_tile::fp8_t",
"bf8": "ck_tile::bf8_t",
"fp16": "ck_tile::fp16_t",
"bf16": "ck_tile::bf16_t",
"fp32": "float",
"fp64": "double",
}
SUPPORTED_DATATYPES = list(DTYPE_STRING_MAP.keys())
# ---------------------------------------------------------------------------
# Reduce-op helpers
# ---------------------------------------------------------------------------
REDUCE_OP_STRING_MAP = {
"max": "ck_tile::ReduceOp::Max",
"min": "ck_tile::ReduceOp::Min",
"avg": "ck_tile::ReduceOp::Add",
}
SUPPORTED_REDUCE_OPS = list(REDUCE_OP_STRING_MAP.keys())
SUPPORTED_POOLING_DIMS = ("2d", "3d")
# ---------------------------------------------------------------------------
# Public helper functions (used by the instance builder)
# ---------------------------------------------------------------------------
def element_size(datatype: str) -> float:
"""Return the byte-width of a single element for *datatype*."""
datatype = datatype.lower()
if datatype not in ELEMENT_SIZE_MAP:
raise ValueError(
f"Unsupported data type: '{datatype}'. "
f"Supported: {list(ELEMENT_SIZE_MAP.keys())}"
)
return ELEMENT_SIZE_MAP[datatype]
def get_dtype_string(datatype: str) -> str:
"""Return the C++ type string (e.g. ``ck_tile::fp16_t``) for *datatype*."""
return DTYPE_STRING_MAP.get(datatype, "float")
def get_reduce_op_string(reduce_op: str) -> str:
"""Return the C++ ReduceOp enumerator string for *reduce_op*."""
return REDUCE_OP_STRING_MAP.get(reduce_op, "ck_tile::ReduceOp::Max")
# ---------------------------------------------------------------------------
# Individual tile-config validators
# ---------------------------------------------------------------------------
def validate_positivity(
block_m: int,
block_n: int,
warp_m: int,
warp_n: int,
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
) -> Tuple[bool, str]:
"""All tile parameters must be positive integers."""
params = {
"block_m": block_m,
"block_n": block_n,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_tile_m": warp_tile_m,
"warp_tile_n": warp_tile_n,
"thread_tile_m": thread_tile_m,
"thread_tile_n": thread_tile_n,
}
for name, val in params.items():
if val <= 0:
return False, f"{name} ({val}) must be > 0"
return True, ""
def validate_power_of_two(
block_m: int,
block_n: int,
warp_m: int,
warp_n: int,
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
) -> Tuple[bool, str]:
"""All tile parameters should be powers of two for correct GPU addressing."""
params = {
"block_m": block_m,
"block_n": block_n,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_tile_m": warp_tile_m,
"warp_tile_n": warp_tile_n,
"thread_tile_m": thread_tile_m,
"thread_tile_n": thread_tile_n,
}
for name, val in params.items():
if val > 0 and (val & (val - 1)) != 0:
return False, f"{name} ({val}) is not a power of two"
return True, ""
def validate_thread_tile_alignment(
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
) -> Tuple[bool, str]:
"""
Mirrors pool_shape.hpp:
static_assert(Warp_M % ThreadTile_M == 0);
static_assert(Warp_N % ThreadTile_N == 0);
"""
if warp_tile_m % thread_tile_m != 0:
return (
False,
f"warp_tile_m ({warp_tile_m}) must be divisible by "
f"thread_tile_m ({thread_tile_m})",
)
if warp_tile_n % thread_tile_n != 0:
return (
False,
f"warp_tile_n ({warp_tile_n}) must be divisible by "
f"thread_tile_n ({thread_tile_n})",
)
return True, ""
def validate_warp_thread_distribution(
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
warp_size: int = WARP_SIZE,
) -> Tuple[bool, str]:
"""
Mirrors pool_shape.hpp:
static_assert((Warp_M * Warp_N / ThreadTile_M / ThreadTile_N)
% get_warp_size() == 0);
"""
threads_per_warp = (warp_tile_m * warp_tile_n) // (thread_tile_m * thread_tile_n)
if threads_per_warp % warp_size != 0:
return (
False,
f"(warp_tile_m * warp_tile_n) / (thread_tile_m * thread_tile_n) = "
f"{threads_per_warp} is not a multiple of warp_size ({warp_size})",
)
return True, ""
def _compute_warp_size_scale_factors(
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
warp_size: int = WARP_SIZE,
) -> Tuple[int, int]:
"""
Reproduce the WarpSizeScaleFactor_M / _N logic from pool_shape.hpp.
"""
threads_per_warp = (warp_tile_m * warp_tile_n) // (thread_tile_m * thread_tile_n)
scale = threads_per_warp // warp_size
if warp_tile_m // thread_tile_m > warp_tile_n // thread_tile_n:
return scale, 1
return 1, scale
def validate_block_tile_coverage(
block_m: int,
block_n: int,
warp_m: int,
warp_n: int,
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
warp_size: int = WARP_SIZE,
) -> Tuple[bool, str]:
"""
Mirrors pool_shape.hpp:
static_assert((Block_M * WarpSizeScaleFactor_M) %
(WarpPerBlock_M * Warp_M) == 0);
static_assert((Block_N * WarpSizeScaleFactor_N) %
(WarpPerBlock_N * Warp_N) == 0);
"""
sf_m, sf_n = _compute_warp_size_scale_factors(
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n, warp_size
)
if (block_m * sf_m) % (warp_m * warp_tile_m) != 0:
return (
False,
f"block_m*ScaleFactor_M ({block_m}*{sf_m}={block_m * sf_m}) must be "
f"divisible by warp_m*warp_tile_m ({warp_m}*{warp_tile_m}"
f"={warp_m * warp_tile_m})",
)
if (block_n * sf_n) % (warp_n * warp_tile_n) != 0:
return (
False,
f"block_n*ScaleFactor_N ({block_n}*{sf_n}={block_n * sf_n}) must be "
f"divisible by warp_n*warp_tile_n ({warp_n}*{warp_tile_n}"
f"={warp_n * warp_tile_n})",
)
return True, ""
def validate_block_size(
warp_m: int,
warp_n: int,
warp_size: int = WARP_SIZE,
) -> Tuple[bool, str]:
"""BlockSize = warp_size * warp_m * warp_n must be <= MAX_BLOCK_SIZE."""
block_size = warp_size * warp_m * warp_n
if block_size > MAX_BLOCK_SIZE:
return (
False,
f"BlockSize ({block_size} = {warp_size}*{warp_m}*{warp_n}) "
f"exceeds maximum ({MAX_BLOCK_SIZE})",
)
return True, ""
def validate_vector_load_alignment(
block_m: int,
thread_tile_m: int,
in_datatype: str,
) -> Tuple[bool, str]:
"""
The M-dimension thread-tile determines the contiguous vector load width.
It must produce a load whose byte-width divides 16 bytes (max global
vector load width on AMD GPUs) and is at least 1 element wide.
"""
elem_bytes = element_size(in_datatype)
load_bytes = thread_tile_m * elem_bytes
if load_bytes > 16:
return (
False,
f"thread_tile_m ({thread_tile_m}) * element_size({in_datatype}, "
f"{elem_bytes}B) = {load_bytes}B exceeds 16B max vector load",
)
if 16 % load_bytes != 0 and load_bytes % 16 != 0:
return (
False,
f"Vector load width ({load_bytes}B) is not a divisor of 16B",
)
return True, ""
def validate_repeat_factors(
block_m: int,
block_n: int,
warp_m: int,
warp_n: int,
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
) -> Tuple[bool, str]:
"""
Repeat_M and Repeat_N from pool_shape.hpp must be >= 1. They are the
number of tile iterations each warp performs within the block.
"""
sf_m, sf_n = _compute_warp_size_scale_factors(
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n
)
repeat_m = (block_m * sf_m) // (warp_m * warp_tile_m)
repeat_n = (block_n * sf_n) // (warp_n * warp_tile_n)
if repeat_m < 1:
return False, f"Repeat_M ({repeat_m}) must be >= 1"
if repeat_n < 1:
return False, f"Repeat_N ({repeat_n}) must be >= 1"
return True, ""
# ---------------------------------------------------------------------------
# Comprehensive tile-config validation (entry point)
# ---------------------------------------------------------------------------
def is_tile_config_valid(
block_m: int,
block_n: int,
warp_m: int,
warp_n: int,
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
in_datatype: str,
out_datatype: str,
fast_mode: bool = False,
gpu_target: str = "gfx90a",
) -> bool:
"""
Comprehensive pooling tile configuration validation.
When *fast_mode* is True only cheap sanity checks are performed (useful
for the ``--list_kernels`` path). Full mode mirrors every
``static_assert`` in ``pool_shape.hpp``.
Parameters
----------
block_m, block_n : Block tile dimensions (M = output elems, N = window).
warp_m, warp_n : Warps per block along each dimension.
warp_tile_m, warp_tile_n : Tile processed per warp.
thread_tile_m, thread_tile_n : Contiguous elements per thread.
in_datatype : Input element type (e.g. ``"fp16"``).
out_datatype : Output element type.
fast_mode : Skip expensive checks when True.
"""
all_params = (
block_m, block_n, warp_m, warp_n,
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n,
)
# --- Positivity (always) ---
ok, err = validate_positivity(*all_params)
if not ok:
logger.debug(f"Positivity check failed: {err}")
return False
# --- Thread-tile alignment (always) ---
ok, err = validate_thread_tile_alignment(
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n
)
if not ok:
logger.debug(f"Thread tile alignment failed: {err}")
return False
if fast_mode:
return True
# Get the warp size for this GPU target
warp_size = get_warp_size_for_gpu(gpu_target)
# --- Power-of-two ---
ok, err = validate_power_of_two(*all_params)
if not ok:
logger.debug(f"Power-of-two check failed: {err}")
return False
# --- Warp-thread distribution ---
ok, err = validate_warp_thread_distribution(
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n, warp_size
)
if not ok:
logger.debug(f"Warp thread distribution failed: {err}")
return False
# --- Block-tile coverage ---
ok, err = validate_block_tile_coverage(*all_params, warp_size=warp_size)
if not ok:
logger.debug(f"Block tile coverage failed: {err}")
return False
# --- Block size ---
ok, err = validate_block_size(warp_m, warp_n, warp_size)
if not ok:
logger.debug(f"Block size check failed: {err}")
return False
# --- Repeat factors ---
ok, err = validate_repeat_factors(*all_params)
if not ok:
logger.debug(f"Repeat factor check failed: {err}")
return False
# --- Vector load alignment ---
ok, err = validate_vector_load_alignment(block_m, thread_tile_m, in_datatype)
if not ok:
logger.debug(f"Vector load alignment failed: {err}")
return False
return True
# ---------------------------------------------------------------------------
# Trait-combination validation
# ---------------------------------------------------------------------------
def is_trait_combination_valid(
reduce_op: str,
output_index: bool,
propagate_nan: bool,
pooling_dim: str,
) -> bool:
"""
Validate a pooling trait combination.
Parameters
----------
reduce_op : ``"max"``, ``"min"``, or ``"avg"``.
output_index : Whether to output indices of the selected elements.
propagate_nan: Whether to propagate NaN values through the reduction.
pooling_dim : ``"2d"`` or ``"3d"``.
"""
if reduce_op not in SUPPORTED_REDUCE_OPS:
logger.debug(f"Unsupported reduce_op: '{reduce_op}'")
return False
if pooling_dim not in SUPPORTED_POOLING_DIMS:
logger.debug(f"Invalid pooling dimension: '{pooling_dim}'")
return False
# output_index only makes sense for max pooling (CK constraint)
if output_index and reduce_op != "max":
logger.debug(
f"output_index=True is only supported for 'max' pooling, "
f"not '{reduce_op}'"
)
return False
return True
# ---------------------------------------------------------------------------
# Datatype validation
# ---------------------------------------------------------------------------
def is_datatype_supported(datatype: str) -> bool:
"""Return True if *datatype* is a known pooling datatype."""
return datatype.lower() in ELEMENT_SIZE_MAP