[rocm-libraries] ROCm/rocm-libraries#6511 (commit 867bece)

[CK_TILE] Adding steps in Stream-K Tile Engine

## Motivation
This PR adds step functionality to the Stream-K instance generator in
Tile Engine in order to quickly generate instance configurations within
a certain max/min range. To complement this, the Stream-K Tile Engine
validation file has been updated for more rigorous validation of
generated instances.

## Technical Details
- Added _generate_values helper to support min/max/step range-based tile
config generation, matching Universal GEMM
- Added validate_gemm, validate_whole_wg_cover_configuration,
validate_cshuffle_epilogue_distribution, and other supporting functions
to the Stream-K validation utils, aligning with the validation already
present in the Universal GEMM

## Test Plan
Tested using the generation in CK Tile Engine

## Test Result
All instances were generated and validated correctly.

## Submission Checklist
- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
arai713
2026-05-26 16:43:05 +00:00
committed by assistant-librarian[bot]
parent 74f8c0a9c1
commit 7d7d293e5d
2 changed files with 351 additions and 7 deletions

View File

@@ -127,6 +127,25 @@ class GemmKernelBuilder:
# New format - generate combinations from individual parameter values
tile_config = self.config["tile_config"]
if tile_config.get("tile_m").get("values") is None:
tile_config.get("tile_m")["values"] = self._generate_values(
tile_config.get("tile_m").get("min"),
tile_config.get("tile_m").get("max"),
tile_config.get("tile_m").get("step"),
)
if tile_config.get("tile_n").get("values") is None:
tile_config.get("tile_n")["values"] = self._generate_values(
tile_config.get("tile_n").get("min"),
tile_config.get("tile_n").get("max"),
tile_config.get("tile_n").get("step"),
)
if tile_config.get("tile_k").get("values") is None:
tile_config.get("tile_k")["values"] = self._generate_values(
tile_config.get("tile_k").get("min"),
tile_config.get("tile_k").get("max"),
tile_config.get("tile_k").get("step"),
)
# Get all possible values for each parameter
tile_m_values = tile_config.get("tile_m", {}).get("values", [256])
tile_n_values = tile_config.get("tile_n", {}).get("values", [256])
@@ -180,6 +199,15 @@ class GemmKernelBuilder:
# Fallback to default
return []
def _generate_values(self, min_val, max_val, step):
"""Generate a list of values from min to max with the given step"""
values = []
val = min_val
while val <= max_val:
values.append(val)
val += step
return values
def _validate_tile_config(
self,
tile_m,
@@ -191,7 +219,7 @@ class GemmKernelBuilder:
warp_tile_m,
warp_tile_n,
warp_tile_k,
pipeline="mem", # Default pipeline for validation
pipeline="compv3", # Default pipeline for validation
fast_mode=False, # Add fast mode option
):
"""Validate that tile configuration is reasonable"""
@@ -239,6 +267,7 @@ class GemmKernelBuilder:
b_datatype,
c_datatype,
pipeline,
self.layout
)
def _generate_trait_combinations(self):

View File

