add default config

This commit is contained in:
Yanxing-Shi
2025-05-07 10:59:36 +00:00
parent bc72ec4cfb
commit 1ccecf9a11
9 changed files with 344 additions and 170 deletions

View File

@@ -4,30 +4,39 @@
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
--problem_json ${CMAKE_CURRENT_LIST_DIR}/configs/gemm_problem.json
--use_default_config
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provide_config.json
--list_blobs
RESULT_VARIABLE ret
)
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS
${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
${CMAKE_CURRENT_LIST_DIR}/configs/gemm_problem.json
${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json
${CMAKE_CURRENT_LIST_DIR}/configs/user_provide_config.json
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
message( FATAL_ERROR "Fail to list kernels via Python. ${ret}")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS)
message("test--------------------------------------------------")
add_custom_command(
OUTPUT ${GEMM_CODEGEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
--problem_json ${CMAKE_CURRENT_LIST_DIR}/configs/gemm_problem.json
--use_default_config
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provide_config.json
--gen_blobs
DEPENDS ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt
${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
${CMAKE_CURRENT_LIST_DIR}/configs/gemm_problem.json
${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json
${CMAKE_CURRENT_LIST_DIR}/configs/user_provide_config.json
)
set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm")

View File

@@ -1,10 +1,16 @@
# GEMM Matrix Multiplication
Use the files in this folder to generate and build applications that run Matrix multiplications using ck_tile programming based on the kernel parameters mentioned in the config file `./configs/instance_combination.json`.
Use the files in this folder to generate and build applications that run Matrix multiplications using ck_tile programming based on the kernel parameters mentioned in the config file.
# Gemm Problem
User needs to provide gemm problem such as datatype, layout in the config file. For reference please see `./configs/gemm_problem.json`.
# Kernel Configurations
User needs to provide kernel configuration such as datatype, layout, tile size, warp size, padding, pipeline, scheduler and epilogue in the config file. For reference please see `./configs/instance_combination.json`
User can provide kernel configuration such as tile size, warp size, padding, pipeline, scheduler and epilogue in the config file. For reference please see `./configs/user_provide_config.json`. The Tile engine also has default kernel configuration to expand the range of kernel configuration which is saved in `./configs/default.json`.
## Build
```

View File

@@ -57,7 +57,7 @@ class Profiler
BLayout::name,
CLayout::name};
KernelInstance kernel_instance{environment_, description, problem, {-1.0f, -1.0f, -1.0f}};
KernelInstance kernel_instance{description, problem, {-1.0f, -1.0f, -1.0f}};
float avg_time = Kernel::launch(args, s);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());

View File

@@ -0,0 +1,81 @@
{
"tile_m": {
"values": [
256
]
},
"tile_n": {
"values": [
256
]
},
"tile_k": {
"values": [
64,
32
]
},
"warp_m": {
"values": [
2
]
},
"warp_n": {
"values": [
2
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
32
]
},
"warp_tile_n": {
"values": [
32
]
},
"warp_tile_k": {
"values": [
16
]
},
"kPadM": {
"values": [
false
]
},
"kPadN": {
"values": [
false
]
},
"kPadK": {
"values": [
false
]
},
"pipeline": {
"values": [
"compv3",
"mem"
]
},
"scheduler": {
"values": [
"intrawave",
"interwave"
]
},
"epilogue": {
"values": [
"default",
"cshuffle"
]
}
}

View File

@@ -0,0 +1,22 @@
{
"layout_a": {
"values": [
"r"
]
},
"layout_b": {
"values": [
"c"
]
},
"layout_c": {
"values": [
"r"
]
},
"datatype": {
"values": [
"fp16"
]
}
}

View File

@@ -1,60 +0,0 @@
{
"layout_a": {
"values": ["r"]
},
"layout_b": {
"values": ["c"]
},
"layout_c": {
"values": ["r"]
},
"datatype": {
"values": ["fp16"]
},
"tile_m": {
"values": [256]
},
"tile_n": {
"values": [256]
},
"tile_k": {
"values": [64, 32]
},
"warp_m": {
"values": [2]
},
"warp_n": {
"values": [2]
},
"warp_k": {
"values": [1]
},
"warp_tile_m": {
"values": [32]
},
"warp_tile_n": {
"values": [32]
},
"warp_tile_k": {
"values": [16]
},
"kPadM": {
"values": [false]
},
"kPadN": {
"values": [false]
},
"kPadK": {
"values": [false]
},
"pipeline": {
"values": ["compv3", "mem"]
},
"scheduler": {
"values": ["intrawave", "interwave"]
},
"epilogue": {
"values": ["default", "cshuffle"]
}
}

View File

@@ -0,0 +1,81 @@
{
"tile_m": {
"values": [
256
]
},
"tile_n": {
"values": [
256
]
},
"tile_k": {
"values": [
64,
32
]
},
"warp_m": {
"values": [
2
]
},
"warp_n": {
"values": [
2
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
32
]
},
"warp_tile_n": {
"values": [
32
]
},
"warp_tile_k": {
"values": [
16
]
},
"kPadM": {
"values": [
false
]
},
"kPadN": {
"values": [
false
]
},
"kPadK": {
"values": [
false
]
},
"pipeline": {
"values": [
"compv3",
"mem"
]
},
"scheduler": {
"values": [
"intrawave",
"interwave"
]
},
"epilogue": {
"values": [
"default",
"cshuffle"
]
}
}

View File

@@ -39,26 +39,25 @@ struct GemmProblem
std::string dtype_a, dtype_b, dtype_acc, dtype_c;
std::string layout_a, layout_b, layout_c;
std::string serialize() const
friend std::ostream& operator<<(std::ostream& os, const GemmProblem& problem)
{
std::ostringstream oss;
oss << "{"
<< "\"split_k\":" << split_k << ","
<< "\"m\":" << m << ","
<< "\"n\":" << n << ","
<< "\"k\":" << k << ","
<< "\"stride_a\":" << stride_a << ","
<< "\"stride_b\":" << stride_b << ","
<< "\"stride_c\":" << stride_c << ","
<< "\"dtype_a\":\"" << dtype_a << "\","
<< "\"dtype_b\":\"" << dtype_b << "\","
<< "\"dtype_acc\":\"" << dtype_acc << "\","
<< "\"dtype_c\":\"" << dtype_c << "\","
<< "\"layout_a\":\"" << layout_a << "\","
<< "\"layout_b\":\"" << layout_b << "\","
<< "\"layout_c\":\"" << layout_c << "\""
<< "}";
return oss.str();
os << "{\n"
<< " \"split_k\":" << problem.split_k << ",\n"
<< " \"m\":" << problem.m << ",\n"
<< " \"n\":" << problem.n << ",\n"
<< " \"k\":" << problem.k << ",\n"
<< " \"stride_a\":" << problem.stride_a << ",\n"
<< " \"stride_b\":" << problem.stride_b << ",\n"
<< " \"stride_c\":" << problem.stride_c << ",\n"
<< " \"dtype_a\":\"" << problem.dtype_a << "\",\n"
<< " \"dtype_b\":\"" << problem.dtype_b << "\",\n"
<< " \"dtype_acc\":\"" << problem.dtype_acc << "\",\n"
<< " \"dtype_c\":\"" << problem.dtype_c << "\",\n"
<< " \"layout_a\":\"" << problem.layout_a << "\",\n"
<< " \"layout_b\":\"" << problem.layout_b << "\",\n"
<< " \"layout_c\":\"" << problem.layout_c << "\"\n"
<< "}";
return os;
}
};
@@ -80,14 +79,15 @@ struct PerformanceResult
return false;
}
std::string serialize() const
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
{
std::ostringstream oss;
oss << "{"
<< "\"latency(ms)\":" << latency << ","
<< "\"tflops(TFlops)\":" << tflops << ","
<< "\"bandwidth(GB/s)\":" << bandwidth << "}";
return oss.str();
os << "{\n"
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency
<< ",\n"
<< " \"tflops(TFlops)\": " << result.tflops << ",\n"
<< " \"bandwidth(GB/s)\": " << result.bandwidth << "\n"
<< "}";
return os;
}
};
@@ -96,20 +96,18 @@ struct Environment
std::string rocm_version;
std::string device_name;
std::string serialize() const
friend std::ostream& operator<<(std::ostream& os, const Environment& env)
{
std::ostringstream oss;
oss << "{"
<< "\"rocm_version\":\"" << rocm_version << "\","
<< "\"device_name\":\"" << device_name << "\""
<< "}";
return oss.str();
os << "{\n"
<< " \"rocm_version\": \"" << env.rocm_version << "\",\n"
<< " \"device_name\": \"" << env.device_name << "\"\n"
<< "}";
return os;
}
};
struct KernelInstance
{
Environment env;
std::string name;
GemmProblem problem;
PerformanceResult perf_result;
@@ -122,11 +120,12 @@ struct KernelInstance
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
{
os << "{"
<< "\"env\":" << obj.env.serialize() << ","
<< "\"name\":\"" << obj.name << "\","
<< "\"problem\":" << obj.problem.serialize() << ","
<< "\"perf_result\":" << obj.perf_result.serialize() << "}";
os << "{\n"
<< " \"name\": \"" << "{\n"
<< obj.name << "\n}" << "\",\n"
<< " \"problem\": \"" << obj.problem << "\",\n"
<< " \"perf_result\": " << obj.perf_result << "\n"
<< "}";
return os;
}
};
@@ -257,6 +256,7 @@ inline auto create_args(int argc, char* argv[])
"compv3",
"The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.")
.insert("scheduler",
"intrawave",
"The type of pipeline. Possible values are compv3, compv4 or mem. Default is "
"compv3.")
.insert(

View File

@@ -183,55 +183,76 @@ def BOOL_MAP(b_) -> str:
return 'false'
@dataclass
class GemmConfig:
class GemmProblem:
def __init__(self, config_data):
self.matrix_cfg : Dict[str, Any] = {}
self.impl_cfg : Dict[str, Any] = {}
self.data : Dict[str, Any] = {}
for key, value in config_data.items():
if key in ["datatype", "layout_a", "layout_b", "layout_c"]:
self.matrix_cfg[key] = value
else:
self.impl_cfg[key] = value
self.data[key] = value
@property
def datatype(self) -> str:
return self.matrix_cfg["datatype"]["values"][0]
return self.data["datatype"]["values"][0]
@property
def layouts(self) -> List[str]:
return [
self.matrix_cfg["layout_a"]["values"][0],
self.matrix_cfg["layout_b"]["values"][0],
self.matrix_cfg["layout_c"]["values"][0]
self.data["layout_a"]["values"][0],
self.data["layout_b"]["values"][0],
self.data["layout_c"]["values"][0]
]
class GemmCodeGenerator:
def __init__(self, output_dir: str, config: GemmConfig):
self.output_dir = Path(output_dir)
if not self.output_dir.exists():
self.output_dir.mkdir()
@dataclass
class GemmConfig:
def __init__(self, config_data):
self.data : Dict[str, Any] = {}
for key, value in config_data.items():
self.data[key] = value
class GemmCodeGenerator:
def __init__(self, output_dir: str, problem: GemmProblem, use_default_config: bool, user_provide_config: Optional[GemmConfig] = None):
self.output_dir = Path(output_dir)
self.problem = problem
self.config = []
if use_default_config:
config_path = Path(__file__).resolve().parent / "configs" / "default_config.json"
with open(config_path, 'r') as json_file:
config_data = json.load(json_file)
default_config = GemmConfig(config_data)
self.config.append(default_config)
if user_provide_config is not None:
if not isinstance(user_provide_config, GemmConfig):
raise TypeError("user_provide_config must be a GemmConfig instance")
self.config.append(user_provide_config)
else:
if user_provide_config is None:
raise ValueError("user_provide_config must be provided when use_default_config=False")
if not isinstance(user_provide_config, GemmConfig):
raise TypeError("user_provide_config must be a GemmConfig instance")
self.config.append(user_provide_config)
self.config = config
self.all_kernels = []
self.unique_configs = []
# Validate configurations
self._validate_config()
self._check_validate()
def _validate_config(self):
"""Validate matrix and implementation configurations"""
# Matrix config validation
def _check_validate(self):
"""Validate matrix problem and kernel configurations"""
# Matrix problem validation
for param in ["datatype", "layout_a", "layout_b", "layout_c"]:
if len(self.config.matrix_cfg[param]["values"]) != 1:
raise ValueError(f"Matrix config {param} must have exactly one value")
if len(self.problem.data[param]["values"]) != 1:
raise ValueError(f"Matrix problem {param} must have exactly one value")
# Implementation traits validation
# kernel config validation
required_params = ["tile_m", "tile_n", "tile_k", "warp_m", "warp_n", "warp_k",
"warp_tile_m", "warp_tile_n", "warp_tile_k", "pipeline",
"epilogue", "scheduler", "kPadM", "kPadN", "kPadK"]
for param in required_params:
if not self.config.impl_cfg.get(param, {}).get("values"):
raise ValueError(f"Missing implementation parameter: {param}")
for config in self.config:
for param in required_params:
if not config.data.get(param, {}).get("values"):
raise ValueError(f"Missing kernel parameter: {param}")
def list_all(self):
"""List all possible kernel configurations"""
@@ -258,7 +279,12 @@ class GemmCodeGenerator:
]
# Generate all unique_combinations
_unique = set(itertools.product(*[self.config.impl_cfg[p]["values"] for (p, _) in params]))
_unique = set(
itertools.product(*[
[value for config in self.config for value in config.data[p]["values"]]
for (p, _) in params
])
)
for combo in _unique:
config = {name: value for (_, name), value in zip(params, combo)}
pipeline, epilogue, scheduler, kPadM, kPadN, kPadK = config.values()
@@ -280,12 +306,12 @@ class GemmCodeGenerator:
def _generate_common_header(self):
"""Generate common header with datatypes and layout"""
ctype = self.config.datatype
atype = self.config.datatype
btype = self.config.datatype
if self.config.datatype in ['fp8', 'bf8']:
ctype = self.problem.datatype
atype = self.problem.datatype
btype = self.problem.datatype
if self.problem.datatype in ['fp8', 'bf8']:
ctype = 'fp16'
elif self.config.datatype in ['int4']:
elif self.problem.datatype in ['int4']:
atype = 'fp16'
ctype = 'fp16'
@@ -302,9 +328,9 @@ using AccDataType = float;
using CDataType = {DATA_TYPE_MAP[ctype]};
// Layout configurations
using ALayout = {LAYOUT_MAP[self.config.layouts[0]]};
using BLayout = {LAYOUT_MAP[self.config.layouts[1]]};
using CLayout = {LAYOUT_MAP[self.config.layouts[2]]};
using ALayout = {LAYOUT_MAP[self.problem.layouts[0]]};
using BLayout = {LAYOUT_MAP[self.problem.layouts[1]]};
using CLayout = {LAYOUT_MAP[self.problem.layouts[2]]};
"""
@@ -504,19 +530,18 @@ struct GemmDispatcher {
auto& kernel_map = get_kernel_map();
if(!kernel_map.empty()) return;
\n"""
# Add tile/warp instantiations
tile_params = set(itertools.product(
self.config.impl_cfg["tile_m"]["values"],
self.config.impl_cfg["tile_n"]["values"],
self.config.impl_cfg["tile_k"]["values"],
self.config.impl_cfg["warp_m"]["values"],
self.config.impl_cfg["warp_n"]["values"],
self.config.impl_cfg["warp_k"]["values"],
self.config.impl_cfg["warp_tile_m"]["values"],
self.config.impl_cfg["warp_tile_n"]["values"],
self.config.impl_cfg["warp_tile_k"]["values"]
))
# Add tile/warp instantiations
tile_params = set(
itertools.product(*[
[value for config in self.config
for value in config.data[param]["values"]]
for param in [
"tile_m", "tile_n", "tile_k",
"warp_m", "warp_n", "warp_k",
"warp_tile_m", "warp_tile_n", "warp_tile_k"
]
])
)
for group in self.all_kernels:
content += f""" kernel_map["{group}"] = [](Profiler& profiler,
@@ -576,27 +601,31 @@ private:
(self.output_dir / "gemm_dispatcher.hpp").write_text(content)
def do_list_blobs(args, gemm_config):
generator = GemmCodeGenerator(args.working_path, gemm_config)
def do_list_blobs(args, gemm_problem, user_provide_config):
generator = GemmCodeGenerator(args.working_path, gemm_problem, args.use_default_config, user_provide_config)
generator.list_all()
def do_gen_blobs(args, gemm_config):
generator = GemmCodeGenerator(args.working_path, gemm_config)
def do_gen_blobs(args, gemm_problem, user_provide_config):
generator = GemmCodeGenerator(args.working_path, gemm_problem, args.use_default_config, user_provide_config)
generator.generate_all()
def main(args):
# Read json file
with open(args.json, 'r') as json_file:
# Read problem json file
with open(args.problem_json, 'r') as json_file:
config_data = json.load(json_file)
gemm_problem = GemmProblem(config_data)
# Read user provide json file
with open(args.config_json, 'r') as json_file:
config_data = json.load(json_file)
gemm_config = GemmConfig(config_data)
if args.list_blobs:
do_list_blobs(args, gemm_config)
do_list_blobs(args, gemm_problem, gemm_config)
elif args.gen_blobs:
do_gen_blobs(args, gemm_config)
do_gen_blobs(args, gemm_problem, gemm_config)
else:
# If neither was specified, either do nothing or default to gen_blobs
print("No mode specified (use --list_blobs or --gen_blobs). Generating by default...")
@@ -610,16 +639,22 @@ if __name__ == "__main__":
description="gen API for CK gemm kernel",
)
parser.add_argument(
"-w", "--working_path", default="./", required=False, help="the path where all the blobs are going to be generated"
"-w", "--working_path", default="./", required=False, help="The path where all the blobs are going to be generated"
)
parser.add_argument(
"-j", "--json", required=True, help="Path to the json which contains the kernel configurations"
"-pj", "--problem_json", required=True, help="Path to the json which defines gemm problem"
)
parser.add_argument(
"-l", "--list_blobs", action = 'store_true', help="List all kernel to file"
"-u", "--use_default_config", action = 'store_true', help="Wether use default config json file to generate kernel instance or not"
)
parser.add_argument(
"-g", "--gen_blobs", action = 'store_true', help="Generate all kernels into different files"
"-cj", "--config_json", required=True, help="Path to the json which contains the kernel configurations that user provide"
)
parser.add_argument(
"-l", "--list_blobs", action = 'store_true', help="List all kernel instance to file"
)
parser.add_argument(
"-g", "--gen_blobs", action = 'store_true', help="Generate all kernels instance into different files"
)
args = parser.parse_args()