mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Add pooling to ckTileEngine part4 fix suppported configurations
This commit is contained in:
@@ -186,15 +186,27 @@ class PoolKernelBuilder:
|
||||
if warp_tile_m <= 0 or warp_tile_n <= 0:
|
||||
return False
|
||||
|
||||
# Check block_m is divisible by warp_m
|
||||
if block_m % warp_m != 0:
|
||||
return False
|
||||
if block_n % warp_n != 0:
|
||||
return False
|
||||
|
||||
# Check thread tile fits in warp tile
|
||||
if warp_tile_m % thread_tile_m != 0:
|
||||
return False
|
||||
if warp_tile_n % thread_tile_n != 0:
|
||||
return False
|
||||
|
||||
# Check threads per warp constraint
|
||||
threads_per_warp = (warp_tile_m // thread_tile_m) * (warp_tile_n // thread_tile_n)
|
||||
if threads_per_warp > warp_size:
|
||||
# Critical constraint from pool_shape.hpp:
|
||||
# (Warp_M * Warp_N / ThreadTile_M / ThreadTile_N) % warp_size == 0
|
||||
# This means threads_per_warp must be a multiple of warp_size (typically equal to it)
|
||||
threads_per_warp = (warp_tile_m * warp_tile_n) // (thread_tile_m * thread_tile_n)
|
||||
if threads_per_warp % warp_size != 0:
|
||||
return False
|
||||
|
||||
# threads_per_warp should not be too large (usually exactly warp_size)
|
||||
if threads_per_warp > warp_size * 4:
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -268,9 +280,10 @@ constexpr const char* KERNEL_NAME = "{kernel_name}";
|
||||
constexpr const char* BLOCK_SHAPE_NAME = "{block_str}";
|
||||
constexpr const char* REDUCE_OP_NAME = "{self.reduce_op}";
|
||||
|
||||
// Flags
|
||||
// Flags and dimensions
|
||||
constexpr bool OUTPUT_INDEX = {"true" if output_index else "false"};
|
||||
constexpr bool PROPAGATE_NAN = {"true" if propagate_nan else "false"};
|
||||
constexpr int POOL_DIM = {pool_dim};
|
||||
|
||||
// Block configuration
|
||||
using BlockWarps = ck_tile::sequence<{block_config['warp_m']}, {block_config['warp_n']}>;
|
||||
|
||||
Reference in New Issue
Block a user