@@ -26,6 +26,16 @@ ELEMENT_SIZE_MAP = {
"fp64": 8,
}
def get_warp_size_for_gpu(gpu_target: str) -> int:
"""Get the warp size for a given GPU target.
CDNA architectures (gfx9xx) use WAVE64 (64 threads per wavefront).
RDNA architectures (gfx10xx, gfx11xx, gfx12xx) use WAVE32 (32 threads per wavefront).
"""
if gpu_target.startswith("gfx9"):
return 64 # CDNA - WAVE64
return 32 # RDNA and others - WAVE32
# Supported warp tile combinations for different GPU architectures and data types
WARP_TILE_SUPPORTED_COMBINATIONS = {
"gfx90a": {
@@ -34,7 +44,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
@@ -42,7 +51,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
@@ -54,7 +62,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
@@ -62,7 +69,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
@@ -75,7 +81,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
@@ -83,7 +88,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [
@@ -333,6 +337,7 @@ def is_tile_config_valid(
b_datatype: str,
c_datatype: str,
pipeline: str,
layout: str,
trait_name: str = None,
) -> bool:
"""
@@ -390,6 +395,17 @@ def is_tile_config_valid(
logging.debug(f"LDS validation failed: {lds_error}")
return False
gemm_valid, gemm_error = validate_gemm(
tile_m, tile_n, tile_k,
warp_m, warp_n, warp_k,
warp_tile_m, warp_tile_n, warp_tile_k,
a_datatype, b_datatype, c_datatype,
pipeline, layout,
)
if not gemm_valid:
logging.debug(f"GEMM validation failed: {gemm_error}")
return False
# Validate warp tile combination
warp_tile_valid, warp_tile_error = validate_warp_tile_combination(
warp_tile_m, warp_tile_n, warp_tile_k, a_datatype, b_datatype, c_datatype
@@ -399,3 +415,302 @@ def is_tile_config_valid(
return False
return True
def validate_whole_wg_cover_configuration(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
layout,
a_datatype,
b_datatype,
gpu_target: str = "gfx90a",
) -> Tuple[bool, str]:
# Validate whole workgroup cover configuration
warp_size = get_warp_size_for_gpu(gpu_target)
NumWarps = warp_m * warp_n * warp_k
BlockSize = NumWarps * warp_size
XPerTile = 0
YPerTile = 0
vector_load_size = 0
# A matrix validation
if layout[0] == "r":
vector_load_size = get_global_vector_load_size(
BlockSize, tile_k, a_datatype, tile_m, tile_k
)
XPerTile = tile_k
YPerTile = tile_m
elif layout[0] == "c":
vector_load_size = get_global_vector_load_size(
BlockSize, tile_k, a_datatype, tile_m, tile_m
)
# Validate distribution
XPerTile = tile_k
YPerTile = tile_m
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
)
if not wg_cover_core_valid:
logging.debug(
f"whole workgroup cover failed for Matrix A distribution: {wg_cover_core_error}"
)
return False, wg_cover_core_error
XPerTile = tile_m
YPerTile = tile_k
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
)
if not wg_cover_core_valid:
logging.debug(
f"whole workgroup cover failed for Matrix A: {wg_cover_core_error}"
)
return False, wg_cover_core_error
# B matrix validation
if layout[1] == "r":
vector_load_size = get_global_vector_load_size(
BlockSize, tile_k, b_datatype, tile_n, tile_n
)
# Validate distribution
XPerTile = tile_k
YPerTile = tile_n
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
)
if not wg_cover_core_valid:
logging.debug(
f"whole workgroup cover failed for Matrix B distribution: {wg_cover_core_error}"
)
return False, wg_cover_core_error
XPerTile = tile_n
YPerTile = tile_k
elif layout[1] == "c":
XPerTile = tile_k
YPerTile = tile_n
vector_load_size = get_global_vector_load_size(
BlockSize, tile_k, b_datatype, tile_n, tile_k
)
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
)
if not wg_cover_core_valid:
logging.debug(
f"whole workgroup cover failed for Matrix B: {wg_cover_core_error}"
)
return False, wg_cover_core_error
return True, ""
def wg_cover_core_validation(
XPerTile: int,
YPerTile: int,
BlockSize: int,
vector_load_size: int,
warp_size: int,
) -> Tuple[bool, str]:
if XPerTile % vector_load_size != 0:
return False, "XPerTile is not divisible by vector_load_size"
num_warps = BlockSize / warp_size
LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size)
X1 = LargestVec if vector_load_size > LargestVec else vector_load_size
X0 = XPerTile / X1
Y1 = warp_size // X0
if X0 * Y1 != warp_size:
return False, "X0 * Y1 != warp_size"
return True, ""
def validate_cshuffle_epilogue_distribution(
tile_m: int,
tile_n: int,
warp_m: int,
warp_n: int,
warp_k: int,
warp_tile_m: int,
warp_tile_n: int,
warp_size: int,
c_datatype: str,
) -> Tuple[bool, str]:
"""
Validate that the CShuffleEpilogue tile distribution pattern is valid.
This mirrors the static_assert in static_encoding_pattern.hpp:
static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
The CShuffleEpilogue creates a tile_distribution_encoding_pattern_2d<BlockSize, YPerTile, XPerTile, VecSize, thread_raked>
where:
- BlockSize = warp_m * warp_n * warp_k * warp_size
- YPerTile = MPerIterationShuffle (derived from tile_m / (warp_m * warp_tile_m / some_factor))
- XPerTile = NPerIterationShuffle (derived from tile_n)
- VecSize = vector size based on element size (typically 8 for fp16)
The key constraint is that X0 must evenly divide warp_size, where:
- X0 = min(warp_size, XPerTile / X1)
- X1 = min(VecSize, LargestVec)
- LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size)
"""
NumWarps = warp_m * warp_n * warp_k
BlockSize = NumWarps * warp_size
elem_size = ELEMENT_SIZE_MAP.get(c_datatype, 2)
VecSize = 16 // elem_size
XPerTile = tile_n
YPerTile = tile_m // warp_m
if XPerTile <= 0 or YPerTile <= 0:
return False, f"Invalid tile dimensions: XPerTile={XPerTile}, YPerTile={YPerTile}"
num_warps = BlockSize // warp_size
if num_warps * warp_size == 0:
return False, "Invalid BlockSize or warp_size"
LargestVec = (XPerTile * YPerTile) // (num_warps * warp_size)
if LargestVec <= 0:
LargestVec = 1
X1 = min(VecSize, LargestVec) if LargestVec > 0 else VecSize
if X1 <= 0:
X1 = 1
X0 = min(warp_size, XPerTile // X1) if X1 > 0 else warp_size
Y1 = warp_size // X0 if X0 > 0 else 0
if X0 * Y1 != warp_size:
return (
False,
f"CShuffleEpilogue distribution invalid: X0({X0}) * Y1({Y1}) = {X0 * Y1} != warp_size({warp_size}). "
f"XPerTile={XPerTile}, YPerTile={YPerTile}, VecSize={VecSize}, BlockSize={BlockSize}"
)
return True, ""
def get_global_vector_load_size(
BlockSize: int,
KPerBlock: int,
DataType: str,
MNPerBlock: int,
XPerTile: int,
) -> int:
elements_per_thread = MNPerBlock * KPerBlock / BlockSize
PackedSize = 1
if (
PackedSize == 2
and XPerTile % (PackedSize * 32 / element_size(DataType)) == 0
and elements_per_thread % (PackedSize * 32 / element_size(DataType)) == 0
):
return PackedSize * 32 / element_size(DataType)
elif (
XPerTile % (PackedSize * 16 / element_size(DataType)) == 0
and elements_per_thread % (PackedSize * 16 / element_size(DataType)) == 0
):
return int(PackedSize * 16 / element_size(DataType))
elif (
XPerTile % (PackedSize * 8 / element_size(DataType)) == 0
and elements_per_thread % (PackedSize * 8 / element_size(DataType)) == 0
):
return int(PackedSize * 8 / element_size(DataType))
elif (
element_size(DataType) >= PackedSize * 4
and XPerTile % (PackedSize * 4 / element_size(DataType)) == 0
and elements_per_thread % (PackedSize * 4 / element_size(DataType)) == 0
):
return int(PackedSize * 4 / element_size(DataType))
elif (
element_size(DataType) >= PackedSize * 2
and XPerTile % (PackedSize * 2 / element_size(DataType)) == 0
and elements_per_thread % (PackedSize * 2 / element_size(DataType)) == 0
):
return int(PackedSize * 2 / element_size(DataType))
else:
return PackedSize
def validate_gemm(
tile_m: int,
tile_n: int,
tile_k: int,
warp_m: int,
warp_n: int,
warp_k: int,
warp_tile_m: int,
warp_tile_n: int,
warp_tile_k: int,
a_datatype: str,
b_datatype: str,
c_datatype: str,
pipeline: str,
layout: str,
trait_name: str = None,
) -> bool:
# GEMM Validation
gpu_target = get_gpu_name_by_id(0)
warp_size = get_warp_size_for_gpu(gpu_target)
# Validate whole workgroup cover configuration
whole_workgroup_cover_valid, whole_workgroup_cover_error = (
validate_whole_wg_cover_configuration(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
layout,
a_datatype,
b_datatype,
gpu_target,
)
)
if not whole_workgroup_cover_valid:
logging.debug(
f"Whole workgroup cover configuration validation failed: {whole_workgroup_cover_error}"
)
return False, whole_workgroup_cover_error
# Validate CShuffleEpilogue distribution pattern (for cshuffle epilogue)
# This validation ensures the tile distribution pattern is valid for the output tile
cshuffle_valid, cshuffle_error = validate_cshuffle_epilogue_distribution(
tile_m,
tile_n,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_size,
c_datatype,
)
if not cshuffle_valid:
logging.debug(f"CShuffleEpilogue validation failed: {cshuffle_error}")
return False, cshuffle_error
return True, ""