[CK_TILE] Add pooling to ckTileEngine part4 fix suppported configurations

This commit is contained in:
Aleksander Dudek
2025-12-11 12:11:42 +00:00
parent 07c078d5ef
commit f6d2243288
5 changed files with 405 additions and 112 deletions

View File

@@ -186,15 +186,27 @@ class PoolKernelBuilder:
if warp_tile_m <= 0 or warp_tile_n <= 0:
return False
# Check block_m is divisible by warp_m
if block_m % warp_m != 0:
return False
if block_n % warp_n != 0:
return False
# Check thread tile fits in warp tile
if warp_tile_m % thread_tile_m != 0:
return False
if warp_tile_n % thread_tile_n != 0:
return False
# Check threads per warp constraint
threads_per_warp = (warp_tile_m // thread_tile_m) * (warp_tile_n // thread_tile_n)
if threads_per_warp > warp_size:
# Critical constraint from pool_shape.hpp:
# (Warp_M * Warp_N / ThreadTile_M / ThreadTile_N) % warp_size == 0
# This means threads_per_warp must be a multiple of warp_size (typically equal to it)
threads_per_warp = (warp_tile_m * warp_tile_n) // (thread_tile_m * thread_tile_n)
if threads_per_warp % warp_size != 0:
return False
# threads_per_warp should not be too large (usually exactly warp_size)
if threads_per_warp > warp_size * 4:
return False
return True
@@ -268,9 +280,10 @@ constexpr const char* KERNEL_NAME = "{kernel_name}";
constexpr const char* BLOCK_SHAPE_NAME = "{block_str}";
constexpr const char* REDUCE_OP_NAME = "{self.reduce_op}";
// Flags
// Flags and dimensions
constexpr bool OUTPUT_INDEX = {"true" if output_index else "false"};
constexpr bool PROPAGATE_NAN = {"true" if propagate_nan else "false"};
constexpr int POOL_DIM = {pool_dim};
// Block configuration
using BlockWarps = ck_tile::sequence<{block_config['warp_m']}, {block_config['warp_n']}>;