mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
[rocm-libraries] ROCm/rocm-libraries#6511 (commit 867bece)
[CK_TILE] Adding steps in Stream-K Tile Engine ## Motivation This PR adds step functionality to the Stream-K instance generator in Tile Engine in order to quickly generate instance configurations within a certain max/min range. To complement this, the Stream-K Tile Engine validation file has been updated for more rigorous validation of generated instances. ## Technical Details - Added _generate_values helper to support min/max/step range-based tile config generation, matching Universal GEMM - Added validate_gemm, validate_whole_wg_cover_configuration, validate_cshuffle_epilogue_distribution, and other supporting functions to the Stream-K validation utils, aligning with the validation already present in the Universal GEMM ## Test Plan Tested using the generation in CK Tile Engine ## Test Result All instances were generated and validated correctly. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
74f8c0a9c1
commit
7d7d293e5d
@@ -127,6 +127,25 @@ class GemmKernelBuilder:
|
||||
# New format - generate combinations from individual parameter values
|
||||
tile_config = self.config["tile_config"]
|
||||
|
||||
if tile_config.get("tile_m").get("values") is None:
|
||||
tile_config.get("tile_m")["values"] = self._generate_values(
|
||||
tile_config.get("tile_m").get("min"),
|
||||
tile_config.get("tile_m").get("max"),
|
||||
tile_config.get("tile_m").get("step"),
|
||||
)
|
||||
if tile_config.get("tile_n").get("values") is None:
|
||||
tile_config.get("tile_n")["values"] = self._generate_values(
|
||||
tile_config.get("tile_n").get("min"),
|
||||
tile_config.get("tile_n").get("max"),
|
||||
tile_config.get("tile_n").get("step"),
|
||||
)
|
||||
if tile_config.get("tile_k").get("values") is None:
|
||||
tile_config.get("tile_k")["values"] = self._generate_values(
|
||||
tile_config.get("tile_k").get("min"),
|
||||
tile_config.get("tile_k").get("max"),
|
||||
tile_config.get("tile_k").get("step"),
|
||||
)
|
||||
|
||||
# Get all possible values for each parameter
|
||||
tile_m_values = tile_config.get("tile_m", {}).get("values", [256])
|
||||
tile_n_values = tile_config.get("tile_n", {}).get("values", [256])
|
||||
@@ -180,6 +199,15 @@ class GemmKernelBuilder:
|
||||
# Fallback to default
|
||||
return []
|
||||
|
||||
def _generate_values(self, min_val, max_val, step):
|
||||
"""Generate a list of values from min to max with the given step"""
|
||||
values = []
|
||||
val = min_val
|
||||
while val <= max_val:
|
||||
values.append(val)
|
||||
val += step
|
||||
return values
|
||||
|
||||
def _validate_tile_config(
|
||||
self,
|
||||
tile_m,
|
||||
@@ -191,7 +219,7 @@ class GemmKernelBuilder:
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
pipeline="mem", # Default pipeline for validation
|
||||
pipeline="compv3", # Default pipeline for validation
|
||||
fast_mode=False, # Add fast mode option
|
||||
):
|
||||
"""Validate that tile configuration is reasonable"""
|
||||
@@ -239,6 +267,7 @@ class GemmKernelBuilder:
|
||||
b_datatype,
|
||||
c_datatype,
|
||||
pipeline,
|
||||
self.layout
|
||||
)
|
||||
|
||||
def _generate_trait_combinations(self):
|
||||
|
||||
@@ -26,6 +26,16 @@ ELEMENT_SIZE_MAP = {
|
||||
"fp64": 8,
|
||||
}
|
||||
|
||||
def get_warp_size_for_gpu(gpu_target: str) -> int:
|
||||
"""Get the warp size for a given GPU target.
|
||||
|
||||
CDNA architectures (gfx9xx) use WAVE64 (64 threads per wavefront).
|
||||
RDNA architectures (gfx10xx, gfx11xx, gfx12xx) use WAVE32 (32 threads per wavefront).
|
||||
"""
|
||||
if gpu_target.startswith("gfx9"):
|
||||
return 64 # CDNA - WAVE64
|
||||
return 32 # RDNA and others - WAVE32
|
||||
|
||||
# Supported warp tile combinations for different GPU architectures and data types
|
||||
WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
"gfx90a": {
|
||||
@@ -34,7 +44,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
@@ -42,7 +51,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
@@ -54,7 +62,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
@@ -62,7 +69,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
@@ -75,7 +81,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
@@ -83,7 +88,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [
|
||||
@@ -333,6 +337,7 @@ def is_tile_config_valid(
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
pipeline: str,
|
||||
layout: str,
|
||||
trait_name: str = None,
|
||||
) -> bool:
|
||||
"""
|
||||
@@ -390,6 +395,17 @@ def is_tile_config_valid(
|
||||
logging.debug(f"LDS validation failed: {lds_error}")
|
||||
return False
|
||||
|
||||
gemm_valid, gemm_error = validate_gemm(
|
||||
tile_m, tile_n, tile_k,
|
||||
warp_m, warp_n, warp_k,
|
||||
warp_tile_m, warp_tile_n, warp_tile_k,
|
||||
a_datatype, b_datatype, c_datatype,
|
||||
pipeline, layout,
|
||||
)
|
||||
if not gemm_valid:
|
||||
logging.debug(f"GEMM validation failed: {gemm_error}")
|
||||
return False
|
||||
|
||||
# Validate warp tile combination
|
||||
warp_tile_valid, warp_tile_error = validate_warp_tile_combination(
|
||||
warp_tile_m, warp_tile_n, warp_tile_k, a_datatype, b_datatype, c_datatype
|
||||
@@ -399,3 +415,302 @@ def is_tile_config_valid(
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def validate_whole_wg_cover_configuration(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
layout,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
gpu_target: str = "gfx90a",
|
||||
) -> Tuple[bool, str]:
|
||||
# Validate whole workgroup cover configuration
|
||||
|
||||
warp_size = get_warp_size_for_gpu(gpu_target)
|
||||
NumWarps = warp_m * warp_n * warp_k
|
||||
BlockSize = NumWarps * warp_size
|
||||
|
||||
XPerTile = 0
|
||||
YPerTile = 0
|
||||
vector_load_size = 0
|
||||
|
||||
# A matrix validation
|
||||
if layout[0] == "r":
|
||||
vector_load_size = get_global_vector_load_size(
|
||||
BlockSize, tile_k, a_datatype, tile_m, tile_k
|
||||
)
|
||||
|
||||
XPerTile = tile_k
|
||||
YPerTile = tile_m
|
||||
|
||||
elif layout[0] == "c":
|
||||
vector_load_size = get_global_vector_load_size(
|
||||
BlockSize, tile_k, a_datatype, tile_m, tile_m
|
||||
)
|
||||
|
||||
# Validate distribution
|
||||
XPerTile = tile_k
|
||||
YPerTile = tile_m
|
||||
|
||||
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
|
||||
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
|
||||
)
|
||||
|
||||
if not wg_cover_core_valid:
|
||||
logging.debug(
|
||||
f"whole workgroup cover failed for Matrix A distribution: {wg_cover_core_error}"
|
||||
)
|
||||
return False, wg_cover_core_error
|
||||
|
||||
XPerTile = tile_m
|
||||
YPerTile = tile_k
|
||||
|
||||
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
|
||||
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
|
||||
)
|
||||
|
||||
if not wg_cover_core_valid:
|
||||
logging.debug(
|
||||
f"whole workgroup cover failed for Matrix A: {wg_cover_core_error}"
|
||||
)
|
||||
return False, wg_cover_core_error
|
||||
|
||||
# B matrix validation
|
||||
if layout[1] == "r":
|
||||
vector_load_size = get_global_vector_load_size(
|
||||
BlockSize, tile_k, b_datatype, tile_n, tile_n
|
||||
)
|
||||
|
||||
# Validate distribution
|
||||
XPerTile = tile_k
|
||||
YPerTile = tile_n
|
||||
|
||||
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
|
||||
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
|
||||
)
|
||||
|
||||
if not wg_cover_core_valid:
|
||||
logging.debug(
|
||||
f"whole workgroup cover failed for Matrix B distribution: {wg_cover_core_error}"
|
||||
)
|
||||
return False, wg_cover_core_error
|
||||
|
||||
XPerTile = tile_n
|
||||
YPerTile = tile_k
|
||||
|
||||
elif layout[1] == "c":
|
||||
XPerTile = tile_k
|
||||
YPerTile = tile_n
|
||||
|
||||
vector_load_size = get_global_vector_load_size(
|
||||
BlockSize, tile_k, b_datatype, tile_n, tile_k
|
||||
)
|
||||
|
||||
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
|
||||
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
|
||||
)
|
||||
if not wg_cover_core_valid:
|
||||
logging.debug(
|
||||
f"whole workgroup cover failed for Matrix B: {wg_cover_core_error}"
|
||||
)
|
||||
return False, wg_cover_core_error
|
||||
|
||||
return True, ""
|
||||
|
||||
def wg_cover_core_validation(
|
||||
XPerTile: int,
|
||||
YPerTile: int,
|
||||
BlockSize: int,
|
||||
vector_load_size: int,
|
||||
warp_size: int,
|
||||
) -> Tuple[bool, str]:
|
||||
if XPerTile % vector_load_size != 0:
|
||||
return False, "XPerTile is not divisible by vector_load_size"
|
||||
|
||||
num_warps = BlockSize / warp_size
|
||||
LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size)
|
||||
|
||||
X1 = LargestVec if vector_load_size > LargestVec else vector_load_size
|
||||
X0 = XPerTile / X1
|
||||
Y1 = warp_size // X0
|
||||
|
||||
if X0 * Y1 != warp_size:
|
||||
return False, "X0 * Y1 != warp_size"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_cshuffle_epilogue_distribution(
|
||||
tile_m: int,
|
||||
tile_n: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
warp_size: int,
|
||||
c_datatype: str,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Validate that the CShuffleEpilogue tile distribution pattern is valid.
|
||||
|
||||
This mirrors the static_assert in static_encoding_pattern.hpp:
|
||||
static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
|
||||
|
||||
The CShuffleEpilogue creates a tile_distribution_encoding_pattern_2d<BlockSize, YPerTile, XPerTile, VecSize, thread_raked>
|
||||
where:
|
||||
- BlockSize = warp_m * warp_n * warp_k * warp_size
|
||||
- YPerTile = MPerIterationShuffle (derived from tile_m / (warp_m * warp_tile_m / some_factor))
|
||||
- XPerTile = NPerIterationShuffle (derived from tile_n)
|
||||
- VecSize = vector size based on element size (typically 8 for fp16)
|
||||
|
||||
The key constraint is that X0 must evenly divide warp_size, where:
|
||||
- X0 = min(warp_size, XPerTile / X1)
|
||||
- X1 = min(VecSize, LargestVec)
|
||||
- LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size)
|
||||
"""
|
||||
NumWarps = warp_m * warp_n * warp_k
|
||||
BlockSize = NumWarps * warp_size
|
||||
|
||||
elem_size = ELEMENT_SIZE_MAP.get(c_datatype, 2)
|
||||
VecSize = 16 // elem_size
|
||||
|
||||
XPerTile = tile_n
|
||||
YPerTile = tile_m // warp_m
|
||||
|
||||
if XPerTile <= 0 or YPerTile <= 0:
|
||||
return False, f"Invalid tile dimensions: XPerTile={XPerTile}, YPerTile={YPerTile}"
|
||||
|
||||
num_warps = BlockSize // warp_size
|
||||
if num_warps * warp_size == 0:
|
||||
return False, "Invalid BlockSize or warp_size"
|
||||
|
||||
LargestVec = (XPerTile * YPerTile) // (num_warps * warp_size)
|
||||
if LargestVec <= 0:
|
||||
LargestVec = 1
|
||||
|
||||
X1 = min(VecSize, LargestVec) if LargestVec > 0 else VecSize
|
||||
if X1 <= 0:
|
||||
X1 = 1
|
||||
|
||||
X0 = min(warp_size, XPerTile // X1) if X1 > 0 else warp_size
|
||||
|
||||
Y1 = warp_size // X0 if X0 > 0 else 0
|
||||
|
||||
if X0 * Y1 != warp_size:
|
||||
return (
|
||||
False,
|
||||
f"CShuffleEpilogue distribution invalid: X0({X0}) * Y1({Y1}) = {X0 * Y1} != warp_size({warp_size}). "
|
||||
f"XPerTile={XPerTile}, YPerTile={YPerTile}, VecSize={VecSize}, BlockSize={BlockSize}"
|
||||
)
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def get_global_vector_load_size(
|
||||
BlockSize: int,
|
||||
KPerBlock: int,
|
||||
DataType: str,
|
||||
MNPerBlock: int,
|
||||
XPerTile: int,
|
||||
) -> int:
|
||||
elements_per_thread = MNPerBlock * KPerBlock / BlockSize
|
||||
PackedSize = 1
|
||||
|
||||
if (
|
||||
PackedSize == 2
|
||||
and XPerTile % (PackedSize * 32 / element_size(DataType)) == 0
|
||||
and elements_per_thread % (PackedSize * 32 / element_size(DataType)) == 0
|
||||
):
|
||||
return PackedSize * 32 / element_size(DataType)
|
||||
elif (
|
||||
XPerTile % (PackedSize * 16 / element_size(DataType)) == 0
|
||||
and elements_per_thread % (PackedSize * 16 / element_size(DataType)) == 0
|
||||
):
|
||||
return int(PackedSize * 16 / element_size(DataType))
|
||||
|
||||
elif (
|
||||
XPerTile % (PackedSize * 8 / element_size(DataType)) == 0
|
||||
and elements_per_thread % (PackedSize * 8 / element_size(DataType)) == 0
|
||||
):
|
||||
return int(PackedSize * 8 / element_size(DataType))
|
||||
elif (
|
||||
element_size(DataType) >= PackedSize * 4
|
||||
and XPerTile % (PackedSize * 4 / element_size(DataType)) == 0
|
||||
and elements_per_thread % (PackedSize * 4 / element_size(DataType)) == 0
|
||||
):
|
||||
return int(PackedSize * 4 / element_size(DataType))
|
||||
elif (
|
||||
element_size(DataType) >= PackedSize * 2
|
||||
and XPerTile % (PackedSize * 2 / element_size(DataType)) == 0
|
||||
and elements_per_thread % (PackedSize * 2 / element_size(DataType)) == 0
|
||||
):
|
||||
return int(PackedSize * 2 / element_size(DataType))
|
||||
else:
|
||||
return PackedSize
|
||||
|
||||
def validate_gemm(
|
||||
tile_m: int,
|
||||
tile_n: int,
|
||||
tile_k: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
warp_tile_k: int,
|
||||
a_datatype: str,
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
pipeline: str,
|
||||
layout: str,
|
||||
trait_name: str = None,
|
||||
) -> bool:
|
||||
# GEMM Validation
|
||||
gpu_target = get_gpu_name_by_id(0)
|
||||
warp_size = get_warp_size_for_gpu(gpu_target)
|
||||
|
||||
# Validate whole workgroup cover configuration
|
||||
whole_workgroup_cover_valid, whole_workgroup_cover_error = (
|
||||
validate_whole_wg_cover_configuration(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
layout,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
gpu_target,
|
||||
)
|
||||
)
|
||||
if not whole_workgroup_cover_valid:
|
||||
logging.debug(
|
||||
f"Whole workgroup cover configuration validation failed: {whole_workgroup_cover_error}"
|
||||
)
|
||||
return False, whole_workgroup_cover_error
|
||||
|
||||
# Validate CShuffleEpilogue distribution pattern (for cshuffle epilogue)
|
||||
# This validation ensures the tile distribution pattern is valid for the output tile
|
||||
cshuffle_valid, cshuffle_error = validate_cshuffle_epilogue_distribution(
|
||||
tile_m,
|
||||
tile_n,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_size,
|
||||
c_datatype,
|
||||
)
|
||||
if not cshuffle_valid:
|
||||
logging.debug(f"CShuffleEpilogue validation failed: {cshuffle_error}")
|
||||
return False, cshuffle_error
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user