mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Ck tile engine gemm (#2982)
* Partial Progress : CK Tile Engine GEMM * Partial Progress : CK Tile Engine GEMM * Partial Progress : Working GEMM Code * Partial Progress : Working GEMM Code * Changinf jenkins to remove preshuffle * Partial Progress : CK TILE ENGINE GEMM Debugging * Partial Progress : Removing changes that are not GEMM * Partial Progress : Validation of full block size in GEMM * Changes in Jenkins to run only fp16 and bf16 * Addressing Review Comments * Partial Progress : Addressing CI issues * Partial Progress - Runing GEMM for fp16,bf16 and rcr * Clang * Adding fp8 and bf8 * Adding fp8 and bf8 * Adding additional architrcture * Limited datatypes and layouts * Adding k_block_per_cu in test config * Changes to faling CI errors * Changes to faling CI errors * Validation for GEMM * Adding Layout support * Adding Validations * Adding layout in jenkins * Update on Jenkins * Distribution validation for GEMM * Resolving merge conflicts * Solving merge conflicts
This commit is contained in:
committed by
GitHub
parent
b11f53a484
commit
7fc0a38e90
@@ -65,7 +65,7 @@ function(create_individual_gemm_target datatype layout trait tile_config config_
|
||||
# Create the executable
|
||||
add_executable(${target_name}
|
||||
EXCLUDE_FROM_ALL
|
||||
${GEMM_SOURCE_DIR}/benchmark_gemm_single.cpp
|
||||
${GEMM_SOURCE_DIR}/gemm_benchmark_single.cpp
|
||||
${instance_header}
|
||||
)
|
||||
|
||||
@@ -103,9 +103,9 @@ function(create_individual_gemm_target datatype layout trait tile_config config_
|
||||
list(GET trait_parts 1 epilogue)
|
||||
list(GET trait_parts 2 scheduler)
|
||||
|
||||
add_dependencies(benchmark_gemm_${pipeline} ${target_name})
|
||||
add_dependencies(benchmark_gemm_${epilogue} ${target_name})
|
||||
add_dependencies(benchmark_gemm_${scheduler} ${target_name})
|
||||
add_dependencies(benchmark_gemm_${pipeline}_pipeline ${target_name})
|
||||
add_dependencies(benchmark_gemm_${epilogue}_epilogue ${target_name})
|
||||
add_dependencies(benchmark_gemm_${scheduler}_scheduler ${target_name})
|
||||
endfunction()
|
||||
|
||||
# Function to build individual GEMM targets
|
||||
@@ -286,15 +286,15 @@ else()
|
||||
set(GEMM_SCHEDULERS "intrawave;interwave")
|
||||
|
||||
foreach(pipeline IN LISTS GEMM_PIPELINES)
|
||||
add_custom_target(benchmark_gemm_${pipeline})
|
||||
add_custom_target(benchmark_gemm_${pipeline}_pipeline)
|
||||
endforeach()
|
||||
|
||||
foreach(epilogue IN LISTS GEMM_EPILOGUES)
|
||||
add_custom_target(benchmark_gemm_${epilogue})
|
||||
add_custom_target(benchmark_gemm_${epilogue}_epilogue)
|
||||
endforeach()
|
||||
|
||||
foreach(scheduler IN LISTS GEMM_SCHEDULERS)
|
||||
add_custom_target(benchmark_gemm_${scheduler})
|
||||
add_custom_target(benchmark_gemm_${scheduler}_scheduler)
|
||||
endforeach()
|
||||
|
||||
# Build individual targets for each datatype/layout combination
|
||||
|
||||
@@ -187,7 +187,7 @@ python gemm_instance_builder.py \
|
||||
--datatype fp16 \
|
||||
--layout rcr \
|
||||
--config_json configs/user_provided_config.json \
|
||||
--gen_individual
|
||||
--gen_all_individual
|
||||
```
|
||||
|
||||
#### gemm_instance_builder_parallel.py
|
||||
|
||||
@@ -23,6 +23,31 @@ ELEMENT_SIZE_MAP = {
|
||||
"fp64": 8,
|
||||
}
|
||||
|
||||
WARP_SUPPORTED_COMBINATIONS = {
|
||||
"gfx90a": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1],
|
||||
],
|
||||
"gfx942": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1],
|
||||
],
|
||||
"gfx950": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1],
|
||||
],
|
||||
"gfx1201": [
|
||||
[2, 4, 1],
|
||||
[1, 8, 1],
|
||||
[8, 1, 1],
|
||||
[4, 2, 1],
|
||||
],
|
||||
}
|
||||
|
||||
# [TODO] Handle this while moving code to commons
|
||||
# Supported warp tile combinations for different GPU architectures and data types
|
||||
WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
"gfx90a": {
|
||||
@@ -290,6 +315,7 @@ def is_tile_config_valid(
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
pipeline: str,
|
||||
layout: str,
|
||||
gpu_target: str,
|
||||
trait_name: str = None,
|
||||
) -> bool:
|
||||
@@ -348,6 +374,24 @@ def is_tile_config_valid(
|
||||
logging.debug(f"LDS validation failed: {lds_error}")
|
||||
return False
|
||||
|
||||
# Validate whole workgroup cover configuration
|
||||
wr_cover_valid, wg_cover_error = validate_whole_wg_cover_configuration(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
layout,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
)
|
||||
if not wr_cover_valid:
|
||||
logging.debug(
|
||||
f"Whole workgroup cover configuration validation failed: {wg_cover_error}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Validate warp tile combination
|
||||
warp_tile_valid, warp_tile_error = validate_warp_tile_combination(
|
||||
warp_tile_m,
|
||||
@@ -363,3 +407,209 @@ def is_tile_config_valid(
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# [TODO] Handle this while moving code to commons Add more datatype to this function if needed
|
||||
def get_dtype_string(datatype: str) -> str:
|
||||
"""Get C++ type string for datatype"""
|
||||
dtype_map = {
|
||||
"fp16": "ck_tile::fp16_t",
|
||||
"fp8": "ck_tile::fp8_t",
|
||||
"bf8": "ck_tile::bf8_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"fp32": "float",
|
||||
"fp64": "double",
|
||||
}
|
||||
return dtype_map.get(datatype, "float")
|
||||
|
||||
|
||||
LAYOUT_MAP = {
|
||||
"r": "ck_tile::tensor_layout::gemm::RowMajor",
|
||||
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
|
||||
}
|
||||
|
||||
|
||||
def get_abc_layouts(layout_code: str) -> Tuple[str, str, str]:
|
||||
"""
|
||||
Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'.
|
||||
"""
|
||||
code = str(layout_code).strip().lower()
|
||||
|
||||
a_layout = LAYOUT_MAP[code[0]]
|
||||
b_layout = LAYOUT_MAP[code[1]]
|
||||
c_layout = LAYOUT_MAP[code[2]]
|
||||
return a_layout, b_layout, c_layout
|
||||
|
||||
|
||||
def validate_whole_wg_cover_configuration(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
layout,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
) -> Tuple[bool, str]:
|
||||
# Validate whole workgroup cover configuration
|
||||
|
||||
warp_size = 64
|
||||
NumWarps = warp_m * warp_n * warp_k
|
||||
BlockSize = NumWarps * warp_size
|
||||
|
||||
XPerTile = 0
|
||||
YPerTile = 0
|
||||
vector_load_size = 0
|
||||
|
||||
# A matrix validation
|
||||
if layout[0] == "r":
|
||||
XPerTile = tile_k
|
||||
YPerTile = tile_m
|
||||
|
||||
vector_load_size = get_global_vector_load_size(
|
||||
BlockSize, tile_k, a_datatype, tile_m, tile_k
|
||||
)
|
||||
|
||||
elif layout[0] == "c":
|
||||
vector_load_size = get_global_vector_load_size(
|
||||
BlockSize, tile_k, a_datatype, tile_m, tile_m
|
||||
)
|
||||
|
||||
# Validate distribution
|
||||
XPerTile = tile_k
|
||||
YPerTile = tile_m
|
||||
|
||||
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
|
||||
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
|
||||
)
|
||||
|
||||
if not wg_cover_core_valid:
|
||||
print("I am here 1")
|
||||
logging.debug(
|
||||
f"whole workgroup cover failed for Matrix A distribution: {wg_cover_core_error}"
|
||||
)
|
||||
return False, wg_cover_core_error
|
||||
|
||||
XPerTile = tile_m
|
||||
YPerTile = tile_k
|
||||
|
||||
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
|
||||
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
|
||||
)
|
||||
|
||||
if not wg_cover_core_valid:
|
||||
logging.debug(
|
||||
f"whole workgroup cover failed for Matrix A: {wg_cover_core_error}"
|
||||
)
|
||||
return False, wg_cover_core_error
|
||||
|
||||
# B matrix validation
|
||||
if layout[1] == "r":
|
||||
vector_load_size = get_global_vector_load_size(
|
||||
BlockSize, tile_k, b_datatype, tile_n, tile_n
|
||||
)
|
||||
|
||||
# Validate distribution
|
||||
XPerTile = tile_k
|
||||
YPerTile = tile_n
|
||||
|
||||
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
|
||||
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
|
||||
)
|
||||
|
||||
if not wg_cover_core_valid:
|
||||
print("I am here 3")
|
||||
logging.debug(
|
||||
f"whole workgroup cover failed for Matrix A distribution: {wg_cover_core_error}"
|
||||
)
|
||||
return False, wg_cover_core_error
|
||||
|
||||
XPerTile = tile_n
|
||||
YPerTile = tile_k
|
||||
|
||||
elif layout[1] == "c":
|
||||
XPerTile = tile_k
|
||||
YPerTile = tile_n
|
||||
|
||||
vector_load_size = get_global_vector_load_size(
|
||||
BlockSize, tile_k, b_datatype, tile_n, tile_k
|
||||
)
|
||||
|
||||
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
|
||||
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
|
||||
)
|
||||
if not wg_cover_core_valid:
|
||||
print("I am here 4")
|
||||
logging.debug(
|
||||
f"whole workgroup cover failed for Matrix B: {wg_cover_core_error}"
|
||||
)
|
||||
return False, wg_cover_core_error
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def wg_cover_core_validation(
|
||||
XPerTile: int,
|
||||
YPerTile: int,
|
||||
BlockSize: int,
|
||||
vector_load_size: int,
|
||||
warp_size: int,
|
||||
) -> Tuple[bool, str]:
|
||||
if XPerTile % vector_load_size != 0:
|
||||
return False
|
||||
|
||||
num_warps = BlockSize / warp_size
|
||||
LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size)
|
||||
|
||||
X1 = LargestVec if vector_load_size > LargestVec else vector_load_size
|
||||
X0 = XPerTile / X1
|
||||
Y1 = warp_size // X0
|
||||
|
||||
if X0 * Y1 != warp_size:
|
||||
return False, ""
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def get_global_vector_load_size(
|
||||
BlockSize: int,
|
||||
KPerBlock: int,
|
||||
DataType: str,
|
||||
MNPerBlock: int,
|
||||
XPerTile: int,
|
||||
) -> int:
|
||||
elements_per_thread = MNPerBlock * KPerBlock / BlockSize
|
||||
PackedSize = 1
|
||||
|
||||
if (
|
||||
XPerTile % (PackedSize * 32 / element_size(DataType)) == 0
|
||||
and elements_per_thread % (PackedSize * 32 / element_size(DataType)) == 0
|
||||
and PackedSize == 2
|
||||
):
|
||||
return PackedSize * 32 / element_size(DataType)
|
||||
elif (
|
||||
XPerTile % (PackedSize * 16 / element_size(DataType)) == 0
|
||||
and elements_per_thread % (PackedSize * 16 / element_size(DataType)) == 0
|
||||
):
|
||||
return int(PackedSize * 16 / element_size(DataType))
|
||||
|
||||
elif (
|
||||
XPerTile % (PackedSize * 8 / element_size(DataType)) == 0
|
||||
and elements_per_thread % (PackedSize * 8 / element_size(DataType)) == 0
|
||||
):
|
||||
return int(PackedSize * 8 / element_size(DataType))
|
||||
elif (
|
||||
element_size(DataType) >= PackedSize * 4
|
||||
and XPerTile % (PackedSize * 4 / element_size(DataType)) == 0
|
||||
and elements_per_thread % (PackedSize * 4 / element_size(DataType)) == 0
|
||||
):
|
||||
return int(PackedSize * 4 / element_size(DataType))
|
||||
elif (
|
||||
element_size(DataType) >= PackedSize * 2
|
||||
and XPerTile % (PackedSize * 2 / element_size(DataType)) == 0
|
||||
and elements_per_thread % (PackedSize * 2 / element_size(DataType)) == 0
|
||||
):
|
||||
return int(PackedSize * 2 / element_size(DataType))
|
||||
else:
|
||||
return PackedSize
|
||||
@@ -1,105 +0,0 @@
|
||||
{
|
||||
"problem": {
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64
|
||||
},
|
||||
"tile_n": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64
|
||||
},
|
||||
"tile_k": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
4,
|
||||
16,
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16,
|
||||
32,
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"compv4",
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave",
|
||||
"interwave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle",
|
||||
"default"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
false,
|
||||
true
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
{
|
||||
"problem": {
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
128 ]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"default"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
false,
|
||||
true
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,4 @@
|
||||
{
|
||||
"problem": {
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"max": 256,
|
||||
@@ -101,5 +99,6 @@
|
||||
true
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"k_block_per_cu": 1
|
||||
}
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
{
|
||||
"problem": {
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
256,
|
||||
128,
|
||||
64
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
256,
|
||||
128,
|
||||
64
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
256,
|
||||
128,
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave",
|
||||
"interwave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle",
|
||||
"default"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
false,
|
||||
true
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,31 +1,28 @@
|
||||
{
|
||||
"problem": {
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
128,
|
||||
256
|
||||
64
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
128
|
||||
192
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
128
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
4
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
1
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
@@ -35,36 +32,33 @@
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
16, 32
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16, 32
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16, 32
|
||||
8
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"mem"
|
||||
"compv4"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave",
|
||||
"interwave"
|
||||
"intrawave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"default",
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
@@ -85,9 +79,9 @@
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
false,
|
||||
true
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"k_block_per_cu": 1
|
||||
}
|
||||
@@ -273,48 +273,6 @@ class GemmBenchmark:
|
||||
print(f"Error reading JSON file {json_file}: {e}")
|
||||
return None
|
||||
|
||||
def parse_benchmark_output(self, output: str) -> Optional[Dict]:
|
||||
"""Parse the benchmark output format - extract JSON directly"""
|
||||
try:
|
||||
# Find JSON block between asterisk markers
|
||||
lines = output.split("\n")
|
||||
json_start = -1
|
||||
json_end = -1
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip().startswith("{"):
|
||||
json_start = i
|
||||
elif line.strip().endswith("}") and json_start != -1:
|
||||
json_end = i
|
||||
break
|
||||
|
||||
if json_start != -1 and json_end != -1:
|
||||
json_text = "\n".join(lines[json_start : json_end + 1])
|
||||
data = json.loads(json_text)
|
||||
|
||||
# Return the complete JSON data as-is, just add some convenience fields
|
||||
result = data.copy()
|
||||
if "perf_result" in data:
|
||||
perf = data["perf_result"]
|
||||
# Add convenience fields for backward compatibility
|
||||
result["time_ms"] = perf.get("latency(ms)", 0)
|
||||
result["tflops"] = perf.get("tflops(TFlops)", 0)
|
||||
result["bandwidth_gb_s"] = perf.get("bandwidth(GB/s)", 0)
|
||||
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
if self.verbose:
|
||||
print(f"Failed to parse JSON: {e}")
|
||||
print(f"Output was: {output[:200]}...")
|
||||
return None
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print(f"Error parsing output: {e}")
|
||||
return None
|
||||
|
||||
def benchmark_problem_size(
|
||||
self,
|
||||
kernels: List[Path],
|
||||
|
||||
@@ -30,9 +30,9 @@ inline auto create_args(int argc, char* argv[])
|
||||
.insert("stride_c", "0", "The stride value for tensor C. Default is 0.")
|
||||
.insert("split_k", "1", "The split value for k dimension. Default is 1.")
|
||||
.insert("verify",
|
||||
"0",
|
||||
"2",
|
||||
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 "
|
||||
"for validation on GPU. Default is 0, no validation.")
|
||||
"for validation on GPU. Default is 2, GPU validation.")
|
||||
.insert("log",
|
||||
"false",
|
||||
"Whether output kernel instance information or not. Possible values are true or "
|
||||
@@ -75,7 +75,7 @@ inline auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
void benchmark_gemm_single(const ck_tile::ArgParser& arg_parser)
|
||||
void benchmark_single(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
// Use DataTypeTraits to get the actual type names from the generated header
|
||||
// The generated header defines ADataType, BDataType, AccDataType, CDataType
|
||||
@@ -149,7 +149,7 @@ int main(int argc, char* argv[])
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
benchmark_gemm_single(parser);
|
||||
benchmark_single(parser);
|
||||
return 0;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
|
||||
//[TODO] This can be moved to commons
|
||||
// DataTypeTraits for all supported types
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
@@ -97,49 +98,3 @@ struct KernelTraits
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
// Helper to extract traits from kernel name
|
||||
inline KernelTraits extract_traits_from_name(const std::string& kernel_name)
|
||||
{
|
||||
KernelTraits traits;
|
||||
|
||||
// Extract pipeline
|
||||
if(kernel_name.find("compv3") != std::string::npos)
|
||||
{
|
||||
traits.pipeline = "compv3";
|
||||
}
|
||||
else if(kernel_name.find("compv4") != std::string::npos)
|
||||
{
|
||||
traits.pipeline = "compv4";
|
||||
}
|
||||
else if(kernel_name.find("mem") != std::string::npos)
|
||||
{
|
||||
traits.pipeline = "mem";
|
||||
}
|
||||
|
||||
// Extract scheduler
|
||||
if(kernel_name.find("interwave") != std::string::npos)
|
||||
{
|
||||
traits.scheduler = "interwave";
|
||||
}
|
||||
else
|
||||
{
|
||||
traits.scheduler = "intrawave";
|
||||
}
|
||||
|
||||
// Extract epilogue
|
||||
if(kernel_name.find("default") != std::string::npos &&
|
||||
kernel_name.find("default_") == std::string::npos)
|
||||
{
|
||||
traits.epilogue = "default";
|
||||
}
|
||||
else
|
||||
{
|
||||
traits.epilogue = "cshuffle";
|
||||
}
|
||||
|
||||
// Padding flags would need to be extracted from the kernel configuration
|
||||
// For now, we'll leave them as false
|
||||
|
||||
return traits;
|
||||
}
|
||||
|
||||
@@ -8,8 +8,12 @@ import multiprocessing
|
||||
import concurrent.futures
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from typing import Optional
|
||||
from validation_utils import is_tile_config_valid, is_trait_combination_valid
|
||||
from commons.validation_utils import (
|
||||
is_tile_config_valid,
|
||||
is_trait_combination_valid,
|
||||
get_dtype_string,
|
||||
get_abc_layouts,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
@@ -29,149 +33,150 @@ class GemmKernelBuilder:
|
||||
if config_json and os.path.exists(config_json):
|
||||
with open(config_json, "r") as f:
|
||||
self.config = json.load(f)
|
||||
else:
|
||||
self.config = self._get_default_config()
|
||||
|
||||
def _get_default_config(self):
|
||||
"""Return default configuration if no config file is provided"""
|
||||
# Define base tile configurations that work for all layouts
|
||||
base_fp16_configs = [
|
||||
{
|
||||
"tile_m": 256,
|
||||
"tile_n": 256,
|
||||
"tile_k": 32,
|
||||
"warp_m": 2,
|
||||
"warp_n": 2,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 32,
|
||||
"warp_tile_n": 32,
|
||||
"warp_tile_k": 32,
|
||||
},
|
||||
{
|
||||
"tile_m": 256,
|
||||
"tile_n": 128,
|
||||
"tile_k": 32,
|
||||
"warp_m": 2,
|
||||
"warp_n": 2,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 32,
|
||||
"warp_tile_n": 32,
|
||||
"warp_tile_k": 16,
|
||||
},
|
||||
]
|
||||
def write_kernel_list(self):
|
||||
"""Write kernel list to file for CMake to read (with comprehensive validation)"""
|
||||
# Get configurations using comprehensive validation
|
||||
tile_configs = self._get_tile_configs(fast_mode=False)
|
||||
trait_combos = self._generate_trait_combinations()
|
||||
|
||||
base_fp8_configs = [
|
||||
{
|
||||
"tile_m": 256,
|
||||
"tile_n": 256,
|
||||
"tile_k": 32,
|
||||
"warp_m": 4,
|
||||
"warp_n": 1,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 32,
|
||||
"warp_tile_n": 32,
|
||||
"warp_tile_k": 32,
|
||||
},
|
||||
{
|
||||
"tile_m": 256,
|
||||
"tile_n": 128,
|
||||
"tile_k": 32,
|
||||
"warp_m": 1,
|
||||
"warp_n": 4,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 16,
|
||||
"warp_tile_n": 16,
|
||||
"warp_tile_k": 32,
|
||||
},
|
||||
]
|
||||
kernel_list = []
|
||||
for tile_config in tile_configs:
|
||||
for trait_combo in trait_combos:
|
||||
(
|
||||
pipeline,
|
||||
epilogue,
|
||||
scheduler,
|
||||
pad_m,
|
||||
pad_n,
|
||||
pad_k,
|
||||
persistent,
|
||||
) = trait_combo
|
||||
|
||||
# Create configurations for all supported layouts
|
||||
all_layouts = ["rcr", "rrr", "ccr", "crr"]
|
||||
tile_configs = {}
|
||||
# Create kernel name with proper boolean capitalization
|
||||
kernel_name = f"gemm_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
|
||||
|
||||
for datatype, base_configs in [
|
||||
("fp16", base_fp16_configs),
|
||||
("fp8", base_fp8_configs),
|
||||
]:
|
||||
tile_configs[datatype] = {}
|
||||
for layout in all_layouts:
|
||||
tile_configs[datatype][layout] = base_configs
|
||||
# Create tile configuration string
|
||||
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
||||
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
||||
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
||||
|
||||
return {
|
||||
"tile_configs": tile_configs,
|
||||
"traits": {
|
||||
"pipelines": ["mem", "compv3", "compv4"],
|
||||
"epilogues": ["default", "cshuffle"],
|
||||
"schedulers": ["intrawave", "interwave"],
|
||||
},
|
||||
"structured_sparsity": ["false"],
|
||||
"padding": {"pad_m": ["false"], "pad_n": ["false"], "pad_k": ["false"]},
|
||||
"persistent": ["false"],
|
||||
}
|
||||
kernel_name += f"_{tile_str}"
|
||||
|
||||
kernel_list.append(
|
||||
{
|
||||
"name": kernel_name,
|
||||
"tile_config": tile_config,
|
||||
"trait_combo": trait_combo,
|
||||
}
|
||||
)
|
||||
|
||||
# Write kernel count
|
||||
with open(self.working_path / "gemm_kernel_count.txt", "w") as f:
|
||||
f.write(str(len(kernel_list)))
|
||||
|
||||
# Write kernel list
|
||||
with open(self.working_path / "gemm_kernel_list.txt", "w") as f:
|
||||
for kernel in kernel_list:
|
||||
# Format: kernel_name|tile_config|trait_combo
|
||||
tile_config = kernel["tile_config"]
|
||||
trait_combo = kernel["trait_combo"]
|
||||
|
||||
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
||||
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
||||
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
||||
|
||||
trait_str = (
|
||||
f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_"
|
||||
+ "_".join(str(x) for x in trait_combo[3:])
|
||||
)
|
||||
|
||||
f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n")
|
||||
|
||||
print(f"Listed {len(kernel_list)} kernel configurations")
|
||||
|
||||
def _get_tile_configs(self, fast_mode=False):
|
||||
"""Get tile configurations for the current datatype and layout"""
|
||||
if "tile_configs" in self.config:
|
||||
# Old format
|
||||
return (
|
||||
self.config["tile_configs"].get(self.datatype, {}).get(self.layout, [])
|
||||
tile_config = self.config["tile_config"]
|
||||
|
||||
# Generate values in the config if default range is given
|
||||
if tile_config.get("tile_m").get("values") is None:
|
||||
tile_config.get("tile_m")["values"] = self._generate_values(
|
||||
tile_config.get("tile_m").get("min"),
|
||||
tile_config.get("tile_m").get("max"),
|
||||
tile_config.get("tile_m").get("step"),
|
||||
)
|
||||
if tile_config.get("tile_n").get("values") is None:
|
||||
tile_config.get("tile_n")["values"] = self._generate_values(
|
||||
tile_config.get("tile_n").get("min"),
|
||||
tile_config.get("tile_n").get("max"),
|
||||
tile_config.get("tile_n").get("step"),
|
||||
)
|
||||
if tile_config.get("tile_k").get("values") is None:
|
||||
tile_config.get("tile_k")["values"] = self._generate_values(
|
||||
tile_config.get("tile_k").get("min"),
|
||||
tile_config.get("tile_k").get("max"),
|
||||
tile_config.get("tile_k").get("step"),
|
||||
)
|
||||
elif "tile_config" in self.config:
|
||||
# New format - generate combinations from individual parameter values
|
||||
tile_config = self.config["tile_config"]
|
||||
|
||||
# Get all possible values for each parameter
|
||||
tile_m_values = tile_config.get("tile_m", {}).get("values", [256])
|
||||
tile_n_values = tile_config.get("tile_n", {}).get("values", [256])
|
||||
tile_k_values = tile_config.get("tile_k", {}).get("values", [32])
|
||||
warp_m_values = tile_config.get("warp_m", {}).get("values", [2])
|
||||
warp_n_values = tile_config.get("warp_n", {}).get("values", [2])
|
||||
warp_k_values = tile_config.get("warp_k", {}).get("values", [1])
|
||||
warp_tile_m_values = tile_config.get("warp_tile_m", {}).get("values", [32])
|
||||
warp_tile_n_values = tile_config.get("warp_tile_n", {}).get("values", [32])
|
||||
warp_tile_k_values = tile_config.get("warp_tile_k", {}).get("values", [32])
|
||||
# Get all possible values for each parameter
|
||||
tile_m_values = tile_config.get("tile_m").get("values")
|
||||
tile_n_values = tile_config.get("tile_n").get("values")
|
||||
tile_k_values = tile_config.get("tile_k").get("values")
|
||||
warp_m_values = tile_config.get("warp_m").get("values")
|
||||
warp_n_values = tile_config.get("warp_n").get("values")
|
||||
warp_k_values = tile_config.get("warp_k").get("values")
|
||||
warp_tile_m_values = tile_config.get("warp_tile_m").get("values")
|
||||
warp_tile_n_values = tile_config.get("warp_tile_n").get("values")
|
||||
warp_tile_k_values = tile_config.get("warp_tile_k").get("values")
|
||||
|
||||
# Generate all combinations
|
||||
configs = []
|
||||
for tile_m in tile_m_values:
|
||||
for tile_n in tile_n_values:
|
||||
for tile_k in tile_k_values:
|
||||
for warp_m in warp_m_values:
|
||||
for warp_n in warp_n_values:
|
||||
for warp_k in warp_k_values:
|
||||
for warp_tile_m in warp_tile_m_values:
|
||||
for warp_tile_n in warp_tile_n_values:
|
||||
for warp_tile_k in warp_tile_k_values:
|
||||
# Validate configuration
|
||||
if self._validate_tile_config(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
fast_mode=fast_mode,
|
||||
):
|
||||
configs.append(
|
||||
{
|
||||
"tile_m": tile_m,
|
||||
"tile_n": tile_n,
|
||||
"tile_k": tile_k,
|
||||
"warp_m": warp_m,
|
||||
"warp_n": warp_n,
|
||||
"warp_k": warp_k,
|
||||
"warp_tile_m": warp_tile_m,
|
||||
"warp_tile_n": warp_tile_n,
|
||||
"warp_tile_k": warp_tile_k,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
else:
|
||||
# Fallback to default
|
||||
return []
|
||||
# Generate all combinations
|
||||
configs = []
|
||||
for tile_m in tile_m_values:
|
||||
for tile_n in tile_n_values:
|
||||
for tile_k in tile_k_values:
|
||||
for warp_m in warp_m_values:
|
||||
for warp_n in warp_n_values:
|
||||
for warp_k in warp_k_values:
|
||||
for warp_tile_m in warp_tile_m_values:
|
||||
for warp_tile_n in warp_tile_n_values:
|
||||
for warp_tile_k in warp_tile_k_values:
|
||||
# Validate configuration
|
||||
if self._validate_tile_config(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
fast_mode=fast_mode,
|
||||
):
|
||||
configs.append(
|
||||
{
|
||||
"tile_m": tile_m,
|
||||
"tile_n": tile_n,
|
||||
"tile_k": tile_k,
|
||||
"warp_m": warp_m,
|
||||
"warp_n": warp_n,
|
||||
"warp_k": warp_k,
|
||||
"warp_tile_m": warp_tile_m,
|
||||
"warp_tile_n": warp_tile_n,
|
||||
"warp_tile_k": warp_tile_k,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
def _generate_values(self, min_val, max_val, step):
|
||||
"""Generate a list of values from min to max with the given step"""
|
||||
values = []
|
||||
val = min_val
|
||||
while val <= max_val:
|
||||
values.append(val)
|
||||
val += step
|
||||
return values
|
||||
|
||||
def _validate_tile_config(
|
||||
self,
|
||||
@@ -184,7 +189,7 @@ class GemmKernelBuilder:
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
pipeline="mem", # Default pipeline for validation
|
||||
pipeline="compv4", # Default pipeline for validation
|
||||
fast_mode=False, # Add fast mode option
|
||||
):
|
||||
"""Validate that tile configuration is reasonable"""
|
||||
@@ -213,6 +218,8 @@ class GemmKernelBuilder:
|
||||
b_datatype = self.datatype
|
||||
c_datatype = self.datatype
|
||||
|
||||
layout = self.layout
|
||||
|
||||
# Special handling for certain data types
|
||||
if self.datatype in ["fp8", "bf8"]:
|
||||
c_datatype = "fp16"
|
||||
@@ -232,125 +239,50 @@ class GemmKernelBuilder:
|
||||
b_datatype,
|
||||
c_datatype,
|
||||
pipeline,
|
||||
layout,
|
||||
self.gpu_target,
|
||||
)
|
||||
|
||||
def _generate_trait_combinations(self):
|
||||
"""Generate all combinations of traits"""
|
||||
if "traits" in self.config:
|
||||
# Old format
|
||||
traits = self.config["traits"]
|
||||
pipelines = traits["pipelines"]
|
||||
epilogues = traits["epilogues"]
|
||||
schedulers = traits["schedulers"]
|
||||
|
||||
padding = self.config["padding"]
|
||||
persistent = self.config["persistent"]
|
||||
trait_config = self.config["trait_config"]
|
||||
|
||||
all_combinations = list(
|
||||
itertools.product(
|
||||
pipelines,
|
||||
epilogues,
|
||||
schedulers,
|
||||
padding["pad_m"],
|
||||
padding["pad_n"],
|
||||
padding["pad_k"],
|
||||
persistent,
|
||||
pipelines = trait_config.get("pipeline").get("values")
|
||||
epilogues = trait_config.get("epilogue").get("values")
|
||||
schedulers = trait_config.get("scheduler").get("values")
|
||||
pad_m_values = trait_config.get("pad_m").get("values")
|
||||
pad_n_values = trait_config.get("pad_n").get("values")
|
||||
pad_k_values = trait_config.get("pad_k").get("values")
|
||||
persistent_values = trait_config.get("persistent").get("values")
|
||||
|
||||
all_combinations = list(
|
||||
itertools.product(
|
||||
pipelines,
|
||||
epilogues,
|
||||
schedulers,
|
||||
pad_m_values,
|
||||
pad_n_values,
|
||||
pad_k_values,
|
||||
persistent_values,
|
||||
)
|
||||
)
|
||||
|
||||
# Filter out unsupported trait combinations
|
||||
combinations = []
|
||||
for combo in all_combinations:
|
||||
pipeline, epilogue, scheduler = combo[:3]
|
||||
if is_trait_combination_valid(pipeline, epilogue, scheduler):
|
||||
combinations.append(combo)
|
||||
else:
|
||||
logging.debug(
|
||||
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
|
||||
)
|
||||
)
|
||||
|
||||
# Filter out unsupported trait combinations
|
||||
combinations = []
|
||||
for combo in all_combinations:
|
||||
pipeline, epilogue, scheduler = combo[:3]
|
||||
if is_trait_combination_valid(pipeline, epilogue, scheduler):
|
||||
combinations.append(combo)
|
||||
else:
|
||||
logging.debug(
|
||||
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
|
||||
)
|
||||
|
||||
elif "trait_config" in self.config:
|
||||
# New format
|
||||
trait_config = self.config["trait_config"]
|
||||
|
||||
pipelines = trait_config.get("pipeline", {}).get("values", ["mem"])
|
||||
epilogues = trait_config.get("epilogue", {}).get("values", ["default"])
|
||||
schedulers = trait_config.get("scheduler", {}).get("values", ["intrawave"])
|
||||
pad_m_values = trait_config.get("pad_m", {}).get("values", [False])
|
||||
pad_n_values = trait_config.get("pad_n", {}).get("values", [False])
|
||||
pad_k_values = trait_config.get("pad_k", {}).get("values", [False])
|
||||
persistent_values = trait_config.get("persistent", {}).get(
|
||||
"values", [False]
|
||||
)
|
||||
|
||||
all_combinations = list(
|
||||
itertools.product(
|
||||
pipelines,
|
||||
epilogues,
|
||||
schedulers,
|
||||
pad_m_values,
|
||||
pad_n_values,
|
||||
pad_k_values,
|
||||
persistent_values,
|
||||
)
|
||||
)
|
||||
|
||||
# Filter out unsupported trait combinations
|
||||
combinations = []
|
||||
for combo in all_combinations:
|
||||
pipeline, epilogue, scheduler = combo[:3]
|
||||
if is_trait_combination_valid(pipeline, epilogue, scheduler):
|
||||
combinations.append(combo)
|
||||
else:
|
||||
logging.debug(
|
||||
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
|
||||
)
|
||||
else:
|
||||
# Fallback to minimal default
|
||||
combinations = [("mem", "default", "intrawave", False, False, False, False)]
|
||||
|
||||
return combinations
|
||||
|
||||
def _get_dtype_string(self):
|
||||
"""Get C++ type string for datatype"""
|
||||
dtype_map = {
|
||||
"fp16": "ck_tile::fp16_t",
|
||||
"fp8": "ck_tile::fp8_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"fp32": "float",
|
||||
"fp64": "double",
|
||||
}
|
||||
return dtype_map.get(self.datatype, "float")
|
||||
|
||||
_LAYOUT_MAP = {
|
||||
"r": "ck_tile::tensor_layout::gemm::RowMajor",
|
||||
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
|
||||
}
|
||||
|
||||
def _get_abc_layouts(self, layout_code: Optional[str] = None):
|
||||
"""
|
||||
Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'.
|
||||
If layout_code is None, use self.layout.
|
||||
"""
|
||||
if layout_code is None:
|
||||
# fall back to the instance field
|
||||
layout_code = getattr(self, "layout", "")
|
||||
|
||||
code = str(layout_code).strip().lower()
|
||||
|
||||
if len(code) != 3 or any(ch not in self._LAYOUT_MAP for ch in code):
|
||||
raise ValueError(
|
||||
f"Invalid layout '{layout_code}'. "
|
||||
"Use a 3-letter code with 'r'/'c' (e.g., rcr, ccr, crr, rrr)."
|
||||
)
|
||||
|
||||
a_layout = self._LAYOUT_MAP[code[0]]
|
||||
b_layout = self._LAYOUT_MAP[code[1]]
|
||||
c_layout = self._LAYOUT_MAP[code[2]]
|
||||
return a_layout, b_layout, c_layout
|
||||
|
||||
def _generate_kernel_instance(self, tile_config, trait_combo, is_header=True):
|
||||
def _generate_kernel_instance(
|
||||
self, tile_config, trait_combo, k_block_per_cu, is_header=True
|
||||
):
|
||||
"""Generate a single kernel instance"""
|
||||
(
|
||||
pipeline,
|
||||
@@ -383,6 +315,13 @@ class GemmKernelBuilder:
|
||||
"compv4": "ck_tile::GemmPipelineAgBgCrCompV4",
|
||||
}
|
||||
|
||||
# Map pipeline names to base pipeline for hot loop detection
|
||||
base_pipeline_map = {
|
||||
"mem": "ck_tile::BaseGemmPipelineAgBgCrMem",
|
||||
"compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3",
|
||||
"compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4",
|
||||
}
|
||||
|
||||
# Map scheduler names to the correct enum values
|
||||
scheduler_type_map = {
|
||||
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
|
||||
@@ -392,23 +331,14 @@ class GemmKernelBuilder:
|
||||
|
||||
# Determine accumulator type based on datatype
|
||||
acc_type = "float"
|
||||
if self.datatype in ["int8", "int4"]:
|
||||
acc_type = "ck_tile::int32_t"
|
||||
|
||||
# Determine output type
|
||||
c_type = self._get_dtype_string()
|
||||
c_type = self.datatype
|
||||
if self.datatype in ["fp8", "bf8"]:
|
||||
c_type = "ck_tile::fp16_t"
|
||||
c_type = "fp16"
|
||||
|
||||
# Determine layouts based on self.layout
|
||||
a_layout, b_layout, c_layout = self._get_abc_layouts()
|
||||
|
||||
# Map pipeline names to base pipeline for hot loop detection
|
||||
base_pipeline_map = {
|
||||
"mem": "ck_tile::BaseGemmPipelineAgBgCrMem",
|
||||
"compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3",
|
||||
"compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4",
|
||||
}
|
||||
a_layout, b_layout, c_layout = get_abc_layouts(self.layout)
|
||||
|
||||
# Generate kernel instance code using the correct API
|
||||
pragma_line = "#pragma once\n" if is_header else ""
|
||||
@@ -425,10 +355,10 @@ class GemmKernelBuilder:
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
|
||||
using ADataType = {self._get_dtype_string()};
|
||||
using BDataType = {self._get_dtype_string()};
|
||||
using ADataType = {get_dtype_string(self.datatype)};
|
||||
using BDataType = {get_dtype_string(self.datatype)};
|
||||
using AccDataType = {acc_type};
|
||||
using CDataType = {c_type};
|
||||
using CDataType = {get_dtype_string(c_type)};
|
||||
|
||||
using ALayout = {a_layout};
|
||||
using BLayout = {b_layout};
|
||||
@@ -484,7 +414,7 @@ struct SelectedKernel {{
|
||||
Traits>;
|
||||
|
||||
// Base pipeline for hot loop detection
|
||||
using BaseGemmPipeline = {base_pipeline_map.get(pipeline, "ck_tile::BaseGemmPipelineAgBgCrMem")}<GemmPipelineProblem>;
|
||||
using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}<GemmPipelineProblem>;
|
||||
|
||||
static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
|
||||
const ck_tile::index_t k_grain = args.k_batch * TileK;
|
||||
@@ -498,7 +428,7 @@ struct SelectedKernel {{
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = {scheduler_type_map.get(scheduler, "ck_tile::GemmPipelineScheduler::Intrawave")};
|
||||
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};
|
||||
[[maybe_unused]] constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
@@ -514,7 +444,7 @@ struct SelectedKernel {{
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::GemmPipelineAgBgCrCompV3")}<UniversalGemmProblem>;
|
||||
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;
|
||||
|
||||
// Epilogue
|
||||
"""
|
||||
@@ -589,7 +519,7 @@ struct SelectedKernel {{
|
||||
}}
|
||||
|
||||
// Launch kernel
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr int kBlockPerCu = {k_block_per_cu};
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
stream,
|
||||
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
|
||||
@@ -616,9 +546,13 @@ struct SelectedKernel {{
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
|
||||
return kernel_name, instance_code
|
||||
|
||||
def run(self, num_workers=None):
|
||||
"""Run the builder to generate individual kernel files"""
|
||||
# Generate individual kernel files
|
||||
self.generate_individual(num_workers)
|
||||
|
||||
def generate_individual(self, num_workers=None):
|
||||
"""Generate individual kernel files for separate compilation with parallel processing"""
|
||||
if num_workers is None:
|
||||
@@ -628,6 +562,7 @@ struct SelectedKernel {{
|
||||
|
||||
tile_configs = self._get_tile_configs()
|
||||
trait_combos = self._generate_trait_combinations()
|
||||
k_block_per_cu = self.config.get("k_block_per_cu")
|
||||
|
||||
# Prepare work items for parallel processing
|
||||
work_items = []
|
||||
@@ -637,6 +572,7 @@ struct SelectedKernel {{
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
k_block_per_cu,
|
||||
self.working_path,
|
||||
self.datatype,
|
||||
self.layout,
|
||||
@@ -723,83 +659,17 @@ struct SelectedKernel {{
|
||||
with open(self.working_path / "gemm_individual_targets.cmake", "w") as f:
|
||||
f.write(cmake_code)
|
||||
|
||||
def write_kernel_list(self):
|
||||
"""Write kernel list to file for CMake to read (with comprehensive validation)"""
|
||||
# Get configurations using comprehensive validation
|
||||
tile_configs = self._get_tile_configs(fast_mode=False)
|
||||
trait_combos = self._generate_trait_combinations()
|
||||
|
||||
kernel_list = []
|
||||
for tile_config in tile_configs:
|
||||
for trait_combo in trait_combos:
|
||||
(
|
||||
pipeline,
|
||||
epilogue,
|
||||
scheduler,
|
||||
pad_m,
|
||||
pad_n,
|
||||
pad_k,
|
||||
persistent,
|
||||
) = trait_combo
|
||||
|
||||
# Create kernel name with proper boolean capitalization
|
||||
kernel_name = f"gemm_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
|
||||
|
||||
# Create tile configuration string
|
||||
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
||||
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
||||
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
||||
|
||||
kernel_name += f"_{tile_str}"
|
||||
|
||||
kernel_list.append(
|
||||
{
|
||||
"name": kernel_name,
|
||||
"tile_config": tile_config,
|
||||
"trait_combo": trait_combo,
|
||||
}
|
||||
)
|
||||
|
||||
# Write kernel count
|
||||
with open(self.working_path / "gemm_kernel_count.txt", "w") as f:
|
||||
f.write(str(len(kernel_list)))
|
||||
|
||||
# Write kernel list
|
||||
with open(self.working_path / "gemm_kernel_list.txt", "w") as f:
|
||||
for kernel in kernel_list:
|
||||
# Format: kernel_name|tile_config|trait_combo
|
||||
tile_config = kernel["tile_config"]
|
||||
trait_combo = kernel["trait_combo"]
|
||||
|
||||
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
||||
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
||||
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
||||
|
||||
trait_str = (
|
||||
f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_"
|
||||
+ "_".join(str(x) for x in trait_combo[3:])
|
||||
)
|
||||
|
||||
f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n")
|
||||
|
||||
print(f"Listed {len(kernel_list)} kernel configurations")
|
||||
|
||||
def run(self, num_workers=None):
|
||||
"""Run the builder to generate individual kernel files"""
|
||||
# Generate individual kernel files
|
||||
self.generate_individual(num_workers)
|
||||
|
||||
|
||||
def _generate_single_kernel_individual(work_item):
|
||||
"""Worker function to generate a single individual kernel file"""
|
||||
tile_config, trait_combo, working_path, datatype, layout = work_item
|
||||
tile_config, trait_combo, k_block_per_cu, working_path, datatype, layout = work_item
|
||||
|
||||
# Create a temporary builder instance for this worker
|
||||
builder = GemmKernelBuilder(working_path, datatype, layout)
|
||||
|
||||
try:
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
tile_config, trait_combo
|
||||
tile_config, trait_combo, k_block_per_cu
|
||||
)
|
||||
|
||||
# Create simplified filename without the "gemm_" prefix
|
||||
@@ -832,7 +702,7 @@ def main():
|
||||
parser.add_argument(
|
||||
"--datatype",
|
||||
required=True,
|
||||
choices=["fp16", "fp8", "bf16", "fp32", "fp64"],
|
||||
choices=["fp16", "fp8", "bf16", "bf8"],
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -846,7 +716,9 @@ def main():
|
||||
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_individual", action="store_true", help="Generate individual kernel files"
|
||||
"--gen_all_individual",
|
||||
action="store_true",
|
||||
help="Generate individual kernel files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_single", action="store_true", help="Generate a single kernel file"
|
||||
@@ -866,13 +738,27 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], (
|
||||
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])"
|
||||
)
|
||||
|
||||
layout_parts = args.layout.lower()
|
||||
assert len(layout_parts) == 3, (
|
||||
f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
|
||||
)
|
||||
assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], (
|
||||
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)"
|
||||
)
|
||||
assert layout_parts[2] == "r", (
|
||||
f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
|
||||
)
|
||||
|
||||
# Create builder
|
||||
builder = GemmKernelBuilder(
|
||||
args.working_path, args.gpu_target, args.datatype, args.layout, args.config_json
|
||||
)
|
||||
|
||||
if args.list_kernels:
|
||||
# Fast listing mode - just write kernel list without generating files
|
||||
builder.write_kernel_list()
|
||||
elif args.gen_single:
|
||||
# Generate a single kernel file
|
||||
@@ -911,9 +797,11 @@ def main():
|
||||
trait_parts[6] == "True", # persistent
|
||||
)
|
||||
|
||||
k_block_per_cu = builder.config.get("k_block_per_cu")
|
||||
|
||||
# Generate the kernel
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
tile_config, trait_combo
|
||||
tile_config, trait_combo, k_block_per_cu
|
||||
)
|
||||
|
||||
# Write the file
|
||||
@@ -927,12 +815,12 @@ def main():
|
||||
|
||||
print(f"Generated {header_file}")
|
||||
|
||||
elif args.gen_individual:
|
||||
elif args.gen_all_individual:
|
||||
# Generate all individual kernel files
|
||||
builder.run(args.num_workers)
|
||||
else:
|
||||
parser.error(
|
||||
"Must specify one of: --list_kernels, --gen_individual, or --gen_single"
|
||||
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "benchmark_gemm.hpp"
|
||||
#include "gemm_benchmark.hpp"
|
||||
|
||||
class GemmProfiler
|
||||
{
|
||||
|
||||
@@ -1,231 +0,0 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Handles loading, parsing, and validation of JSON configuration parameters.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union, Tuple, Type, Dict
|
||||
import json
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnumConfigParam:
|
||||
"""Represents an enumeration-type configuration parameter"""
|
||||
|
||||
values: List[Union[int, str, bool]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RangeConfigParam:
|
||||
"""Represents a numeric range-type configuration parameter"""
|
||||
|
||||
min: int
|
||||
max: int
|
||||
step: int
|
||||
exclude: Optional[List[int]]
|
||||
|
||||
def generate_candidates(self) -> List[int]:
|
||||
"""Generates valid candidates after applying range constraints"""
|
||||
|
||||
if self.min > self.max:
|
||||
raise ValueError(f"Invalid range: min({self.min}) > max({self.max})")
|
||||
if self.step <= 0:
|
||||
raise ValueError(f"Step must be positive, got {self.step}")
|
||||
|
||||
candidates = list(range(self.min, self.max + 1, self.step))
|
||||
|
||||
if hasattr(self, "exclude") and self.exclude:
|
||||
if not isinstance(self.exclude, list):
|
||||
raise TypeError("exclude must be list type")
|
||||
exclude_set = set(self.exclude)
|
||||
candidates = [x for x in candidates if x not in exclude_set]
|
||||
|
||||
if not candidates:
|
||||
raise ValueError(
|
||||
f"No valid candidates for range [{self.min}-{self.max}] "
|
||||
f"with step {self.step} and excludes {self.exclude}"
|
||||
)
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProblemConfig:
|
||||
"""configuration class for problem parameter."""
|
||||
|
||||
datatypes: Tuple[EnumConfigParam, ...]
|
||||
layouts: Tuple[EnumConfigParam, ...]
|
||||
|
||||
@property
|
||||
def datatype_map(self) -> Dict[str, str]:
|
||||
"""Get datatype as a key-value map."""
|
||||
return {
|
||||
"matrix_a": self.datatypes[0].values[0],
|
||||
"matrix_b": self.datatypes[1].values[0],
|
||||
"matrix_c": self.datatypes[2].values[0],
|
||||
}
|
||||
|
||||
@property
|
||||
def layout_map(self) -> Dict[str, str]:
|
||||
"""Get layout as a key-value map."""
|
||||
return {
|
||||
"matrix_a": self.layouts[0].values[0],
|
||||
"matrix_b": self.layouts[1].values[0],
|
||||
"matrix_c": self.layouts[2].values[0],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TileConfig:
|
||||
"""Configuration class for tile parameter."""
|
||||
|
||||
tile_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
tile_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
tile_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
|
||||
warp_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
|
||||
warp_tile_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_tile_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_tile_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraitConfig:
|
||||
"""Configuration class for kernel traits."""
|
||||
|
||||
pipeline: EnumConfigParam
|
||||
scheduler: EnumConfigParam
|
||||
epilogue: EnumConfigParam
|
||||
pad_m: EnumConfigParam
|
||||
pad_n: EnumConfigParam
|
||||
pad_k: EnumConfigParam
|
||||
persistent: EnumConfigParam
|
||||
|
||||
|
||||
@dataclass
|
||||
class GemmConfig:
|
||||
"""Main configuration class for GEMM operations"""
|
||||
|
||||
problem: ProblemConfig
|
||||
tile_config: TileConfig
|
||||
trait_config: TraitConfig
|
||||
|
||||
@classmethod
|
||||
def from_json(
|
||||
cls: Type["GemmConfig"], filepath: str, datatype: str, layout: str
|
||||
) -> "GemmConfig":
|
||||
"""JSON configuration loader with validation controls"""
|
||||
config_path = Path(filepath)
|
||||
|
||||
try:
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Config file {filepath} not found")
|
||||
|
||||
with config_path.open("r") as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
a_type = datatype
|
||||
b_type = datatype
|
||||
c_type = datatype
|
||||
if b_type == "int4":
|
||||
a_type = "fp16"
|
||||
if b_type in ["bf8", "fp8", "int4"]:
|
||||
c_type = "fp16"
|
||||
|
||||
layout_parts = layout.lower()
|
||||
assert len(layout_parts) == 3, (
|
||||
f"Invalid layout string: {layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
|
||||
)
|
||||
assert layout_parts[0] in ("r", "c"), (
|
||||
f"Invalid matrix_a layout: {layout_parts[0]} (must be 'r' for row major or or 'c' for column major)"
|
||||
)
|
||||
assert layout_parts[1] in ("r", "c"), (
|
||||
f"Invalid matrix_a layout: {layout_parts[1]} (must be 'r' for row major or or 'c' for column major)"
|
||||
)
|
||||
assert layout_parts[2] == "r", (
|
||||
f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
|
||||
)
|
||||
a_layout = layout_parts[0]
|
||||
b_layout = layout_parts[1]
|
||||
c_layout = layout_parts[2]
|
||||
|
||||
# Parse problem config
|
||||
# TODO: Not reading datatype information from json file.
|
||||
problem = ProblemConfig(
|
||||
datatypes=(
|
||||
EnumConfigParam(values=[a_type]),
|
||||
EnumConfigParam(values=[b_type]),
|
||||
EnumConfigParam(values=[c_type]),
|
||||
),
|
||||
layouts=(
|
||||
EnumConfigParam(values=[a_layout]),
|
||||
EnumConfigParam(values=[b_layout]),
|
||||
EnumConfigParam(values=[c_layout]),
|
||||
),
|
||||
)
|
||||
|
||||
# Parse tile config
|
||||
def create_param(param_dict):
|
||||
if "values" in param_dict:
|
||||
return EnumConfigParam(values=param_dict["values"])
|
||||
else:
|
||||
return RangeConfigParam(
|
||||
min=param_dict["min"],
|
||||
max=param_dict["max"],
|
||||
step=param_dict["step"],
|
||||
exclude=param_dict.get("exclude", []),
|
||||
)
|
||||
|
||||
tile_config = TileConfig(
|
||||
tile_m=create_param(config_dict["tile_config"]["tile_m"]),
|
||||
tile_n=create_param(config_dict["tile_config"]["tile_n"]),
|
||||
tile_k=create_param(config_dict["tile_config"]["tile_k"]),
|
||||
warp_m=create_param(config_dict["tile_config"]["warp_m"]),
|
||||
warp_n=create_param(config_dict["tile_config"]["warp_n"]),
|
||||
warp_k=create_param(config_dict["tile_config"]["warp_k"]),
|
||||
warp_tile_m=create_param(config_dict["tile_config"]["warp_tile_m"]),
|
||||
warp_tile_n=create_param(config_dict["tile_config"]["warp_tile_n"]),
|
||||
warp_tile_k=create_param(config_dict["tile_config"]["warp_tile_k"]),
|
||||
)
|
||||
|
||||
# Parse trait config
|
||||
trait_config = TraitConfig(
|
||||
pipeline=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["pipeline"]["values"]
|
||||
),
|
||||
scheduler=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["scheduler"]["values"]
|
||||
),
|
||||
epilogue=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["epilogue"]["values"]
|
||||
),
|
||||
pad_m=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["pad_m"]["values"]
|
||||
),
|
||||
pad_n=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["pad_n"]["values"]
|
||||
),
|
||||
pad_k=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["pad_k"]["values"]
|
||||
),
|
||||
persistent=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["persistent"]["values"]
|
||||
),
|
||||
)
|
||||
|
||||
return cls(
|
||||
problem=problem, tile_config=tile_config, trait_config=trait_config
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON format: {str(e)}")
|
||||
except KeyError as e:
|
||||
raise KeyError(f"Missing required configuration field: {str(e)}")
|
||||
Reference in New Issue
Block a user