Ck tile engine gemm (#2982)

* Partial Progress : CK Tile Engine GEMM

* Partial Progress : CK Tile Engine GEMM

* Partial Progress : Working GEMM Code

* Partial Progress : Working GEMM Code

* Changinf jenkins to remove preshuffle

* Partial Progress : CK TILE ENGINE GEMM Debugging

* Partial Progress : Removing changes that are not GEMM

* Partial Progress : Validation of full block size in GEMM

* Changes in Jenkins to run only fp16 and bf16

* Addressing Review Comments

* Partial Progress : Addressing CI issues

* Partial Progress - Runing GEMM for fp16,bf16 and rcr

* Clang

* Adding fp8 and bf8

* Adding fp8 and bf8

* Adding additional architrcture

* Limited datatypes and layouts

* Adding k_block_per_cu in test config

* Changes to faling CI errors

* Changes to faling CI errors

* Validation for GEMM

* Adding Layout support

* Adding Validations

* Adding layout in jenkins

* Update on Jenkins

* Distribution validation for GEMM

* Resolving merge conflicts

* Solving merge conflicts
This commit is contained in:
Thrupti Raj Lakshmana Gowda
2025-10-27 21:11:13 -05:00
committed by GitHub
parent b11f53a484
commit 7fc0a38e90
18 changed files with 504 additions and 987 deletions

View File

@@ -65,7 +65,7 @@ function(create_individual_gemm_target datatype layout trait tile_config config_
# Create the executable
add_executable(${target_name}
EXCLUDE_FROM_ALL
${GEMM_SOURCE_DIR}/benchmark_gemm_single.cpp
${GEMM_SOURCE_DIR}/gemm_benchmark_single.cpp
${instance_header}
)
@@ -103,9 +103,9 @@ function(create_individual_gemm_target datatype layout trait tile_config config_
list(GET trait_parts 1 epilogue)
list(GET trait_parts 2 scheduler)
add_dependencies(benchmark_gemm_${pipeline} ${target_name})
add_dependencies(benchmark_gemm_${epilogue} ${target_name})
add_dependencies(benchmark_gemm_${scheduler} ${target_name})
add_dependencies(benchmark_gemm_${pipeline}_pipeline ${target_name})
add_dependencies(benchmark_gemm_${epilogue}_epilogue ${target_name})
add_dependencies(benchmark_gemm_${scheduler}_scheduler ${target_name})
endfunction()
# Function to build individual GEMM targets
@@ -286,15 +286,15 @@ else()
set(GEMM_SCHEDULERS "intrawave;interwave")
foreach(pipeline IN LISTS GEMM_PIPELINES)
add_custom_target(benchmark_gemm_${pipeline})
add_custom_target(benchmark_gemm_${pipeline}_pipeline)
endforeach()
foreach(epilogue IN LISTS GEMM_EPILOGUES)
add_custom_target(benchmark_gemm_${epilogue})
add_custom_target(benchmark_gemm_${epilogue}_epilogue)
endforeach()
foreach(scheduler IN LISTS GEMM_SCHEDULERS)
add_custom_target(benchmark_gemm_${scheduler})
add_custom_target(benchmark_gemm_${scheduler}_scheduler)
endforeach()
# Build individual targets for each datatype/layout combination

View File

@@ -187,7 +187,7 @@ python gemm_instance_builder.py \
--datatype fp16 \
--layout rcr \
--config_json configs/user_provided_config.json \
--gen_individual
--gen_all_individual
```
#### gemm_instance_builder_parallel.py

View File

@@ -23,6 +23,31 @@ ELEMENT_SIZE_MAP = {
"fp64": 8,
}
WARP_SUPPORTED_COMBINATIONS = {
"gfx90a": [
[1, 4, 1],
[2, 2, 1],
[4, 1, 1],
],
"gfx942": [
[1, 4, 1],
[2, 2, 1],
[4, 1, 1],
],
"gfx950": [
[1, 4, 1],
[2, 2, 1],
[4, 1, 1],
],
"gfx1201": [
[2, 4, 1],
[1, 8, 1],
[8, 1, 1],
[4, 2, 1],
],
}
# [TODO] Handle this while moving code to commons
# Supported warp tile combinations for different GPU architectures and data types
WARP_TILE_SUPPORTED_COMBINATIONS = {
"gfx90a": {
@@ -290,6 +315,7 @@ def is_tile_config_valid(
b_datatype: str,
c_datatype: str,
pipeline: str,
layout: str,
gpu_target: str,
trait_name: str = None,
) -> bool:
@@ -348,6 +374,24 @@ def is_tile_config_valid(
logging.debug(f"LDS validation failed: {lds_error}")
return False
# Validate whole workgroup cover configuration
wr_cover_valid, wg_cover_error = validate_whole_wg_cover_configuration(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
layout,
a_datatype,
b_datatype,
)
if not wr_cover_valid:
logging.debug(
f"Whole workgroup cover configuration validation failed: {wg_cover_error}"
)
return False
# Validate warp tile combination
warp_tile_valid, warp_tile_error = validate_warp_tile_combination(
warp_tile_m,
@@ -363,3 +407,209 @@ def is_tile_config_valid(
return False
return True
# [TODO] Handle this while moving code to commons Add more datatype to this function if needed
def get_dtype_string(datatype: str) -> str:
"""Get C++ type string for datatype"""
dtype_map = {
"fp16": "ck_tile::fp16_t",
"fp8": "ck_tile::fp8_t",
"bf8": "ck_tile::bf8_t",
"bf16": "ck_tile::bf16_t",
"fp32": "float",
"fp64": "double",
}
return dtype_map.get(datatype, "float")
LAYOUT_MAP = {
"r": "ck_tile::tensor_layout::gemm::RowMajor",
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
}
def get_abc_layouts(layout_code: str) -> Tuple[str, str, str]:
"""
Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'.
"""
code = str(layout_code).strip().lower()
a_layout = LAYOUT_MAP[code[0]]
b_layout = LAYOUT_MAP[code[1]]
c_layout = LAYOUT_MAP[code[2]]
return a_layout, b_layout, c_layout
def validate_whole_wg_cover_configuration(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
layout,
a_datatype,
b_datatype,
) -> Tuple[bool, str]:
# Validate whole workgroup cover configuration
warp_size = 64
NumWarps = warp_m * warp_n * warp_k
BlockSize = NumWarps * warp_size
XPerTile = 0
YPerTile = 0
vector_load_size = 0
# A matrix validation
if layout[0] == "r":
XPerTile = tile_k
YPerTile = tile_m
vector_load_size = get_global_vector_load_size(
BlockSize, tile_k, a_datatype, tile_m, tile_k
)
elif layout[0] == "c":
vector_load_size = get_global_vector_load_size(
BlockSize, tile_k, a_datatype, tile_m, tile_m
)
# Validate distribution
XPerTile = tile_k
YPerTile = tile_m
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
)
if not wg_cover_core_valid:
print("I am here 1")
logging.debug(
f"whole workgroup cover failed for Matrix A distribution: {wg_cover_core_error}"
)
return False, wg_cover_core_error
XPerTile = tile_m
YPerTile = tile_k
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
)
if not wg_cover_core_valid:
logging.debug(
f"whole workgroup cover failed for Matrix A: {wg_cover_core_error}"
)
return False, wg_cover_core_error
# B matrix validation
if layout[1] == "r":
vector_load_size = get_global_vector_load_size(
BlockSize, tile_k, b_datatype, tile_n, tile_n
)
# Validate distribution
XPerTile = tile_k
YPerTile = tile_n
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
)
if not wg_cover_core_valid:
print("I am here 3")
logging.debug(
f"whole workgroup cover failed for Matrix A distribution: {wg_cover_core_error}"
)
return False, wg_cover_core_error
XPerTile = tile_n
YPerTile = tile_k
elif layout[1] == "c":
XPerTile = tile_k
YPerTile = tile_n
vector_load_size = get_global_vector_load_size(
BlockSize, tile_k, b_datatype, tile_n, tile_k
)
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
)
if not wg_cover_core_valid:
print("I am here 4")
logging.debug(
f"whole workgroup cover failed for Matrix B: {wg_cover_core_error}"
)
return False, wg_cover_core_error
return True, ""
def wg_cover_core_validation(
XPerTile: int,
YPerTile: int,
BlockSize: int,
vector_load_size: int,
warp_size: int,
) -> Tuple[bool, str]:
if XPerTile % vector_load_size != 0:
return False
num_warps = BlockSize / warp_size
LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size)
X1 = LargestVec if vector_load_size > LargestVec else vector_load_size
X0 = XPerTile / X1
Y1 = warp_size // X0
if X0 * Y1 != warp_size:
return False, ""
return True, ""
def get_global_vector_load_size(
BlockSize: int,
KPerBlock: int,
DataType: str,
MNPerBlock: int,
XPerTile: int,
) -> int:
elements_per_thread = MNPerBlock * KPerBlock / BlockSize
PackedSize = 1
if (
XPerTile % (PackedSize * 32 / element_size(DataType)) == 0
and elements_per_thread % (PackedSize * 32 / element_size(DataType)) == 0
and PackedSize == 2
):
return PackedSize * 32 / element_size(DataType)
elif (
XPerTile % (PackedSize * 16 / element_size(DataType)) == 0
and elements_per_thread % (PackedSize * 16 / element_size(DataType)) == 0
):
return int(PackedSize * 16 / element_size(DataType))
elif (
XPerTile % (PackedSize * 8 / element_size(DataType)) == 0
and elements_per_thread % (PackedSize * 8 / element_size(DataType)) == 0
):
return int(PackedSize * 8 / element_size(DataType))
elif (
element_size(DataType) >= PackedSize * 4
and XPerTile % (PackedSize * 4 / element_size(DataType)) == 0
and elements_per_thread % (PackedSize * 4 / element_size(DataType)) == 0
):
return int(PackedSize * 4 / element_size(DataType))
elif (
element_size(DataType) >= PackedSize * 2
and XPerTile % (PackedSize * 2 / element_size(DataType)) == 0
and elements_per_thread % (PackedSize * 2 / element_size(DataType)) == 0
):
return int(PackedSize * 2 / element_size(DataType))
else:
return PackedSize

View File

@@ -1,105 +0,0 @@
{
"problem": {
},
"tile_config": {
"tile_m": {
"max": 256,
"min": 64,
"step": 64
},
"tile_n": {
"max": 256,
"min": 64,
"step": 64
},
"tile_k": {
"max": 256,
"min": 64,
"step": 64
},
"warp_m": {
"values": [
4,
2,
1
]
},
"warp_n": {
"values": [
4,
2,
1
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
4,
16,
32
]
},
"warp_tile_n": {
"values": [
16,
32,
64
]
},
"warp_tile_k": {
"values": [
8,
16,
32,
64,
128
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3",
"compv4",
"mem"
]
},
"scheduler": {
"values": [
"intrawave",
"interwave"
]
},
"epilogue": {
"values": [
"cshuffle",
"default"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
},
"persistent": {
"values": [
false,
true
]
}
}
}

View File

@@ -1,88 +0,0 @@
{
"problem": {
},
"tile_config": {
"tile_m": {
"values": [
128 ]
},
"tile_n": {
"values": [
128
]
},
"tile_k": {
"values": [
128
]
},
"warp_m": {
"values": [
2
]
},
"warp_n": {
"values": [
2
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
32
]
},
"warp_tile_n": {
"values": [
32
]
},
"warp_tile_k": {
"values": [
16
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3"
]
},
"scheduler": {
"values": [
"intrawave"
]
},
"epilogue": {
"values": [
"default"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
},
"persistent": {
"values": [
false,
true
]
}
}
}

View File

@@ -1,6 +1,4 @@
{
"problem": {
},
"tile_config": {
"tile_m": {
"max": 256,
@@ -101,5 +99,6 @@
true
]
}
}
},
"k_block_per_cu": 1
}

View File

@@ -1,102 +0,0 @@
{
"problem": {
},
"tile_config": {
"tile_m": {
"values": [
256,
128,
64
]
},
"tile_n": {
"values": [
256,
128,
64
]
},
"tile_k": {
"values": [
256,
128,
64
]
},
"warp_m": {
"values": [
4,
2,
1
]
},
"warp_n": {
"values": [
4,
2,
1
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
16
]
},
"warp_tile_n": {
"values": [
16
]
},
"warp_tile_k": {
"values": [
16
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3",
"mem"
]
},
"scheduler": {
"values": [
"intrawave",
"interwave"
]
},
"epilogue": {
"values": [
"cshuffle",
"default"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
},
"persistent": {
"values": [
false,
true
]
}
}
}

View File

@@ -1,31 +1,28 @@
{
"problem": {
},
"tile_config": {
"tile_m": {
"values": [
128,
256
64
]
},
"tile_n": {
"values": [
128
192
]
},
"tile_k": {
"values": [
128
64
]
},
"warp_m": {
"values": [
4
2
]
},
"warp_n": {
"values": [
1
2
]
},
"warp_k": {
@@ -35,36 +32,33 @@
},
"warp_tile_m": {
"values": [
16, 32
32
]
},
"warp_tile_n": {
"values": [
16, 32
32
]
},
"warp_tile_k": {
"values": [
16, 32
8
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3",
"mem"
"compv4"
]
},
"scheduler": {
"values": [
"intrawave",
"interwave"
"intrawave"
]
},
"epilogue": {
"values": [
"default",
"cshuffle"
]
},
@@ -85,9 +79,9 @@
},
"persistent": {
"values": [
false,
true
]
}
}
},
"k_block_per_cu": 1
}

View File

@@ -273,48 +273,6 @@ class GemmBenchmark:
print(f"Error reading JSON file {json_file}: {e}")
return None
def parse_benchmark_output(self, output: str) -> Optional[Dict]:
"""Parse the benchmark output format - extract JSON directly"""
try:
# Find JSON block between asterisk markers
lines = output.split("\n")
json_start = -1
json_end = -1
for i, line in enumerate(lines):
if line.strip().startswith("{"):
json_start = i
elif line.strip().endswith("}") and json_start != -1:
json_end = i
break
if json_start != -1 and json_end != -1:
json_text = "\n".join(lines[json_start : json_end + 1])
data = json.loads(json_text)
# Return the complete JSON data as-is, just add some convenience fields
result = data.copy()
if "perf_result" in data:
perf = data["perf_result"]
# Add convenience fields for backward compatibility
result["time_ms"] = perf.get("latency(ms)", 0)
result["tflops"] = perf.get("tflops(TFlops)", 0)
result["bandwidth_gb_s"] = perf.get("bandwidth(GB/s)", 0)
return result
return None
except json.JSONDecodeError as e:
if self.verbose:
print(f"Failed to parse JSON: {e}")
print(f"Output was: {output[:200]}...")
return None
except Exception as e:
if self.verbose:
print(f"Error parsing output: {e}")
return None
def benchmark_problem_size(
self,
kernels: List[Path],

View File

@@ -30,9 +30,9 @@ inline auto create_args(int argc, char* argv[])
.insert("stride_c", "0", "The stride value for tensor C. Default is 0.")
.insert("split_k", "1", "The split value for k dimension. Default is 1.")
.insert("verify",
"0",
"2",
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 "
"for validation on GPU. Default is 0, no validation.")
"for validation on GPU. Default is 2, GPU validation.")
.insert("log",
"false",
"Whether output kernel instance information or not. Possible values are true or "
@@ -75,7 +75,7 @@ inline auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
void benchmark_gemm_single(const ck_tile::ArgParser& arg_parser)
void benchmark_single(const ck_tile::ArgParser& arg_parser)
{
// Use DataTypeTraits to get the actual type names from the generated header
// The generated header defines ADataType, BDataType, AccDataType, CDataType
@@ -149,7 +149,7 @@ int main(int argc, char* argv[])
if(!result)
return EXIT_FAILURE;
benchmark_gemm_single(parser);
benchmark_single(parser);
return 0;
}
catch(const std::exception& e)

View File

@@ -9,6 +9,7 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
//[TODO] This can be moved to commons
// DataTypeTraits for all supported types
template <typename T>
struct DataTypeTraits;
@@ -97,49 +98,3 @@ struct KernelTraits
{
}
};
// Helper to extract traits from kernel name
inline KernelTraits extract_traits_from_name(const std::string& kernel_name)
{
KernelTraits traits;
// Extract pipeline
if(kernel_name.find("compv3") != std::string::npos)
{
traits.pipeline = "compv3";
}
else if(kernel_name.find("compv4") != std::string::npos)
{
traits.pipeline = "compv4";
}
else if(kernel_name.find("mem") != std::string::npos)
{
traits.pipeline = "mem";
}
// Extract scheduler
if(kernel_name.find("interwave") != std::string::npos)
{
traits.scheduler = "interwave";
}
else
{
traits.scheduler = "intrawave";
}
// Extract epilogue
if(kernel_name.find("default") != std::string::npos &&
kernel_name.find("default_") == std::string::npos)
{
traits.epilogue = "default";
}
else
{
traits.epilogue = "cshuffle";
}
// Padding flags would need to be extracted from the kernel configuration
// For now, we'll leave them as false
return traits;
}

View File

@@ -8,8 +8,12 @@ import multiprocessing
import concurrent.futures
from pathlib import Path
import logging
from typing import Optional
from validation_utils import is_tile_config_valid, is_trait_combination_valid
from commons.validation_utils import (
is_tile_config_valid,
is_trait_combination_valid,
get_dtype_string,
get_abc_layouts,
)
logging.basicConfig(level=logging.INFO)
@@ -29,149 +33,150 @@ class GemmKernelBuilder:
if config_json and os.path.exists(config_json):
with open(config_json, "r") as f:
self.config = json.load(f)
else:
self.config = self._get_default_config()
def _get_default_config(self):
"""Return default configuration if no config file is provided"""
# Define base tile configurations that work for all layouts
base_fp16_configs = [
{
"tile_m": 256,
"tile_n": 256,
"tile_k": 32,
"warp_m": 2,
"warp_n": 2,
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 32,
},
{
"tile_m": 256,
"tile_n": 128,
"tile_k": 32,
"warp_m": 2,
"warp_n": 2,
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 16,
},
]
def write_kernel_list(self):
"""Write kernel list to file for CMake to read (with comprehensive validation)"""
# Get configurations using comprehensive validation
tile_configs = self._get_tile_configs(fast_mode=False)
trait_combos = self._generate_trait_combinations()
base_fp8_configs = [
{
"tile_m": 256,
"tile_n": 256,
"tile_k": 32,
"warp_m": 4,
"warp_n": 1,
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 32,
},
{
"tile_m": 256,
"tile_n": 128,
"tile_k": 32,
"warp_m": 1,
"warp_n": 4,
"warp_k": 1,
"warp_tile_m": 16,
"warp_tile_n": 16,
"warp_tile_k": 32,
},
]
kernel_list = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
(
pipeline,
epilogue,
scheduler,
pad_m,
pad_n,
pad_k,
persistent,
) = trait_combo
# Create configurations for all supported layouts
all_layouts = ["rcr", "rrr", "ccr", "crr"]
tile_configs = {}
# Create kernel name with proper boolean capitalization
kernel_name = f"gemm_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
for datatype, base_configs in [
("fp16", base_fp16_configs),
("fp8", base_fp8_configs),
]:
tile_configs[datatype] = {}
for layout in all_layouts:
tile_configs[datatype][layout] = base_configs
# Create tile configuration string
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
return {
"tile_configs": tile_configs,
"traits": {
"pipelines": ["mem", "compv3", "compv4"],
"epilogues": ["default", "cshuffle"],
"schedulers": ["intrawave", "interwave"],
},
"structured_sparsity": ["false"],
"padding": {"pad_m": ["false"], "pad_n": ["false"], "pad_k": ["false"]},
"persistent": ["false"],
}
kernel_name += f"_{tile_str}"
kernel_list.append(
{
"name": kernel_name,
"tile_config": tile_config,
"trait_combo": trait_combo,
}
)
# Write kernel count
with open(self.working_path / "gemm_kernel_count.txt", "w") as f:
f.write(str(len(kernel_list)))
# Write kernel list
with open(self.working_path / "gemm_kernel_list.txt", "w") as f:
for kernel in kernel_list:
# Format: kernel_name|tile_config|trait_combo
tile_config = kernel["tile_config"]
trait_combo = kernel["trait_combo"]
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
trait_str = (
f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_"
+ "_".join(str(x) for x in trait_combo[3:])
)
f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n")
print(f"Listed {len(kernel_list)} kernel configurations")
def _get_tile_configs(self, fast_mode=False):
"""Get tile configurations for the current datatype and layout"""
if "tile_configs" in self.config:
# Old format
return (
self.config["tile_configs"].get(self.datatype, {}).get(self.layout, [])
tile_config = self.config["tile_config"]
# Generate values in the config if default range is given
if tile_config.get("tile_m").get("values") is None:
tile_config.get("tile_m")["values"] = self._generate_values(
tile_config.get("tile_m").get("min"),
tile_config.get("tile_m").get("max"),
tile_config.get("tile_m").get("step"),
)
if tile_config.get("tile_n").get("values") is None:
tile_config.get("tile_n")["values"] = self._generate_values(
tile_config.get("tile_n").get("min"),
tile_config.get("tile_n").get("max"),
tile_config.get("tile_n").get("step"),
)
if tile_config.get("tile_k").get("values") is None:
tile_config.get("tile_k")["values"] = self._generate_values(
tile_config.get("tile_k").get("min"),
tile_config.get("tile_k").get("max"),
tile_config.get("tile_k").get("step"),
)
elif "tile_config" in self.config:
# New format - generate combinations from individual parameter values
tile_config = self.config["tile_config"]
# Get all possible values for each parameter
tile_m_values = tile_config.get("tile_m", {}).get("values", [256])
tile_n_values = tile_config.get("tile_n", {}).get("values", [256])
tile_k_values = tile_config.get("tile_k", {}).get("values", [32])
warp_m_values = tile_config.get("warp_m", {}).get("values", [2])
warp_n_values = tile_config.get("warp_n", {}).get("values", [2])
warp_k_values = tile_config.get("warp_k", {}).get("values", [1])
warp_tile_m_values = tile_config.get("warp_tile_m", {}).get("values", [32])
warp_tile_n_values = tile_config.get("warp_tile_n", {}).get("values", [32])
warp_tile_k_values = tile_config.get("warp_tile_k", {}).get("values", [32])
# Get all possible values for each parameter
tile_m_values = tile_config.get("tile_m").get("values")
tile_n_values = tile_config.get("tile_n").get("values")
tile_k_values = tile_config.get("tile_k").get("values")
warp_m_values = tile_config.get("warp_m").get("values")
warp_n_values = tile_config.get("warp_n").get("values")
warp_k_values = tile_config.get("warp_k").get("values")
warp_tile_m_values = tile_config.get("warp_tile_m").get("values")
warp_tile_n_values = tile_config.get("warp_tile_n").get("values")
warp_tile_k_values = tile_config.get("warp_tile_k").get("values")
# Generate all combinations
configs = []
for tile_m in tile_m_values:
for tile_n in tile_n_values:
for tile_k in tile_k_values:
for warp_m in warp_m_values:
for warp_n in warp_n_values:
for warp_k in warp_k_values:
for warp_tile_m in warp_tile_m_values:
for warp_tile_n in warp_tile_n_values:
for warp_tile_k in warp_tile_k_values:
# Validate configuration
if self._validate_tile_config(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
fast_mode=fast_mode,
):
configs.append(
{
"tile_m": tile_m,
"tile_n": tile_n,
"tile_k": tile_k,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_k": warp_k,
"warp_tile_m": warp_tile_m,
"warp_tile_n": warp_tile_n,
"warp_tile_k": warp_tile_k,
}
)
return configs
else:
# Fallback to default
return []
# Generate all combinations
configs = []
for tile_m in tile_m_values:
for tile_n in tile_n_values:
for tile_k in tile_k_values:
for warp_m in warp_m_values:
for warp_n in warp_n_values:
for warp_k in warp_k_values:
for warp_tile_m in warp_tile_m_values:
for warp_tile_n in warp_tile_n_values:
for warp_tile_k in warp_tile_k_values:
# Validate configuration
if self._validate_tile_config(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
fast_mode=fast_mode,
):
configs.append(
{
"tile_m": tile_m,
"tile_n": tile_n,
"tile_k": tile_k,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_k": warp_k,
"warp_tile_m": warp_tile_m,
"warp_tile_n": warp_tile_n,
"warp_tile_k": warp_tile_k,
}
)
return configs
def _generate_values(self, min_val, max_val, step):
"""Generate a list of values from min to max with the given step"""
values = []
val = min_val
while val <= max_val:
values.append(val)
val += step
return values
def _validate_tile_config(
self,
@@ -184,7 +189,7 @@ class GemmKernelBuilder:
warp_tile_m,
warp_tile_n,
warp_tile_k,
pipeline="mem", # Default pipeline for validation
pipeline="compv4", # Default pipeline for validation
fast_mode=False, # Add fast mode option
):
"""Validate that tile configuration is reasonable"""
@@ -213,6 +218,8 @@ class GemmKernelBuilder:
b_datatype = self.datatype
c_datatype = self.datatype
layout = self.layout
# Special handling for certain data types
if self.datatype in ["fp8", "bf8"]:
c_datatype = "fp16"
@@ -232,125 +239,50 @@ class GemmKernelBuilder:
b_datatype,
c_datatype,
pipeline,
layout,
self.gpu_target,
)
def _generate_trait_combinations(self):
"""Generate all combinations of traits"""
if "traits" in self.config:
# Old format
traits = self.config["traits"]
pipelines = traits["pipelines"]
epilogues = traits["epilogues"]
schedulers = traits["schedulers"]
padding = self.config["padding"]
persistent = self.config["persistent"]
trait_config = self.config["trait_config"]
all_combinations = list(
itertools.product(
pipelines,
epilogues,
schedulers,
padding["pad_m"],
padding["pad_n"],
padding["pad_k"],
persistent,
pipelines = trait_config.get("pipeline").get("values")
epilogues = trait_config.get("epilogue").get("values")
schedulers = trait_config.get("scheduler").get("values")
pad_m_values = trait_config.get("pad_m").get("values")
pad_n_values = trait_config.get("pad_n").get("values")
pad_k_values = trait_config.get("pad_k").get("values")
persistent_values = trait_config.get("persistent").get("values")
all_combinations = list(
itertools.product(
pipelines,
epilogues,
schedulers,
pad_m_values,
pad_n_values,
pad_k_values,
persistent_values,
)
)
# Filter out unsupported trait combinations
combinations = []
for combo in all_combinations:
pipeline, epilogue, scheduler = combo[:3]
if is_trait_combination_valid(pipeline, epilogue, scheduler):
combinations.append(combo)
else:
logging.debug(
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
)
)
# Filter out unsupported trait combinations
combinations = []
for combo in all_combinations:
pipeline, epilogue, scheduler = combo[:3]
if is_trait_combination_valid(pipeline, epilogue, scheduler):
combinations.append(combo)
else:
logging.debug(
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
)
elif "trait_config" in self.config:
# New format
trait_config = self.config["trait_config"]
pipelines = trait_config.get("pipeline", {}).get("values", ["mem"])
epilogues = trait_config.get("epilogue", {}).get("values", ["default"])
schedulers = trait_config.get("scheduler", {}).get("values", ["intrawave"])
pad_m_values = trait_config.get("pad_m", {}).get("values", [False])
pad_n_values = trait_config.get("pad_n", {}).get("values", [False])
pad_k_values = trait_config.get("pad_k", {}).get("values", [False])
persistent_values = trait_config.get("persistent", {}).get(
"values", [False]
)
all_combinations = list(
itertools.product(
pipelines,
epilogues,
schedulers,
pad_m_values,
pad_n_values,
pad_k_values,
persistent_values,
)
)
# Filter out unsupported trait combinations
combinations = []
for combo in all_combinations:
pipeline, epilogue, scheduler = combo[:3]
if is_trait_combination_valid(pipeline, epilogue, scheduler):
combinations.append(combo)
else:
logging.debug(
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
)
else:
# Fallback to minimal default
combinations = [("mem", "default", "intrawave", False, False, False, False)]
return combinations
def _get_dtype_string(self):
"""Get C++ type string for datatype"""
dtype_map = {
"fp16": "ck_tile::fp16_t",
"fp8": "ck_tile::fp8_t",
"bf16": "ck_tile::bf16_t",
"fp32": "float",
"fp64": "double",
}
return dtype_map.get(self.datatype, "float")
_LAYOUT_MAP = {
"r": "ck_tile::tensor_layout::gemm::RowMajor",
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
}
def _get_abc_layouts(self, layout_code: Optional[str] = None):
"""
Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'.
If layout_code is None, use self.layout.
"""
if layout_code is None:
# fall back to the instance field
layout_code = getattr(self, "layout", "")
code = str(layout_code).strip().lower()
if len(code) != 3 or any(ch not in self._LAYOUT_MAP for ch in code):
raise ValueError(
f"Invalid layout '{layout_code}'. "
"Use a 3-letter code with 'r'/'c' (e.g., rcr, ccr, crr, rrr)."
)
a_layout = self._LAYOUT_MAP[code[0]]
b_layout = self._LAYOUT_MAP[code[1]]
c_layout = self._LAYOUT_MAP[code[2]]
return a_layout, b_layout, c_layout
def _generate_kernel_instance(self, tile_config, trait_combo, is_header=True):
def _generate_kernel_instance(
self, tile_config, trait_combo, k_block_per_cu, is_header=True
):
"""Generate a single kernel instance"""
(
pipeline,
@@ -383,6 +315,13 @@ class GemmKernelBuilder:
"compv4": "ck_tile::GemmPipelineAgBgCrCompV4",
}
# Map pipeline names to base pipeline for hot loop detection
base_pipeline_map = {
"mem": "ck_tile::BaseGemmPipelineAgBgCrMem",
"compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3",
"compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4",
}
# Map scheduler names to the correct enum values
scheduler_type_map = {
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
@@ -392,23 +331,14 @@ class GemmKernelBuilder:
# Determine accumulator type based on datatype
acc_type = "float"
if self.datatype in ["int8", "int4"]:
acc_type = "ck_tile::int32_t"
# Determine output type
c_type = self._get_dtype_string()
c_type = self.datatype
if self.datatype in ["fp8", "bf8"]:
c_type = "ck_tile::fp16_t"
c_type = "fp16"
# Determine layouts based on self.layout
a_layout, b_layout, c_layout = self._get_abc_layouts()
# Map pipeline names to base pipeline for hot loop detection
base_pipeline_map = {
"mem": "ck_tile::BaseGemmPipelineAgBgCrMem",
"compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3",
"compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4",
}
a_layout, b_layout, c_layout = get_abc_layouts(self.layout)
# Generate kernel instance code using the correct API
pragma_line = "#pragma once\n" if is_header else ""
@@ -425,10 +355,10 @@ class GemmKernelBuilder:
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
using ADataType = {self._get_dtype_string()};
using BDataType = {self._get_dtype_string()};
using ADataType = {get_dtype_string(self.datatype)};
using BDataType = {get_dtype_string(self.datatype)};
using AccDataType = {acc_type};
using CDataType = {c_type};
using CDataType = {get_dtype_string(c_type)};
using ALayout = {a_layout};
using BLayout = {b_layout};
@@ -484,7 +414,7 @@ struct SelectedKernel {{
Traits>;
// Base pipeline for hot loop detection
using BaseGemmPipeline = {base_pipeline_map.get(pipeline, "ck_tile::BaseGemmPipelineAgBgCrMem")}<GemmPipelineProblem>;
using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}<GemmPipelineProblem>;
static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
const ck_tile::index_t k_grain = args.k_batch * TileK;
@@ -498,7 +428,7 @@ struct SelectedKernel {{
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = {scheduler_type_map.get(scheduler, "ck_tile::GemmPipelineScheduler::Intrawave")};
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};
[[maybe_unused]] constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
@@ -514,7 +444,7 @@ struct SelectedKernel {{
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::GemmPipelineAgBgCrCompV3")}<UniversalGemmProblem>;
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;
// Epilogue
"""
@@ -589,7 +519,7 @@ struct SelectedKernel {{
}}
// Launch kernel
constexpr int kBlockPerCu = 1;
constexpr int kBlockPerCu = {k_block_per_cu};
ave_time = ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
@@ -616,9 +546,13 @@ struct SelectedKernel {{
}}
}};
"""
return kernel_name, instance_code
def run(self, num_workers=None):
"""Run the builder to generate individual kernel files"""
# Generate individual kernel files
self.generate_individual(num_workers)
def generate_individual(self, num_workers=None):
"""Generate individual kernel files for separate compilation with parallel processing"""
if num_workers is None:
@@ -628,6 +562,7 @@ struct SelectedKernel {{
tile_configs = self._get_tile_configs()
trait_combos = self._generate_trait_combinations()
k_block_per_cu = self.config.get("k_block_per_cu")
# Prepare work items for parallel processing
work_items = []
@@ -637,6 +572,7 @@ struct SelectedKernel {{
(
tile_config,
trait_combo,
k_block_per_cu,
self.working_path,
self.datatype,
self.layout,
@@ -723,83 +659,17 @@ struct SelectedKernel {{
with open(self.working_path / "gemm_individual_targets.cmake", "w") as f:
f.write(cmake_code)
def write_kernel_list(self):
"""Write kernel list to file for CMake to read (with comprehensive validation)"""
# Get configurations using comprehensive validation
tile_configs = self._get_tile_configs(fast_mode=False)
trait_combos = self._generate_trait_combinations()
kernel_list = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
(
pipeline,
epilogue,
scheduler,
pad_m,
pad_n,
pad_k,
persistent,
) = trait_combo
# Create kernel name with proper boolean capitalization
kernel_name = f"gemm_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
# Create tile configuration string
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
kernel_name += f"_{tile_str}"
kernel_list.append(
{
"name": kernel_name,
"tile_config": tile_config,
"trait_combo": trait_combo,
}
)
# Write kernel count
with open(self.working_path / "gemm_kernel_count.txt", "w") as f:
f.write(str(len(kernel_list)))
# Write kernel list
with open(self.working_path / "gemm_kernel_list.txt", "w") as f:
for kernel in kernel_list:
# Format: kernel_name|tile_config|trait_combo
tile_config = kernel["tile_config"]
trait_combo = kernel["trait_combo"]
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
trait_str = (
f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_"
+ "_".join(str(x) for x in trait_combo[3:])
)
f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n")
print(f"Listed {len(kernel_list)} kernel configurations")
def run(self, num_workers=None):
"""Run the builder to generate individual kernel files"""
# Generate individual kernel files
self.generate_individual(num_workers)
def _generate_single_kernel_individual(work_item):
"""Worker function to generate a single individual kernel file"""
tile_config, trait_combo, working_path, datatype, layout = work_item
tile_config, trait_combo, k_block_per_cu, working_path, datatype, layout = work_item
# Create a temporary builder instance for this worker
builder = GemmKernelBuilder(working_path, datatype, layout)
try:
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo
tile_config, trait_combo, k_block_per_cu
)
# Create simplified filename without the "gemm_" prefix
@@ -832,7 +702,7 @@ def main():
parser.add_argument(
"--datatype",
required=True,
choices=["fp16", "fp8", "bf16", "fp32", "fp64"],
choices=["fp16", "fp8", "bf16", "bf8"],
help="Data type",
)
parser.add_argument(
@@ -846,7 +716,9 @@ def main():
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
)
parser.add_argument(
"--gen_individual", action="store_true", help="Generate individual kernel files"
"--gen_all_individual",
action="store_true",
help="Generate individual kernel files",
)
parser.add_argument(
"--gen_single", action="store_true", help="Generate a single kernel file"
@@ -866,13 +738,27 @@ def main():
args = parser.parse_args()
assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], (
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])"
)
layout_parts = args.layout.lower()
assert len(layout_parts) == 3, (
f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
)
assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], (
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)"
)
assert layout_parts[2] == "r", (
f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
)
# Create builder
builder = GemmKernelBuilder(
args.working_path, args.gpu_target, args.datatype, args.layout, args.config_json
)
if args.list_kernels:
# Fast listing mode - just write kernel list without generating files
builder.write_kernel_list()
elif args.gen_single:
# Generate a single kernel file
@@ -911,9 +797,11 @@ def main():
trait_parts[6] == "True", # persistent
)
k_block_per_cu = builder.config.get("k_block_per_cu")
# Generate the kernel
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo
tile_config, trait_combo, k_block_per_cu
)
# Write the file
@@ -927,12 +815,12 @@ def main():
print(f"Generated {header_file}")
elif args.gen_individual:
elif args.gen_all_individual:
# Generate all individual kernel files
builder.run(args.num_workers)
else:
parser.error(
"Must specify one of: --list_kernels, --gen_individual, or --gen_single"
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
)

View File

@@ -9,7 +9,7 @@
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "benchmark_gemm.hpp"
#include "gemm_benchmark.hpp"
class GemmProfiler
{

View File

@@ -1,231 +0,0 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# -*- coding: utf-8 -*-
"""
Handles loading, parsing, and validation of JSON configuration parameters.
"""
from pathlib import Path
from dataclasses import dataclass
from typing import List, Optional, Union, Tuple, Type, Dict
import json
@dataclass
class EnumConfigParam:
"""Represents an enumeration-type configuration parameter"""
values: List[Union[int, str, bool]]
@dataclass
class RangeConfigParam:
"""Represents a numeric range-type configuration parameter"""
min: int
max: int
step: int
exclude: Optional[List[int]]
def generate_candidates(self) -> List[int]:
"""Generates valid candidates after applying range constraints"""
if self.min > self.max:
raise ValueError(f"Invalid range: min({self.min}) > max({self.max})")
if self.step <= 0:
raise ValueError(f"Step must be positive, got {self.step}")
candidates = list(range(self.min, self.max + 1, self.step))
if hasattr(self, "exclude") and self.exclude:
if not isinstance(self.exclude, list):
raise TypeError("exclude must be list type")
exclude_set = set(self.exclude)
candidates = [x for x in candidates if x not in exclude_set]
if not candidates:
raise ValueError(
f"No valid candidates for range [{self.min}-{self.max}] "
f"with step {self.step} and excludes {self.exclude}"
)
return candidates
@dataclass
class ProblemConfig:
"""configuration class for problem parameter."""
datatypes: Tuple[EnumConfigParam, ...]
layouts: Tuple[EnumConfigParam, ...]
@property
def datatype_map(self) -> Dict[str, str]:
"""Get datatype as a key-value map."""
return {
"matrix_a": self.datatypes[0].values[0],
"matrix_b": self.datatypes[1].values[0],
"matrix_c": self.datatypes[2].values[0],
}
@property
def layout_map(self) -> Dict[str, str]:
"""Get layout as a key-value map."""
return {
"matrix_a": self.layouts[0].values[0],
"matrix_b": self.layouts[1].values[0],
"matrix_c": self.layouts[2].values[0],
}
@dataclass
class TileConfig:
"""Configuration class for tile parameter."""
tile_m: Union[EnumConfigParam, RangeConfigParam]
tile_n: Union[EnumConfigParam, RangeConfigParam]
tile_k: Union[EnumConfigParam, RangeConfigParam]
warp_m: Union[EnumConfigParam, RangeConfigParam]
warp_n: Union[EnumConfigParam, RangeConfigParam]
warp_k: Union[EnumConfigParam, RangeConfigParam]
warp_tile_m: Union[EnumConfigParam, RangeConfigParam]
warp_tile_n: Union[EnumConfigParam, RangeConfigParam]
warp_tile_k: Union[EnumConfigParam, RangeConfigParam]
@dataclass
class TraitConfig:
"""Configuration class for kernel traits."""
pipeline: EnumConfigParam
scheduler: EnumConfigParam
epilogue: EnumConfigParam
pad_m: EnumConfigParam
pad_n: EnumConfigParam
pad_k: EnumConfigParam
persistent: EnumConfigParam
@dataclass
class GemmConfig:
"""Main configuration class for GEMM operations"""
problem: ProblemConfig
tile_config: TileConfig
trait_config: TraitConfig
@classmethod
def from_json(
cls: Type["GemmConfig"], filepath: str, datatype: str, layout: str
) -> "GemmConfig":
"""JSON configuration loader with validation controls"""
config_path = Path(filepath)
try:
if not config_path.exists():
raise FileNotFoundError(f"Config file {filepath} not found")
with config_path.open("r") as f:
config_dict = json.load(f)
a_type = datatype
b_type = datatype
c_type = datatype
if b_type == "int4":
a_type = "fp16"
if b_type in ["bf8", "fp8", "int4"]:
c_type = "fp16"
layout_parts = layout.lower()
assert len(layout_parts) == 3, (
f"Invalid layout string: {layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
)
assert layout_parts[0] in ("r", "c"), (
f"Invalid matrix_a layout: {layout_parts[0]} (must be 'r' for row major or or 'c' for column major)"
)
assert layout_parts[1] in ("r", "c"), (
f"Invalid matrix_a layout: {layout_parts[1]} (must be 'r' for row major or or 'c' for column major)"
)
assert layout_parts[2] == "r", (
f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
)
a_layout = layout_parts[0]
b_layout = layout_parts[1]
c_layout = layout_parts[2]
# Parse problem config
# TODO: Not reading datatype information from json file.
problem = ProblemConfig(
datatypes=(
EnumConfigParam(values=[a_type]),
EnumConfigParam(values=[b_type]),
EnumConfigParam(values=[c_type]),
),
layouts=(
EnumConfigParam(values=[a_layout]),
EnumConfigParam(values=[b_layout]),
EnumConfigParam(values=[c_layout]),
),
)
# Parse tile config
def create_param(param_dict):
if "values" in param_dict:
return EnumConfigParam(values=param_dict["values"])
else:
return RangeConfigParam(
min=param_dict["min"],
max=param_dict["max"],
step=param_dict["step"],
exclude=param_dict.get("exclude", []),
)
tile_config = TileConfig(
tile_m=create_param(config_dict["tile_config"]["tile_m"]),
tile_n=create_param(config_dict["tile_config"]["tile_n"]),
tile_k=create_param(config_dict["tile_config"]["tile_k"]),
warp_m=create_param(config_dict["tile_config"]["warp_m"]),
warp_n=create_param(config_dict["tile_config"]["warp_n"]),
warp_k=create_param(config_dict["tile_config"]["warp_k"]),
warp_tile_m=create_param(config_dict["tile_config"]["warp_tile_m"]),
warp_tile_n=create_param(config_dict["tile_config"]["warp_tile_n"]),
warp_tile_k=create_param(config_dict["tile_config"]["warp_tile_k"]),
)
# Parse trait config
trait_config = TraitConfig(
pipeline=EnumConfigParam(
values=config_dict["trait_config"]["pipeline"]["values"]
),
scheduler=EnumConfigParam(
values=config_dict["trait_config"]["scheduler"]["values"]
),
epilogue=EnumConfigParam(
values=config_dict["trait_config"]["epilogue"]["values"]
),
pad_m=EnumConfigParam(
values=config_dict["trait_config"]["pad_m"]["values"]
),
pad_n=EnumConfigParam(
values=config_dict["trait_config"]["pad_n"]["values"]
),
pad_k=EnumConfigParam(
values=config_dict["trait_config"]["pad_k"]["values"]
),
persistent=EnumConfigParam(
values=config_dict["trait_config"]["persistent"]["values"]
),
)
return cls(
problem=problem, tile_config=tile_config, trait_config=trait_config
)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {str(e)}")
except KeyError as e:
raise KeyError(f"Missing required configuration field: {str(e)}")