mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
[CK_TILE] Add pooling in tile_engine ## Motivation <!-- Explain the purpose of this PR and the goals it aims to achieve. --> Add pooling in ck tile engine ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
488 lines
15 KiB
Python
488 lines
15 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Validation utilities for pooling tile_engine configurations.
|
|
|
|
Validates tile configurations, trait combinations, and datatype support for
|
|
pooling kernels. Modelled after gemm_validation_utils.py — each constraint
|
|
from the CK PoolShape / PoolKernel static_asserts is mirrored here so that
|
|
invalid configs are rejected at code-generation time rather than at compile
|
|
or runtime.
|
|
"""
|
|
|
|
import logging
|
|
from typing import List, Tuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Hardware constants
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Default warp size (wave64 for CDNA architectures)
|
|
WARP_SIZE = 64
|
|
MAX_BLOCK_SIZE = 1024 # Maximum threads per workgroup on AMD GPUs
|
|
MAX_LDS_BYTES = 65536 # 64 KB LDS per workgroup
|
|
|
|
def get_warp_size_for_gpu(gpu_target: str) -> int:
|
|
"""Get the warp size for a given GPU target.
|
|
|
|
CDNA architectures (gfx9xx) use WAVE64 (64 threads per wavefront).
|
|
RDNA architectures (gfx10xx, gfx11xx, gfx12xx) use WAVE32 (32 threads per wavefront).
|
|
"""
|
|
if gpu_target.startswith("gfx9"):
|
|
return 64 # CDNA - WAVE64
|
|
return 32 # RDNA and others - WAVE32
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Datatype helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
ELEMENT_SIZE_MAP = {
|
|
"fp8": 1,
|
|
"bf8": 1,
|
|
"int8": 1,
|
|
"fp16": 2,
|
|
"bf16": 2,
|
|
"int4": 0.5,
|
|
"int32": 4,
|
|
"fp32": 4,
|
|
"fp64": 8,
|
|
}
|
|
|
|
DTYPE_STRING_MAP = {
|
|
"fp8": "ck_tile::fp8_t",
|
|
"bf8": "ck_tile::bf8_t",
|
|
"fp16": "ck_tile::fp16_t",
|
|
"bf16": "ck_tile::bf16_t",
|
|
"fp32": "float",
|
|
"fp64": "double",
|
|
}
|
|
|
|
SUPPORTED_DATATYPES = list(DTYPE_STRING_MAP.keys())
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Reduce-op helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
REDUCE_OP_STRING_MAP = {
|
|
"max": "ck_tile::ReduceOp::Max",
|
|
"min": "ck_tile::ReduceOp::Min",
|
|
"avg": "ck_tile::ReduceOp::Add",
|
|
}
|
|
|
|
SUPPORTED_REDUCE_OPS = list(REDUCE_OP_STRING_MAP.keys())
|
|
|
|
SUPPORTED_POOLING_DIMS = ("2d", "3d")
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Public helper functions (used by the instance builder)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def element_size(datatype: str) -> float:
|
|
"""Return the byte-width of a single element for *datatype*."""
|
|
datatype = datatype.lower()
|
|
if datatype not in ELEMENT_SIZE_MAP:
|
|
raise ValueError(
|
|
f"Unsupported data type: '{datatype}'. "
|
|
f"Supported: {list(ELEMENT_SIZE_MAP.keys())}"
|
|
)
|
|
return ELEMENT_SIZE_MAP[datatype]
|
|
|
|
|
|
def get_dtype_string(datatype: str) -> str:
|
|
"""Return the C++ type string (e.g. ``ck_tile::fp16_t``) for *datatype*."""
|
|
return DTYPE_STRING_MAP.get(datatype, "float")
|
|
|
|
|
|
def get_reduce_op_string(reduce_op: str) -> str:
|
|
"""Return the C++ ReduceOp enumerator string for *reduce_op*."""
|
|
return REDUCE_OP_STRING_MAP.get(reduce_op, "ck_tile::ReduceOp::Max")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Individual tile-config validators
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def validate_positivity(
|
|
block_m: int,
|
|
block_n: int,
|
|
warp_m: int,
|
|
warp_n: int,
|
|
warp_tile_m: int,
|
|
warp_tile_n: int,
|
|
thread_tile_m: int,
|
|
thread_tile_n: int,
|
|
) -> Tuple[bool, str]:
|
|
"""All tile parameters must be positive integers."""
|
|
params = {
|
|
"block_m": block_m,
|
|
"block_n": block_n,
|
|
"warp_m": warp_m,
|
|
"warp_n": warp_n,
|
|
"warp_tile_m": warp_tile_m,
|
|
"warp_tile_n": warp_tile_n,
|
|
"thread_tile_m": thread_tile_m,
|
|
"thread_tile_n": thread_tile_n,
|
|
}
|
|
for name, val in params.items():
|
|
if val <= 0:
|
|
return False, f"{name} ({val}) must be > 0"
|
|
return True, ""
|
|
|
|
|
|
def validate_power_of_two(
|
|
block_m: int,
|
|
block_n: int,
|
|
warp_m: int,
|
|
warp_n: int,
|
|
warp_tile_m: int,
|
|
warp_tile_n: int,
|
|
thread_tile_m: int,
|
|
thread_tile_n: int,
|
|
) -> Tuple[bool, str]:
|
|
"""All tile parameters should be powers of two for correct GPU addressing."""
|
|
params = {
|
|
"block_m": block_m,
|
|
"block_n": block_n,
|
|
"warp_m": warp_m,
|
|
"warp_n": warp_n,
|
|
"warp_tile_m": warp_tile_m,
|
|
"warp_tile_n": warp_tile_n,
|
|
"thread_tile_m": thread_tile_m,
|
|
"thread_tile_n": thread_tile_n,
|
|
}
|
|
for name, val in params.items():
|
|
if val > 0 and (val & (val - 1)) != 0:
|
|
return False, f"{name} ({val}) is not a power of two"
|
|
return True, ""
|
|
|
|
|
|
def validate_thread_tile_alignment(
|
|
warp_tile_m: int,
|
|
warp_tile_n: int,
|
|
thread_tile_m: int,
|
|
thread_tile_n: int,
|
|
) -> Tuple[bool, str]:
|
|
"""
|
|
Mirrors pool_shape.hpp:
|
|
static_assert(Warp_M % ThreadTile_M == 0);
|
|
static_assert(Warp_N % ThreadTile_N == 0);
|
|
"""
|
|
if warp_tile_m % thread_tile_m != 0:
|
|
return (
|
|
False,
|
|
f"warp_tile_m ({warp_tile_m}) must be divisible by "
|
|
f"thread_tile_m ({thread_tile_m})",
|
|
)
|
|
if warp_tile_n % thread_tile_n != 0:
|
|
return (
|
|
False,
|
|
f"warp_tile_n ({warp_tile_n}) must be divisible by "
|
|
f"thread_tile_n ({thread_tile_n})",
|
|
)
|
|
return True, ""
|
|
|
|
|
|
def validate_warp_thread_distribution(
|
|
warp_tile_m: int,
|
|
warp_tile_n: int,
|
|
thread_tile_m: int,
|
|
thread_tile_n: int,
|
|
warp_size: int = WARP_SIZE,
|
|
) -> Tuple[bool, str]:
|
|
"""
|
|
Mirrors pool_shape.hpp:
|
|
static_assert((Warp_M * Warp_N / ThreadTile_M / ThreadTile_N)
|
|
% get_warp_size() == 0);
|
|
"""
|
|
threads_per_warp = (warp_tile_m * warp_tile_n) // (thread_tile_m * thread_tile_n)
|
|
if threads_per_warp % warp_size != 0:
|
|
return (
|
|
False,
|
|
f"(warp_tile_m * warp_tile_n) / (thread_tile_m * thread_tile_n) = "
|
|
f"{threads_per_warp} is not a multiple of warp_size ({warp_size})",
|
|
)
|
|
return True, ""
|
|
|
|
|
|
def _compute_warp_size_scale_factors(
|
|
warp_tile_m: int,
|
|
warp_tile_n: int,
|
|
thread_tile_m: int,
|
|
thread_tile_n: int,
|
|
warp_size: int = WARP_SIZE,
|
|
) -> Tuple[int, int]:
|
|
"""
|
|
Reproduce the WarpSizeScaleFactor_M / _N logic from pool_shape.hpp.
|
|
"""
|
|
threads_per_warp = (warp_tile_m * warp_tile_n) // (thread_tile_m * thread_tile_n)
|
|
scale = threads_per_warp // warp_size
|
|
|
|
if warp_tile_m // thread_tile_m > warp_tile_n // thread_tile_n:
|
|
return scale, 1
|
|
return 1, scale
|
|
|
|
|
|
def validate_block_tile_coverage(
|
|
block_m: int,
|
|
block_n: int,
|
|
warp_m: int,
|
|
warp_n: int,
|
|
warp_tile_m: int,
|
|
warp_tile_n: int,
|
|
thread_tile_m: int,
|
|
thread_tile_n: int,
|
|
warp_size: int = WARP_SIZE,
|
|
) -> Tuple[bool, str]:
|
|
"""
|
|
Mirrors pool_shape.hpp:
|
|
static_assert((Block_M * WarpSizeScaleFactor_M) %
|
|
(WarpPerBlock_M * Warp_M) == 0);
|
|
static_assert((Block_N * WarpSizeScaleFactor_N) %
|
|
(WarpPerBlock_N * Warp_N) == 0);
|
|
"""
|
|
sf_m, sf_n = _compute_warp_size_scale_factors(
|
|
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n, warp_size
|
|
)
|
|
|
|
if (block_m * sf_m) % (warp_m * warp_tile_m) != 0:
|
|
return (
|
|
False,
|
|
f"block_m*ScaleFactor_M ({block_m}*{sf_m}={block_m * sf_m}) must be "
|
|
f"divisible by warp_m*warp_tile_m ({warp_m}*{warp_tile_m}"
|
|
f"={warp_m * warp_tile_m})",
|
|
)
|
|
if (block_n * sf_n) % (warp_n * warp_tile_n) != 0:
|
|
return (
|
|
False,
|
|
f"block_n*ScaleFactor_N ({block_n}*{sf_n}={block_n * sf_n}) must be "
|
|
f"divisible by warp_n*warp_tile_n ({warp_n}*{warp_tile_n}"
|
|
f"={warp_n * warp_tile_n})",
|
|
)
|
|
return True, ""
|
|
|
|
|
|
def validate_block_size(
|
|
warp_m: int,
|
|
warp_n: int,
|
|
warp_size: int = WARP_SIZE,
|
|
) -> Tuple[bool, str]:
|
|
"""BlockSize = warp_size * warp_m * warp_n must be <= MAX_BLOCK_SIZE."""
|
|
block_size = warp_size * warp_m * warp_n
|
|
if block_size > MAX_BLOCK_SIZE:
|
|
return (
|
|
False,
|
|
f"BlockSize ({block_size} = {warp_size}*{warp_m}*{warp_n}) "
|
|
f"exceeds maximum ({MAX_BLOCK_SIZE})",
|
|
)
|
|
return True, ""
|
|
|
|
|
|
def validate_vector_load_alignment(
|
|
block_m: int,
|
|
thread_tile_m: int,
|
|
in_datatype: str,
|
|
) -> Tuple[bool, str]:
|
|
"""
|
|
The M-dimension thread-tile determines the contiguous vector load width.
|
|
It must produce a load whose byte-width divides 16 bytes (max global
|
|
vector load width on AMD GPUs) and is at least 1 element wide.
|
|
"""
|
|
elem_bytes = element_size(in_datatype)
|
|
load_bytes = thread_tile_m * elem_bytes
|
|
if load_bytes > 16:
|
|
return (
|
|
False,
|
|
f"thread_tile_m ({thread_tile_m}) * element_size({in_datatype}, "
|
|
f"{elem_bytes}B) = {load_bytes}B exceeds 16B max vector load",
|
|
)
|
|
if 16 % load_bytes != 0 and load_bytes % 16 != 0:
|
|
return (
|
|
False,
|
|
f"Vector load width ({load_bytes}B) is not a divisor of 16B",
|
|
)
|
|
return True, ""
|
|
|
|
|
|
def validate_repeat_factors(
|
|
block_m: int,
|
|
block_n: int,
|
|
warp_m: int,
|
|
warp_n: int,
|
|
warp_tile_m: int,
|
|
warp_tile_n: int,
|
|
thread_tile_m: int,
|
|
thread_tile_n: int,
|
|
) -> Tuple[bool, str]:
|
|
"""
|
|
Repeat_M and Repeat_N from pool_shape.hpp must be >= 1. They are the
|
|
number of tile iterations each warp performs within the block.
|
|
"""
|
|
sf_m, sf_n = _compute_warp_size_scale_factors(
|
|
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n
|
|
)
|
|
repeat_m = (block_m * sf_m) // (warp_m * warp_tile_m)
|
|
repeat_n = (block_n * sf_n) // (warp_n * warp_tile_n)
|
|
if repeat_m < 1:
|
|
return False, f"Repeat_M ({repeat_m}) must be >= 1"
|
|
if repeat_n < 1:
|
|
return False, f"Repeat_N ({repeat_n}) must be >= 1"
|
|
return True, ""
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Comprehensive tile-config validation (entry point)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def is_tile_config_valid(
|
|
block_m: int,
|
|
block_n: int,
|
|
warp_m: int,
|
|
warp_n: int,
|
|
warp_tile_m: int,
|
|
warp_tile_n: int,
|
|
thread_tile_m: int,
|
|
thread_tile_n: int,
|
|
in_datatype: str,
|
|
out_datatype: str,
|
|
fast_mode: bool = False,
|
|
gpu_target: str = "gfx90a",
|
|
) -> bool:
|
|
"""
|
|
Comprehensive pooling tile configuration validation.
|
|
|
|
When *fast_mode* is True only cheap sanity checks are performed (useful
|
|
for the ``--list_kernels`` path). Full mode mirrors every
|
|
``static_assert`` in ``pool_shape.hpp``.
|
|
|
|
Parameters
|
|
----------
|
|
block_m, block_n : Block tile dimensions (M = output elems, N = window).
|
|
warp_m, warp_n : Warps per block along each dimension.
|
|
warp_tile_m, warp_tile_n : Tile processed per warp.
|
|
thread_tile_m, thread_tile_n : Contiguous elements per thread.
|
|
in_datatype : Input element type (e.g. ``"fp16"``).
|
|
out_datatype : Output element type.
|
|
fast_mode : Skip expensive checks when True.
|
|
"""
|
|
all_params = (
|
|
block_m, block_n, warp_m, warp_n,
|
|
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n,
|
|
)
|
|
|
|
# --- Positivity (always) ---
|
|
ok, err = validate_positivity(*all_params)
|
|
if not ok:
|
|
logger.debug(f"Positivity check failed: {err}")
|
|
return False
|
|
|
|
# --- Thread-tile alignment (always) ---
|
|
ok, err = validate_thread_tile_alignment(
|
|
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n
|
|
)
|
|
if not ok:
|
|
logger.debug(f"Thread tile alignment failed: {err}")
|
|
return False
|
|
|
|
if fast_mode:
|
|
return True
|
|
|
|
# Get the warp size for this GPU target
|
|
warp_size = get_warp_size_for_gpu(gpu_target)
|
|
|
|
# --- Power-of-two ---
|
|
ok, err = validate_power_of_two(*all_params)
|
|
if not ok:
|
|
logger.debug(f"Power-of-two check failed: {err}")
|
|
return False
|
|
|
|
# --- Warp-thread distribution ---
|
|
ok, err = validate_warp_thread_distribution(
|
|
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n, warp_size
|
|
)
|
|
if not ok:
|
|
logger.debug(f"Warp thread distribution failed: {err}")
|
|
return False
|
|
|
|
# --- Block-tile coverage ---
|
|
ok, err = validate_block_tile_coverage(*all_params, warp_size=warp_size)
|
|
if not ok:
|
|
logger.debug(f"Block tile coverage failed: {err}")
|
|
return False
|
|
|
|
# --- Block size ---
|
|
ok, err = validate_block_size(warp_m, warp_n, warp_size)
|
|
if not ok:
|
|
logger.debug(f"Block size check failed: {err}")
|
|
return False
|
|
|
|
# --- Repeat factors ---
|
|
ok, err = validate_repeat_factors(*all_params)
|
|
if not ok:
|
|
logger.debug(f"Repeat factor check failed: {err}")
|
|
return False
|
|
|
|
# --- Vector load alignment ---
|
|
ok, err = validate_vector_load_alignment(block_m, thread_tile_m, in_datatype)
|
|
if not ok:
|
|
logger.debug(f"Vector load alignment failed: {err}")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Trait-combination validation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def is_trait_combination_valid(
|
|
reduce_op: str,
|
|
output_index: bool,
|
|
propagate_nan: bool,
|
|
pooling_dim: str,
|
|
) -> bool:
|
|
"""
|
|
Validate a pooling trait combination.
|
|
|
|
Parameters
|
|
----------
|
|
reduce_op : ``"max"``, ``"min"``, or ``"avg"``.
|
|
output_index : Whether to output indices of the selected elements.
|
|
propagate_nan: Whether to propagate NaN values through the reduction.
|
|
pooling_dim : ``"2d"`` or ``"3d"``.
|
|
"""
|
|
if reduce_op not in SUPPORTED_REDUCE_OPS:
|
|
logger.debug(f"Unsupported reduce_op: '{reduce_op}'")
|
|
return False
|
|
|
|
if pooling_dim not in SUPPORTED_POOLING_DIMS:
|
|
logger.debug(f"Invalid pooling dimension: '{pooling_dim}'")
|
|
return False
|
|
|
|
# output_index only makes sense for max pooling (CK constraint)
|
|
if output_index and reduce_op != "max":
|
|
logger.debug(
|
|
f"output_index=True is only supported for 'max' pooling, "
|
|
f"not '{reduce_op}'"
|
|
)
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Datatype validation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def is_datatype_supported(datatype: str) -> bool:
|
|
"""Return True if *datatype* is a known pooling datatype."""
|
|
return datatype.lower() in ELEMENT_SIZE_MAP
|