mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Ck tile engine preshuffle (#2919)
* Partial Progress : Preshuffle working code for datatype * Partial Progress : Preshuffle Cleanup * Working code for default config with min max step * Partial Progress : PermuteN implemented in validation * Partial Progress : PermuteN changes in Preshuffle * CK Tile Engine Preshuffle Complete * CK TILE ENGINE : Preshuffle Layout validation * CK Tile Engine Preshuffle Validation * Preshuffle Validation check * CK Tile Engine Preshuffle : Fixing Validation Cases * Addressing PR review Comments * Changes in config * Addressing Review Comments * Adding additional architecture in Jenkins * Partial Progress : Selective Datatype and layouts * Limited datatypes and layouts * Addressing CI errors * Datatype updates * Datatype updates * Datatype changes to Preshuffle * Addressing Review Comments * Addressing Review Comments * Datatype changes * Changes to Cmake * Update on Jenkins * Formatting with precommit * Ruff Formatting
This commit is contained in:
committed by
GitHub
parent
6d709dac41
commit
8b185e872e
@@ -32,7 +32,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
@@ -40,7 +39,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
@@ -52,7 +50,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
@@ -60,7 +57,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
@@ -73,7 +69,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
@@ -81,7 +76,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [
|
||||
@@ -122,6 +116,12 @@ def element_size(data_type: str) -> float:
|
||||
|
||||
def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool:
|
||||
"""Check if a trait combination is valid."""
|
||||
if pipeline not in ["preshufflev2"]:
|
||||
raise ValueError("Accepted pipeline values are: ['preshufflev2']")
|
||||
if epilogue not in ["default", "cshuffle"]:
|
||||
return ValueError("Accepted epilogue values are: ['default', 'cshuffle']")
|
||||
if scheduler not in ["default"]:
|
||||
return ValueError("Accepted scheduler values are: ['default']")
|
||||
return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS
|
||||
|
||||
|
||||
@@ -173,7 +173,7 @@ def validate_lds_capacity(
|
||||
matrix_b_size = (tile_n * tile_k) * element_size(b_datatype)
|
||||
total_tile_in_lds = matrix_a_size + matrix_b_size
|
||||
|
||||
max_tile_size = 2**15 if pipeline == "compv4" else 2**16
|
||||
max_tile_size = 2**15 if pipeline in ["preshufflev2", "compv4"] else 2**16
|
||||
|
||||
if total_tile_in_lds > max_tile_size:
|
||||
error_msg = (
|
||||
@@ -266,6 +266,35 @@ def is_tile_config_valid(
|
||||
if warp_k * warp_tile_k > tile_k:
|
||||
return False
|
||||
|
||||
# Validate vector load alignment
|
||||
m_iter_per_warp = tile_m / (warp_m * warp_tile_m)
|
||||
vector_valid, vector_error = validate_vector_load_alignment(
|
||||
warp_tile_m,
|
||||
warp_tile_k,
|
||||
a_datatype,
|
||||
m_iter_per_warp,
|
||||
wave_size=64,
|
||||
vector_load_size=16,
|
||||
)
|
||||
if not vector_valid:
|
||||
logging.debug(f"Vector load alignment failed: {vector_error}")
|
||||
return False
|
||||
|
||||
# Validate M0, M1, M2 configuration for matrix A row-major layout
|
||||
m0_m1_m2_valid, m0_m1_m2_error = validate_m0_m1_m2_configuration(
|
||||
tile_m,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
a_datatype,
|
||||
vector_load_size=16,
|
||||
warp_size=64,
|
||||
)
|
||||
if not m0_m1_m2_valid:
|
||||
logging.debug(f"M0/M1/M2 configuration validation failed: {m0_m1_m2_error}")
|
||||
return False
|
||||
|
||||
# Validate warp configuration
|
||||
if not validate_warp_configuration(warp_m, warp_n, warp_k):
|
||||
logging.debug(
|
||||
@@ -318,12 +347,117 @@ def is_tile_config_valid(
|
||||
return True
|
||||
|
||||
|
||||
def validate_vector_load_alignment(
|
||||
wg_m: int,
|
||||
wg_k: int,
|
||||
a_datatype: str,
|
||||
m_iter_per_warp: int,
|
||||
wave_size: int,
|
||||
vector_load_size: int,
|
||||
) -> Tuple[bool, str]:
|
||||
try:
|
||||
# Calculate the memory access pattern size
|
||||
a_element_size = element_size(a_datatype)
|
||||
access_size = (wg_m * wg_k * a_element_size * m_iter_per_warp) / wave_size
|
||||
|
||||
# Check if it's aligned to vector load size
|
||||
if access_size % vector_load_size != 0:
|
||||
error_msg = (
|
||||
f"Vector load alignment violation: "
|
||||
f"({wg_m} * {wg_k} * {a_element_size} * {m_iter_per_warp} / {wave_size}) "
|
||||
f"% {vector_load_size} = {access_size % vector_load_size} != 0. "
|
||||
f"Access size: {access_size} bytes"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
return False, f"Error in vector load validation: {str(e)}"
|
||||
|
||||
|
||||
def validate_m0_m1_m2_configuration(
|
||||
tile_m: int,
|
||||
tile_k: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
a_datatype: str,
|
||||
vector_load_size: int = 16,
|
||||
warp_size: int = 64,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Validate M0, M1, M2 configuration for matrix A row-major layout.
|
||||
This ensures proper memory access pattern alignment.
|
||||
"""
|
||||
try:
|
||||
# Validation for A as row-major
|
||||
MPerBlock = tile_m
|
||||
|
||||
# Calculate K1 using element size
|
||||
K1 = vector_load_size / element_size(a_datatype)
|
||||
|
||||
# Check if K1 is valid (must be integer)
|
||||
if K1 != int(K1):
|
||||
return (
|
||||
False,
|
||||
f"K1 = {K1} is not an integer. vector_load_size({vector_load_size}) must be divisible by element_size({a_datatype})",
|
||||
)
|
||||
K1 = int(K1)
|
||||
|
||||
# Calculate K0
|
||||
if tile_k % K1 != 0:
|
||||
return False, f"tile_k({tile_k}) must be divisible by K1({K1})"
|
||||
K0 = tile_k // K1
|
||||
|
||||
# Calculate M2
|
||||
if warp_size % K0 != 0:
|
||||
return False, f"warp_size({warp_size}) must be divisible by K0({K0})"
|
||||
M2 = warp_size // K0
|
||||
|
||||
# Calculate number of warps and block size
|
||||
NumWarps = warp_m * warp_n * warp_k
|
||||
BlockSize = NumWarps * warp_size
|
||||
|
||||
# Calculate M0 (assuming get_warp_size() returns warp_size)
|
||||
M0 = BlockSize // warp_size # This should equal NumWarps
|
||||
|
||||
# Calculate M1
|
||||
if (M2 * M0) == 0:
|
||||
return False, f"M2({M2}) * M0({M0}) cannot be zero"
|
||||
|
||||
if MPerBlock % (M2 * M0) != 0:
|
||||
return (
|
||||
False,
|
||||
f"MPerBlock({MPerBlock}) must be divisible by M2({M2}) * M0({M0}) = {M2 * M0}",
|
||||
)
|
||||
M1 = MPerBlock // (M2 * M0)
|
||||
|
||||
# Validate the assertion: M0 * M1 * M2 == MPerBlock
|
||||
calculated_m_per_block = M0 * M1 * M2
|
||||
if calculated_m_per_block != MPerBlock:
|
||||
error_msg = (
|
||||
f"Incorrect M0, M1, M2 configuration! "
|
||||
f"M0({M0}) * M1({M1}) * M2({M2}) = {calculated_m_per_block} != MPerBlock({MPerBlock}). "
|
||||
f"Configuration: K0={K0}, K1={K1}, NumWarps={NumWarps}, BlockSize={BlockSize}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, ""
|
||||
|
||||
except ZeroDivisionError as e:
|
||||
return False, f"Division by zero in M0/M1/M2 calculation: {str(e)}"
|
||||
except Exception as e:
|
||||
return False, f"Error in M0/M1/M2 validation: {str(e)}"
|
||||
|
||||
|
||||
# [TODO] Handle this while moving code to commons Add more datatype to this function if needed
|
||||
def get_dtype_string(datatype: str) -> str:
|
||||
"""Get C++ type string for datatype"""
|
||||
dtype_map = {
|
||||
"fp16": "ck_tile::fp16_t",
|
||||
"fp8": "ck_tile::fp8_t",
|
||||
"bf8": "ck_tile::bf8_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"fp32": "float",
|
||||
"fp64": "double",
|
||||
|
||||
Reference in New Issue
Block a user