mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
add default config
This commit is contained in:
@@ -4,30 +4,39 @@
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
|
||||
--problem_json ${CMAKE_CURRENT_LIST_DIR}/configs/gemm_problem.json
|
||||
--use_default_config
|
||||
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provide_config.json
|
||||
--list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS
|
||||
${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
|
||||
${CMAKE_CURRENT_LIST_DIR}/configs/gemm_problem.json
|
||||
${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json
|
||||
${CMAKE_CURRENT_LIST_DIR}/configs/user_provide_config.json
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
|
||||
message( FATAL_ERROR "Fail to list kernels via Python. ${ret}")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS)
|
||||
|
||||
message("test--------------------------------------------------")
|
||||
add_custom_command(
|
||||
OUTPUT ${GEMM_CODEGEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
|
||||
--problem_json ${CMAKE_CURRENT_LIST_DIR}/configs/gemm_problem.json
|
||||
--use_default_config
|
||||
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provide_config.json
|
||||
--gen_blobs
|
||||
DEPENDS ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt
|
||||
${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
|
||||
${CMAKE_CURRENT_LIST_DIR}/configs/gemm_problem.json
|
||||
${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json
|
||||
${CMAKE_CURRENT_LIST_DIR}/configs/user_provide_config.json
|
||||
)
|
||||
|
||||
set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm")
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
# GEMM Matrix Multiplication
|
||||
|
||||
Use the files in this folder to generate and build applications that run Matrix multiplications using ck_tile programming based on the kernel parameters mentioned in the config file `./configs/instance_combination.json`.
|
||||
Use the files in this folder to generate and build applications that run Matrix multiplications using ck_tile programming based on the kernel parameters mentioned in the config file.
|
||||
|
||||
# Gemm Problem
|
||||
|
||||
User needs to provide gemm problem such as datatype, layout in the config file. For reference please see `./configs/gemm_problem.json`.
|
||||
|
||||
|
||||
# Kernel Configurations
|
||||
|
||||
User needs to provide kernel configuration such as datatype, layout, tile size, warp size, padding, pipeline, scheduler and epilogue in the config file. For reference please see `./configs/instance_combination.json`
|
||||
User can provide kernel configuration such as tile size, warp size, padding, pipeline, scheduler and epilogue in the config file. For reference please see `./configs/user_provide_config.json`. The Tile engine also has default kernel configuration to expand the range of kernel configuration which is saved in `./configs/default.json`.
|
||||
|
||||
|
||||
## Build
|
||||
```
|
||||
|
||||
@@ -57,7 +57,7 @@ class Profiler
|
||||
BLayout::name,
|
||||
CLayout::name};
|
||||
|
||||
KernelInstance kernel_instance{environment_, description, problem, {-1.0f, -1.0f, -1.0f}};
|
||||
KernelInstance kernel_instance{description, problem, {-1.0f, -1.0f, -1.0f}};
|
||||
|
||||
float avg_time = Kernel::launch(args, s);
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
81
tile_engine/ops/gemm/configs/default_config.json
Normal file
81
tile_engine/ops/gemm/configs/default_config.json
Normal file
@@ -0,0 +1,81 @@
|
||||
{
|
||||
"tile_m": {
|
||||
"values": [
|
||||
256
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
256
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
64,
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"kPadM": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"kPadN": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"kPadK": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave",
|
||||
"interwave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"default",
|
||||
"cshuffle"
|
||||
]
|
||||
}
|
||||
}
|
||||
22
tile_engine/ops/gemm/configs/gemm_problem.json
Normal file
22
tile_engine/ops/gemm/configs/gemm_problem.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"layout_a": {
|
||||
"values": [
|
||||
"r"
|
||||
]
|
||||
},
|
||||
"layout_b": {
|
||||
"values": [
|
||||
"c"
|
||||
]
|
||||
},
|
||||
"layout_c": {
|
||||
"values": [
|
||||
"r"
|
||||
]
|
||||
},
|
||||
"datatype": {
|
||||
"values": [
|
||||
"fp16"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
{
|
||||
|
||||
"layout_a": {
|
||||
"values": ["r"]
|
||||
},
|
||||
"layout_b": {
|
||||
"values": ["c"]
|
||||
},
|
||||
"layout_c": {
|
||||
"values": ["r"]
|
||||
},
|
||||
"datatype": {
|
||||
"values": ["fp16"]
|
||||
},
|
||||
"tile_m": {
|
||||
"values": [256]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [256]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [64, 32]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [2]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [2]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [1]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [32]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [32]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [16]
|
||||
},
|
||||
"kPadM": {
|
||||
"values": [false]
|
||||
},
|
||||
"kPadN": {
|
||||
"values": [false]
|
||||
},
|
||||
"kPadK": {
|
||||
"values": [false]
|
||||
},
|
||||
"pipeline": {
|
||||
"values": ["compv3", "mem"]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": ["intrawave", "interwave"]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": ["default", "cshuffle"]
|
||||
}
|
||||
}
|
||||
81
tile_engine/ops/gemm/configs/user_provide_config.json
Normal file
81
tile_engine/ops/gemm/configs/user_provide_config.json
Normal file
@@ -0,0 +1,81 @@
|
||||
{
|
||||
"tile_m": {
|
||||
"values": [
|
||||
256
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
256
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
64,
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"kPadM": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"kPadN": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"kPadK": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave",
|
||||
"interwave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"default",
|
||||
"cshuffle"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -39,26 +39,25 @@ struct GemmProblem
|
||||
std::string dtype_a, dtype_b, dtype_acc, dtype_c;
|
||||
std::string layout_a, layout_b, layout_c;
|
||||
|
||||
std::string serialize() const
|
||||
friend std::ostream& operator<<(std::ostream& os, const GemmProblem& problem)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "{"
|
||||
<< "\"split_k\":" << split_k << ","
|
||||
<< "\"m\":" << m << ","
|
||||
<< "\"n\":" << n << ","
|
||||
<< "\"k\":" << k << ","
|
||||
<< "\"stride_a\":" << stride_a << ","
|
||||
<< "\"stride_b\":" << stride_b << ","
|
||||
<< "\"stride_c\":" << stride_c << ","
|
||||
<< "\"dtype_a\":\"" << dtype_a << "\","
|
||||
<< "\"dtype_b\":\"" << dtype_b << "\","
|
||||
<< "\"dtype_acc\":\"" << dtype_acc << "\","
|
||||
<< "\"dtype_c\":\"" << dtype_c << "\","
|
||||
<< "\"layout_a\":\"" << layout_a << "\","
|
||||
<< "\"layout_b\":\"" << layout_b << "\","
|
||||
<< "\"layout_c\":\"" << layout_c << "\""
|
||||
<< "}";
|
||||
return oss.str();
|
||||
os << "{\n"
|
||||
<< " \"split_k\":" << problem.split_k << ",\n"
|
||||
<< " \"m\":" << problem.m << ",\n"
|
||||
<< " \"n\":" << problem.n << ",\n"
|
||||
<< " \"k\":" << problem.k << ",\n"
|
||||
<< " \"stride_a\":" << problem.stride_a << ",\n"
|
||||
<< " \"stride_b\":" << problem.stride_b << ",\n"
|
||||
<< " \"stride_c\":" << problem.stride_c << ",\n"
|
||||
<< " \"dtype_a\":\"" << problem.dtype_a << "\",\n"
|
||||
<< " \"dtype_b\":\"" << problem.dtype_b << "\",\n"
|
||||
<< " \"dtype_acc\":\"" << problem.dtype_acc << "\",\n"
|
||||
<< " \"dtype_c\":\"" << problem.dtype_c << "\",\n"
|
||||
<< " \"layout_a\":\"" << problem.layout_a << "\",\n"
|
||||
<< " \"layout_b\":\"" << problem.layout_b << "\",\n"
|
||||
<< " \"layout_c\":\"" << problem.layout_c << "\"\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -80,14 +79,15 @@ struct PerformanceResult
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string serialize() const
|
||||
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "{"
|
||||
<< "\"latency(ms)\":" << latency << ","
|
||||
<< "\"tflops(TFlops)\":" << tflops << ","
|
||||
<< "\"bandwidth(GB/s)\":" << bandwidth << "}";
|
||||
return oss.str();
|
||||
os << "{\n"
|
||||
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency
|
||||
<< ",\n"
|
||||
<< " \"tflops(TFlops)\": " << result.tflops << ",\n"
|
||||
<< " \"bandwidth(GB/s)\": " << result.bandwidth << "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -96,20 +96,18 @@ struct Environment
|
||||
std::string rocm_version;
|
||||
std::string device_name;
|
||||
|
||||
std::string serialize() const
|
||||
friend std::ostream& operator<<(std::ostream& os, const Environment& env)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "{"
|
||||
<< "\"rocm_version\":\"" << rocm_version << "\","
|
||||
<< "\"device_name\":\"" << device_name << "\""
|
||||
<< "}";
|
||||
return oss.str();
|
||||
os << "{\n"
|
||||
<< " \"rocm_version\": \"" << env.rocm_version << "\",\n"
|
||||
<< " \"device_name\": \"" << env.device_name << "\"\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct KernelInstance
|
||||
{
|
||||
Environment env;
|
||||
std::string name;
|
||||
GemmProblem problem;
|
||||
PerformanceResult perf_result;
|
||||
@@ -122,11 +120,12 @@ struct KernelInstance
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
|
||||
{
|
||||
os << "{"
|
||||
<< "\"env\":" << obj.env.serialize() << ","
|
||||
<< "\"name\":\"" << obj.name << "\","
|
||||
<< "\"problem\":" << obj.problem.serialize() << ","
|
||||
<< "\"perf_result\":" << obj.perf_result.serialize() << "}";
|
||||
os << "{\n"
|
||||
<< " \"name\": \"" << "{\n"
|
||||
<< obj.name << "\n}" << "\",\n"
|
||||
<< " \"problem\": \"" << obj.problem << "\",\n"
|
||||
<< " \"perf_result\": " << obj.perf_result << "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
@@ -257,6 +256,7 @@ inline auto create_args(int argc, char* argv[])
|
||||
"compv3",
|
||||
"The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.")
|
||||
.insert("scheduler",
|
||||
"intrawave",
|
||||
"The type of pipeline. Possible values are compv3, compv4 or mem. Default is "
|
||||
"compv3.")
|
||||
.insert(
|
||||
|
||||
@@ -183,55 +183,76 @@ def BOOL_MAP(b_) -> str:
|
||||
return 'false'
|
||||
|
||||
@dataclass
|
||||
class GemmConfig:
|
||||
class GemmProblem:
|
||||
def __init__(self, config_data):
|
||||
self.matrix_cfg : Dict[str, Any] = {}
|
||||
self.impl_cfg : Dict[str, Any] = {}
|
||||
self.data : Dict[str, Any] = {}
|
||||
for key, value in config_data.items():
|
||||
if key in ["datatype", "layout_a", "layout_b", "layout_c"]:
|
||||
self.matrix_cfg[key] = value
|
||||
else:
|
||||
self.impl_cfg[key] = value
|
||||
|
||||
self.data[key] = value
|
||||
|
||||
@property
|
||||
def datatype(self) -> str:
|
||||
return self.matrix_cfg["datatype"]["values"][0]
|
||||
return self.data["datatype"]["values"][0]
|
||||
|
||||
@property
|
||||
def layouts(self) -> List[str]:
|
||||
return [
|
||||
self.matrix_cfg["layout_a"]["values"][0],
|
||||
self.matrix_cfg["layout_b"]["values"][0],
|
||||
self.matrix_cfg["layout_c"]["values"][0]
|
||||
self.data["layout_a"]["values"][0],
|
||||
self.data["layout_b"]["values"][0],
|
||||
self.data["layout_c"]["values"][0]
|
||||
]
|
||||
|
||||
|
||||
class GemmCodeGenerator:
|
||||
def __init__(self, output_dir: str, config: GemmConfig):
|
||||
self.output_dir = Path(output_dir)
|
||||
if not self.output_dir.exists():
|
||||
self.output_dir.mkdir()
|
||||
@dataclass
|
||||
class GemmConfig:
|
||||
def __init__(self, config_data):
|
||||
self.data : Dict[str, Any] = {}
|
||||
for key, value in config_data.items():
|
||||
self.data[key] = value
|
||||
|
||||
|
||||
class GemmCodeGenerator:
|
||||
def __init__(self, output_dir: str, problem: GemmProblem, use_default_config: bool, user_provide_config: Optional[GemmConfig] = None):
|
||||
self.output_dir = Path(output_dir)
|
||||
self.problem = problem
|
||||
|
||||
self.config = []
|
||||
if use_default_config:
|
||||
config_path = Path(__file__).resolve().parent / "configs" / "default_config.json"
|
||||
with open(config_path, 'r') as json_file:
|
||||
config_data = json.load(json_file)
|
||||
default_config = GemmConfig(config_data)
|
||||
self.config.append(default_config)
|
||||
if user_provide_config is not None:
|
||||
if not isinstance(user_provide_config, GemmConfig):
|
||||
raise TypeError("user_provide_config must be a GemmConfig instance")
|
||||
self.config.append(user_provide_config)
|
||||
else:
|
||||
if user_provide_config is None:
|
||||
raise ValueError("user_provide_config must be provided when use_default_config=False")
|
||||
if not isinstance(user_provide_config, GemmConfig):
|
||||
raise TypeError("user_provide_config must be a GemmConfig instance")
|
||||
self.config.append(user_provide_config)
|
||||
|
||||
self.config = config
|
||||
self.all_kernels = []
|
||||
self.unique_configs = []
|
||||
# Validate configurations
|
||||
self._validate_config()
|
||||
self._check_validate()
|
||||
|
||||
def _validate_config(self):
|
||||
"""Validate matrix and implementation configurations"""
|
||||
# Matrix config validation
|
||||
def _check_validate(self):
|
||||
"""Validate matrix problem and kernel configurations"""
|
||||
# Matrix problem validation
|
||||
for param in ["datatype", "layout_a", "layout_b", "layout_c"]:
|
||||
if len(self.config.matrix_cfg[param]["values"]) != 1:
|
||||
raise ValueError(f"Matrix config {param} must have exactly one value")
|
||||
if len(self.problem.data[param]["values"]) != 1:
|
||||
raise ValueError(f"Matrix problem {param} must have exactly one value")
|
||||
|
||||
# Implementation traits validation
|
||||
# kernel config validation
|
||||
required_params = ["tile_m", "tile_n", "tile_k", "warp_m", "warp_n", "warp_k",
|
||||
"warp_tile_m", "warp_tile_n", "warp_tile_k", "pipeline",
|
||||
"epilogue", "scheduler", "kPadM", "kPadN", "kPadK"]
|
||||
for param in required_params:
|
||||
if not self.config.impl_cfg.get(param, {}).get("values"):
|
||||
raise ValueError(f"Missing implementation parameter: {param}")
|
||||
for config in self.config:
|
||||
for param in required_params:
|
||||
if not config.data.get(param, {}).get("values"):
|
||||
raise ValueError(f"Missing kernel parameter: {param}")
|
||||
|
||||
def list_all(self):
|
||||
"""List all possible kernel configurations"""
|
||||
@@ -258,7 +279,12 @@ class GemmCodeGenerator:
|
||||
]
|
||||
|
||||
# Generate all unique_combinations
|
||||
_unique = set(itertools.product(*[self.config.impl_cfg[p]["values"] for (p, _) in params]))
|
||||
_unique = set(
|
||||
itertools.product(*[
|
||||
[value for config in self.config for value in config.data[p]["values"]]
|
||||
for (p, _) in params
|
||||
])
|
||||
)
|
||||
for combo in _unique:
|
||||
config = {name: value for (_, name), value in zip(params, combo)}
|
||||
pipeline, epilogue, scheduler, kPadM, kPadN, kPadK = config.values()
|
||||
@@ -280,12 +306,12 @@ class GemmCodeGenerator:
|
||||
|
||||
def _generate_common_header(self):
|
||||
"""Generate common header with datatypes and layout"""
|
||||
ctype = self.config.datatype
|
||||
atype = self.config.datatype
|
||||
btype = self.config.datatype
|
||||
if self.config.datatype in ['fp8', 'bf8']:
|
||||
ctype = self.problem.datatype
|
||||
atype = self.problem.datatype
|
||||
btype = self.problem.datatype
|
||||
if self.problem.datatype in ['fp8', 'bf8']:
|
||||
ctype = 'fp16'
|
||||
elif self.config.datatype in ['int4']:
|
||||
elif self.problem.datatype in ['int4']:
|
||||
atype = 'fp16'
|
||||
ctype = 'fp16'
|
||||
|
||||
@@ -302,9 +328,9 @@ using AccDataType = float;
|
||||
using CDataType = {DATA_TYPE_MAP[ctype]};
|
||||
|
||||
// Layout configurations
|
||||
using ALayout = {LAYOUT_MAP[self.config.layouts[0]]};
|
||||
using BLayout = {LAYOUT_MAP[self.config.layouts[1]]};
|
||||
using CLayout = {LAYOUT_MAP[self.config.layouts[2]]};
|
||||
using ALayout = {LAYOUT_MAP[self.problem.layouts[0]]};
|
||||
using BLayout = {LAYOUT_MAP[self.problem.layouts[1]]};
|
||||
using CLayout = {LAYOUT_MAP[self.problem.layouts[2]]};
|
||||
"""
|
||||
|
||||
|
||||
@@ -504,19 +530,18 @@ struct GemmDispatcher {
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(!kernel_map.empty()) return;
|
||||
\n"""
|
||||
# Add tile/warp instantiations
|
||||
tile_params = set(itertools.product(
|
||||
self.config.impl_cfg["tile_m"]["values"],
|
||||
self.config.impl_cfg["tile_n"]["values"],
|
||||
self.config.impl_cfg["tile_k"]["values"],
|
||||
self.config.impl_cfg["warp_m"]["values"],
|
||||
self.config.impl_cfg["warp_n"]["values"],
|
||||
self.config.impl_cfg["warp_k"]["values"],
|
||||
self.config.impl_cfg["warp_tile_m"]["values"],
|
||||
self.config.impl_cfg["warp_tile_n"]["values"],
|
||||
self.config.impl_cfg["warp_tile_k"]["values"]
|
||||
))
|
||||
|
||||
# Add tile/warp instantiations
|
||||
tile_params = set(
|
||||
itertools.product(*[
|
||||
[value for config in self.config
|
||||
for value in config.data[param]["values"]]
|
||||
for param in [
|
||||
"tile_m", "tile_n", "tile_k",
|
||||
"warp_m", "warp_n", "warp_k",
|
||||
"warp_tile_m", "warp_tile_n", "warp_tile_k"
|
||||
]
|
||||
])
|
||||
)
|
||||
|
||||
for group in self.all_kernels:
|
||||
content += f""" kernel_map["{group}"] = [](Profiler& profiler,
|
||||
@@ -576,27 +601,31 @@ private:
|
||||
(self.output_dir / "gemm_dispatcher.hpp").write_text(content)
|
||||
|
||||
|
||||
def do_list_blobs(args, gemm_config):
|
||||
generator = GemmCodeGenerator(args.working_path, gemm_config)
|
||||
def do_list_blobs(args, gemm_problem, user_provide_config):
|
||||
generator = GemmCodeGenerator(args.working_path, gemm_problem, args.use_default_config, user_provide_config)
|
||||
generator.list_all()
|
||||
|
||||
def do_gen_blobs(args, gemm_config):
|
||||
generator = GemmCodeGenerator(args.working_path, gemm_config)
|
||||
def do_gen_blobs(args, gemm_problem, user_provide_config):
|
||||
generator = GemmCodeGenerator(args.working_path, gemm_problem, args.use_default_config, user_provide_config)
|
||||
generator.generate_all()
|
||||
|
||||
|
||||
|
||||
def main(args):
|
||||
# Read json file
|
||||
with open(args.json, 'r') as json_file:
|
||||
# Read problem json file
|
||||
with open(args.problem_json, 'r') as json_file:
|
||||
config_data = json.load(json_file)
|
||||
gemm_problem = GemmProblem(config_data)
|
||||
|
||||
# Read user provide json file
|
||||
with open(args.config_json, 'r') as json_file:
|
||||
config_data = json.load(json_file)
|
||||
|
||||
gemm_config = GemmConfig(config_data)
|
||||
|
||||
if args.list_blobs:
|
||||
do_list_blobs(args, gemm_config)
|
||||
do_list_blobs(args, gemm_problem, gemm_config)
|
||||
elif args.gen_blobs:
|
||||
do_gen_blobs(args, gemm_config)
|
||||
do_gen_blobs(args, gemm_problem, gemm_config)
|
||||
else:
|
||||
# If neither was specified, either do nothing or default to gen_blobs
|
||||
print("No mode specified (use --list_blobs or --gen_blobs). Generating by default...")
|
||||
@@ -610,16 +639,22 @@ if __name__ == "__main__":
|
||||
description="gen API for CK gemm kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w", "--working_path", default="./", required=False, help="the path where all the blobs are going to be generated"
|
||||
"-w", "--working_path", default="./", required=False, help="The path where all the blobs are going to be generated"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-j", "--json", required=True, help="Path to the json which contains the kernel configurations"
|
||||
"-pj", "--problem_json", required=True, help="Path to the json which defines gemm problem"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l", "--list_blobs", action = 'store_true', help="List all kernel to file"
|
||||
"-u", "--use_default_config", action = 'store_true', help="Wether use default config json file to generate kernel instance or not"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g", "--gen_blobs", action = 'store_true', help="Generate all kernels into different files"
|
||||
"-cj", "--config_json", required=True, help="Path to the json which contains the kernel configurations that user provide"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l", "--list_blobs", action = 'store_true', help="List all kernel instance to file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g", "--gen_blobs", action = 'store_true', help="Generate all kernels instance into different files"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user