mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK Tile] Int8 Support on CK Tile GEMM (#2267)
* updates to support int8 in 03_gemm example * added comments, using aliases, helper functions * test(gemm_universal): add test cases for int8 gemm pipeline * fix(test_gemm): fix for failing test unit test for int8 * test(ck_tile): add int8 unit test for gemm universal * refactor(gemm_universal): GPU reference verification for GEMM code improved * style(gemm_universal): removed extra comments and did clang format * merging recent changes to universal gemm to tile_engine * ck tile engine integration work * feat(tile_engine): add int8 support to tile engine ops/gemm * feat(tile_engine): added 32 32 16 mfma instances to tile engine for int8 * style: Format code with clang-format-12 * refactor(tile_engine): address review comments * style: removed unhelpful comments & unused variables. * build: tile engine uses default config * feat: add int8 support for CK_TILE GEMM * style: added trailing commas to codegen_utils.py * refactor: tile engine * refactor: formatting and code review * refactor: code formatting for python files * fix: suppress build warning * add support for gfx950 * refactor:KWarpTile size in gemms util * Fix the branch and wrap up the k warp tile * Add bf8 integration * refactor: clang format and rebase --------- Co-authored-by: zjli2013 <leezhengjiang@gmail.com> Co-authored-by: AviralGoelAMD <aviral.goel@amd.com> Co-authored-by: Khushbu Agarwal <khuagarw@amd.com>
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
|
||||
# generate a list of kernels, but not actually emit files at config stage
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
|
||||
@@ -11,17 +11,21 @@ import subprocess
|
||||
import re
|
||||
from functools import lru_cache
|
||||
|
||||
DATA_TYPE_MAP = {'fp32': 'float',
|
||||
'fp16': 'ck_tile::half_t',
|
||||
'bf16': 'ck_tile::bf16_t',
|
||||
'int8': 'ck_tile::int8_t',
|
||||
'fp8': 'ck_tile::fp8_t',
|
||||
'bf8': 'ck_tile::bf8_t',
|
||||
'int4': 'ck_tile::pk_int4_t'
|
||||
}
|
||||
DATA_TYPE_MAP = {
|
||||
"fp32": "float",
|
||||
"fp16": "ck_tile::half_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"int8": "ck_tile::int8_t",
|
||||
"fp8": "ck_tile::fp8_t",
|
||||
"bf8": "ck_tile::bf8_t",
|
||||
"int4": "ck_tile::pk_int4_t",
|
||||
"int32": "ck_tile::int32_t",
|
||||
}
|
||||
|
||||
LAYOUT_MAP = {'r': 'ck_tile::tensor_layout::gemm::RowMajor',
|
||||
'c': 'ck_tile::tensor_layout::gemm::ColumnMajor'}
|
||||
LAYOUT_MAP = {
|
||||
"r": "ck_tile::tensor_layout::gemm::RowMajor",
|
||||
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
|
||||
}
|
||||
|
||||
DEFAULT_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
@@ -149,44 +153,109 @@ RUN_COMPV4 = """
|
||||
"""
|
||||
|
||||
|
||||
PIPELINE_MAP = {'mem': ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'],
|
||||
'compv3': ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'],
|
||||
'compv4': ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']}
|
||||
PIPELINE_MAP = {
|
||||
"mem": ["ck_tile::BaseGemmPipelineAgBgCrMem", "ck_tile::GemmPipelineAgBgCrMem"],
|
||||
"compv3": [
|
||||
"ck_tile::BaseGemmPipelineAgBgCrCompV3",
|
||||
"ck_tile::GemmPipelineAgBgCrCompV3",
|
||||
],
|
||||
"compv4": [
|
||||
"ck_tile::BaseGemmPipelineAgBgCrCompV4",
|
||||
"ck_tile::GemmPipelineAgBgCrCompV4",
|
||||
],
|
||||
}
|
||||
|
||||
SCHEDULER_MAP = {'interwave': 'ck_tile::GemmPipelineScheduler::Interwave',
|
||||
'intrawave': 'ck_tile::GemmPipelineScheduler::Intrawave'}
|
||||
SCHEDULER_MAP = {
|
||||
"interwave": "ck_tile::GemmPipelineScheduler::Interwave",
|
||||
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
|
||||
}
|
||||
|
||||
EPILOGUE_MAP = {'default': DEFAULT_EPILOGUE,
|
||||
'cshuffle': CSHUFFLE_EPILOGUE}
|
||||
EPILOGUE_MAP = {"default": DEFAULT_EPILOGUE, "cshuffle": CSHUFFLE_EPILOGUE}
|
||||
|
||||
HOT_LOOP_TRUE = {'mem': RUN_MEM,
|
||||
'compv3': RUN_COMPV3,
|
||||
'compv4': RUN_COMPV4}
|
||||
HOT_LOOP_TRUE = {"mem": RUN_MEM, "compv3": RUN_COMPV3, "compv4": RUN_COMPV4}
|
||||
|
||||
|
||||
def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)]
|
||||
def BOOL_MAP(b_):
|
||||
return {True: "true", False: "false"}[bool(b_)]
|
||||
|
||||
|
||||
# To Do: add some more supported combinations
|
||||
warp_tile_supported_combinations = {
|
||||
"gfx90a": {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]]
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
},
|
||||
"gfx942": {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]]
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
|
||||
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
|
||||
},
|
||||
"gfx950": {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
|
||||
}
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"fp8_fp8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 32],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# To Do: remove some unsupported combinations
|
||||
@@ -194,24 +263,30 @@ trait_unsupported_combinations = {
|
||||
("compv3", "cshuffle", "interwave"),
|
||||
("compv3", "default", "interwave"),
|
||||
("compv4", "cshuffle", "interwave"),
|
||||
("compv4", "default", "interwave")
|
||||
("compv4", "default", "interwave"),
|
||||
}
|
||||
|
||||
|
||||
ELEMENT_SIZE_MAP = {
|
||||
"fp16": 2,
|
||||
"bf16": 2,
|
||||
"int8": 1,
|
||||
"fp8": 1,
|
||||
"bf8": 1,
|
||||
"int4": 0.5,
|
||||
"int32": 4,
|
||||
}
|
||||
|
||||
|
||||
def element_size(data_type: str) -> float:
|
||||
"""Calculate the size (in bytes) of a single element for given data type."""
|
||||
data_type = data_type.lower()
|
||||
if data_type in {'fp16', 'bf16'}:
|
||||
return 2
|
||||
elif data_type in {'int8', 'fp8', 'bf8'}:
|
||||
return 1
|
||||
elif data_type == 'int4':
|
||||
return 0.5
|
||||
else:
|
||||
if data_type not in ELEMENT_SIZE_MAP:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
return ELEMENT_SIZE_MAP[data_type]
|
||||
|
||||
|
||||
GPU_NAME_PATTERN = re.compile(r'Name:\s*(gfx\d+\w*)')
|
||||
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@@ -219,10 +294,7 @@ def get_gpu_name_by_id(gpu_id: int = 0) -> str:
|
||||
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["rocminfo"],
|
||||
text=True,
|
||||
stderr=subprocess.PIPE,
|
||||
timeout=5
|
||||
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
|
||||
)
|
||||
if matches := GPU_NAME_PATTERN.finditer(output):
|
||||
gpu_list = [m.group(1) for m in matches]
|
||||
|
||||
@@ -33,19 +33,19 @@
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"max": 512,
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64,
|
||||
"exclude": []
|
||||
},
|
||||
"tile_n": {
|
||||
"max": 512,
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 32,
|
||||
"exclude": []
|
||||
},
|
||||
"tile_k": {
|
||||
"max": 512,
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64,
|
||||
"exclude": [192]
|
||||
|
||||
@@ -17,17 +17,17 @@
|
||||
},
|
||||
"datatype_a": {
|
||||
"values": [
|
||||
"fp16"
|
||||
"int8"
|
||||
]
|
||||
},
|
||||
"datatype_b": {
|
||||
"values": [
|
||||
"fp16"
|
||||
"int8"
|
||||
]
|
||||
},
|
||||
"datatype_c": {
|
||||
"values": [
|
||||
"fp16"
|
||||
"int32"
|
||||
]
|
||||
}
|
||||
},
|
||||
@@ -44,7 +44,7 @@
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
32
|
||||
128
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
@@ -64,17 +64,17 @@
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
32
|
||||
16, 32
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
32
|
||||
16, 32
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
16, 32
|
||||
]
|
||||
}
|
||||
},
|
||||
|
||||
@@ -50,6 +50,18 @@ struct DataTypeTraits<ck_tile::bf8_t>
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
|
||||
@@ -29,10 +29,9 @@ from codegen_utils import (
|
||||
warp_tile_supported_combinations,
|
||||
trait_unsupported_combinations,
|
||||
element_size,
|
||||
get_gpu_name_by_id
|
||||
get_gpu_name_by_id,
|
||||
)
|
||||
import logging
|
||||
import time
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
@@ -40,16 +39,18 @@ logging.basicConfig(level=logging.INFO)
|
||||
class GemmCodeGenerator:
|
||||
"""GEMM (General Matrix Multiplication) code generator."""
|
||||
|
||||
def __init__(self, output_dir: str,
|
||||
user_provided_config: Optional[GemmConfig] = None):
|
||||
def __init__(
|
||||
self, output_dir: str, user_provided_config: Optional[GemmConfig] = None
|
||||
):
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if user_provided_config is not None:
|
||||
self.config = user_provided_config
|
||||
else:
|
||||
config_path = Path(__file__).resolve().parent / \
|
||||
"configs" / "default_config.json"
|
||||
config_path = (
|
||||
Path(__file__).resolve().parent / "configs" / "default_config.json"
|
||||
)
|
||||
self.config = GemmConfig.from_json(config_path)
|
||||
|
||||
self.valid_trait_names: List[str] = []
|
||||
@@ -58,46 +59,82 @@ class GemmCodeGenerator:
|
||||
def list_all_trait_names(self):
|
||||
"""List all possible kernel trait names into file."""
|
||||
w_p = Path(self.output_dir)
|
||||
file_path = w_p / 'gemm_instance_blobs.txt'
|
||||
file_path = w_p / "gemm_instance_blobs.txt"
|
||||
self._generate_all_traits()
|
||||
self._get_valid_trait_tile_combinations()
|
||||
|
||||
# Write all file paths to the header file
|
||||
with file_path.open('w') as f:
|
||||
f.write(str(w_p / "gemm_common.hpp") + "\n")
|
||||
f.write(str(w_p / "gemm_instances.hpp") + "\n")
|
||||
f.write(str(w_p / "gemm_dispatcher.hpp") + "\n")
|
||||
files_listed = 0
|
||||
with file_path.open("w") as f:
|
||||
# Core files
|
||||
core_files = [
|
||||
"gemm_common.hpp",
|
||||
"gemm_instances.hpp",
|
||||
"gemm_dispatcher.hpp",
|
||||
]
|
||||
for core_file in core_files:
|
||||
f.write(str(w_p / core_file) + "\n")
|
||||
files_listed += 1
|
||||
|
||||
# Trait header files
|
||||
for trait in self.valid_trait_names:
|
||||
f.write(str(w_p / f"gemm_{trait}.hpp") + "\n")
|
||||
trait_file = f"gemm_{trait}.hpp"
|
||||
f.write(str(w_p / trait_file) + "\n")
|
||||
files_listed += 1
|
||||
|
||||
# Instance source files
|
||||
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
|
||||
for tile in tile_valid_params:
|
||||
for tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k in tile:
|
||||
sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \
|
||||
self.config.problem.datatype_map['matrix_b'] == 'fp16' and \
|
||||
self.config.problem.datatype_map['matrix_c'] == 'fp16' and \
|
||||
((warp_tile_m == 32 and warp_tile_n == 32 and warp_tile_k == 16) or
|
||||
(warp_tile_m == 16 and warp_tile_n == 16 and warp_tile_k == 32))
|
||||
for (
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
) in tile:
|
||||
instance_name = f"{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
|
||||
sparse = (
|
||||
self.config.problem.datatype_map["matrix_a"] == "fp16"
|
||||
and self.config.problem.datatype_map["matrix_b"] == "fp16"
|
||||
and self.config.problem.datatype_map["matrix_c"] == "fp16"
|
||||
and (
|
||||
(
|
||||
warp_tile_m == 32
|
||||
and warp_tile_n == 32
|
||||
and warp_tile_k == 16
|
||||
)
|
||||
or (
|
||||
warp_tile_m == 16
|
||||
and warp_tile_n == 16
|
||||
and warp_tile_k == 32
|
||||
)
|
||||
)
|
||||
)
|
||||
if sparse:
|
||||
f.write(str(
|
||||
w_p / f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_true.cpp") + "\n")
|
||||
f.write(str(
|
||||
w_p / f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_false.cpp") + "\n")
|
||||
sparse_file = f"gemm_{trait}_{instance_name}_true.cpp"
|
||||
f.write(str(w_p / sparse_file) + "\n")
|
||||
files_listed += 1
|
||||
|
||||
regular_file = f"gemm_{trait}_{instance_name}_false.cpp"
|
||||
f.write(str(w_p / regular_file) + "\n")
|
||||
files_listed += 1
|
||||
|
||||
print(f"File listing complete: {files_listed} files listed in {file_path}\n")
|
||||
|
||||
def _generate_all_traits(self):
|
||||
"""Generate all possible kernel traits names."""
|
||||
params = [
|
||||
"pipeline",
|
||||
"epilogue",
|
||||
"scheduler",
|
||||
"pad_m",
|
||||
"pad_n",
|
||||
"pad_k"]
|
||||
params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"]
|
||||
|
||||
# Generate all unique_combinations
|
||||
_unique = set(itertools.product(*[
|
||||
getattr(self.config.trait_config, param).values
|
||||
for param in params
|
||||
]))
|
||||
_unique = set(
|
||||
itertools.product(
|
||||
*[getattr(self.config.trait_config, param).values for param in params]
|
||||
)
|
||||
)
|
||||
|
||||
for combo in _unique:
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = combo
|
||||
@@ -110,9 +147,7 @@ class GemmCodeGenerator:
|
||||
)
|
||||
self.valid_trait_names.append(trait_name)
|
||||
else:
|
||||
logging.debug(
|
||||
f"Invalid combination: {pipeline}-{epilogue}-{scheduler}"
|
||||
)
|
||||
logging.debug(f"Invalid combination: {pipeline}-{epilogue}-{scheduler}")
|
||||
|
||||
def generate_all_instance_files(self):
|
||||
"""Generate all kernel instances files."""
|
||||
@@ -123,6 +158,16 @@ class GemmCodeGenerator:
|
||||
def _generate_common_header_file(self):
|
||||
"""Generate common header file with datatypes and layout."""
|
||||
|
||||
# Determine appropriate accumulation type based on input types
|
||||
a_type = self.config.problem.datatype_map["matrix_a"]
|
||||
b_type = self.config.problem.datatype_map["matrix_b"]
|
||||
c_type = self.config.problem.datatype_map["matrix_c"]
|
||||
|
||||
if a_type in ["int8", "int4"] and b_type in ["int8", "int4"]:
|
||||
acc_type = "ck_tile::int32_t"
|
||||
else:
|
||||
acc_type = "float"
|
||||
|
||||
content = f"""// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
@@ -132,15 +177,15 @@ class GemmCodeGenerator:
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
// Data types
|
||||
using ADataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_a']]};
|
||||
using BDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_b']]};
|
||||
using AccDataType = float;
|
||||
using CDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_c']]};
|
||||
using ADataType = {DATA_TYPE_MAP[self.config.problem.datatype_map["matrix_a"]]};
|
||||
using BDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map["matrix_b"]]};
|
||||
using AccDataType = {acc_type};
|
||||
using CDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map["matrix_c"]]};
|
||||
|
||||
// Layout configurations
|
||||
using ALayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_a']]};
|
||||
using BLayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_b']]};
|
||||
using CLayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_c']]};
|
||||
using ALayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_a"]]};
|
||||
using BLayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_b"]]};
|
||||
using CLayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_c"]]};
|
||||
"""
|
||||
|
||||
(self.output_dir / "gemm_common.hpp").write_text(content)
|
||||
@@ -174,13 +219,21 @@ namespace {trait} {{
|
||||
"""
|
||||
# Add template struct with configuration
|
||||
content += self._generate_kernel_struct(
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k)
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k
|
||||
)
|
||||
|
||||
content += f"\n}} // namespace {trait}\n"
|
||||
(self.output_dir / filename).write_text(content)
|
||||
|
||||
def _generate_kernel_struct(self, pipeline: str, epilogue: str, scheduler: str,
|
||||
pad_m: str, pad_n: str, pad_k: str) -> str:
|
||||
def _generate_kernel_struct(
|
||||
self,
|
||||
pipeline: str,
|
||||
epilogue: str,
|
||||
scheduler: str,
|
||||
pad_m: str,
|
||||
pad_n: str,
|
||||
pad_k: str,
|
||||
) -> str:
|
||||
"""Generate the code block of kernel struct"""
|
||||
return f"""
|
||||
|
||||
@@ -193,7 +246,7 @@ struct GemmKernel {{
|
||||
static constexpr bool kPadN = {pad_n};
|
||||
static constexpr bool kPadK = {pad_k};
|
||||
|
||||
static float launch(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{
|
||||
static float launch(ck_tile::GemmHostArgs<><>& args, const ck_tile::stream_config& stream) {{
|
||||
static constexpr bool permuteA = false;
|
||||
static constexpr bool permuteB = false;
|
||||
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
|
||||
@@ -307,6 +360,7 @@ struct GemmKernel {{
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_));
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_));
|
||||
}};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
stream,
|
||||
@@ -367,28 +421,36 @@ struct GemmKernel {{
|
||||
#pragma once
|
||||
"""
|
||||
for trait in self.valid_trait_names:
|
||||
content += f"#include \"gemm_{trait}.hpp\"\n"
|
||||
content += f'#include "gemm_{trait}.hpp"\n'
|
||||
(self.output_dir / "gemm_instances.hpp").write_text(content)
|
||||
|
||||
def is_tile_valid(self, tile: tuple, trait: str) -> bool:
|
||||
"""Check if the tile configuration is valid for the given trait."""
|
||||
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile
|
||||
(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
) = tile
|
||||
pipeline, *_ = trait.split("_")
|
||||
|
||||
# Parameter validity check
|
||||
invalid_params = []
|
||||
if (warp_m, warp_n, warp_k) not in [(1, 4, 1), (2, 2, 1), (4, 1, 1)]:
|
||||
invalid_params.append(
|
||||
f"warp_m({warp_m}) * warp_n({warp_n}) * warp_k({warp_k})")
|
||||
f"warp_m({warp_m}) * warp_n({warp_n}) * warp_k({warp_k})"
|
||||
)
|
||||
if (warp_m * warp_tile_m) == 0:
|
||||
invalid_params.append(
|
||||
f"warp_m({warp_m}) * warp_tile_m({warp_tile_m})")
|
||||
invalid_params.append(f"warp_m({warp_m}) * warp_tile_m({warp_tile_m})")
|
||||
if (warp_n * warp_tile_n) == 0:
|
||||
invalid_params.append(
|
||||
f"warp_n({warp_n}) * warp_tile_n({warp_tile_n})")
|
||||
invalid_params.append(f"warp_n({warp_n}) * warp_tile_n({warp_tile_n})")
|
||||
if (warp_k * warp_tile_k) == 0:
|
||||
invalid_params.append(
|
||||
f"warp_k({warp_k}) * warp_tile_k({warp_tile_k})")
|
||||
invalid_params.append(f"warp_k({warp_k}) * warp_tile_k({warp_tile_k})")
|
||||
|
||||
if invalid_params:
|
||||
logging.debug(
|
||||
@@ -397,18 +459,20 @@ struct GemmKernel {{
|
||||
f"warp_tile=({warp_tile_m},{warp_tile_n},{warp_tile_k})"
|
||||
)
|
||||
return False
|
||||
|
||||
# Dimension alignment check
|
||||
alignment_issues = []
|
||||
if tile_m % (warp_m * warp_tile_m) != 0:
|
||||
alignment_issues.append(
|
||||
f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}")
|
||||
f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}"
|
||||
)
|
||||
if tile_n % (warp_n * warp_tile_n) != 0:
|
||||
alignment_issues.append(
|
||||
f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}")
|
||||
f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}"
|
||||
)
|
||||
if tile_k % (warp_k * warp_tile_k) != 0:
|
||||
alignment_issues.append(
|
||||
f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}")
|
||||
f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}"
|
||||
)
|
||||
|
||||
if alignment_issues:
|
||||
logging.debug(
|
||||
@@ -419,17 +483,20 @@ struct GemmKernel {{
|
||||
return False
|
||||
|
||||
# LDS capacity verification
|
||||
matrix_a_size = (tile_m * tile_k) * \
|
||||
element_size(self.config.problem.datatype_map['matrix_a'])
|
||||
matrix_b_size = (tile_n * tile_k) * \
|
||||
element_size(self.config.problem.datatype_map['matrix_b'])
|
||||
matrix_a_size = (tile_m * tile_k) * element_size(
|
||||
self.config.problem.datatype_map["matrix_a"]
|
||||
)
|
||||
matrix_b_size = (tile_n * tile_k) * element_size(
|
||||
self.config.problem.datatype_map["matrix_b"]
|
||||
)
|
||||
total_tile_in_lds = matrix_a_size + matrix_b_size
|
||||
|
||||
max_tile_size = 2**15 if pipeline == "compv4" else 2**16
|
||||
|
||||
if total_tile_in_lds > max_tile_size:
|
||||
logging.debug(
|
||||
f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds/1024:.1f}KB) > "
|
||||
f"maximum allowed {max_tile_size:,}B ({max_tile_size/1024}KB). Breakdown:\n"
|
||||
f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > "
|
||||
f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n"
|
||||
f"- Matrix A ({self.config.problem.datatype_map['matrix_a']}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n"
|
||||
f"- Matrix B ({self.config.problem.datatype_map['matrix_b']}): {tile_n}x{tile_k} = {matrix_b_size:,}B"
|
||||
)
|
||||
@@ -440,16 +507,19 @@ struct GemmKernel {{
|
||||
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
|
||||
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
|
||||
gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {})
|
||||
if not gpu_warp_tile_key:
|
||||
logging.debug(
|
||||
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check.")
|
||||
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check."
|
||||
)
|
||||
return False
|
||||
|
||||
allowed_combinations = gpu_warp_tile_key.get(warp_tile_key, [])
|
||||
if not allowed_combinations:
|
||||
logging.debug(
|
||||
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check.")
|
||||
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check."
|
||||
)
|
||||
return False
|
||||
|
||||
if current_combination not in allowed_combinations:
|
||||
@@ -462,49 +532,68 @@ struct GemmKernel {{
|
||||
return True
|
||||
|
||||
def _get_valid_trait_tile_combinations(self):
|
||||
def get_tile_value(tile_param): return tile_param.generate_candidates(
|
||||
) if isinstance(tile_param, RangeConfigParam) else tile_param.values
|
||||
def get_tile_value(tile_param):
|
||||
return (
|
||||
tile_param.generate_candidates()
|
||||
if isinstance(tile_param, RangeConfigParam)
|
||||
else tile_param.values
|
||||
)
|
||||
|
||||
tile_group = list(itertools.product(
|
||||
get_tile_value(self.config.tile_config.tile_m),
|
||||
get_tile_value(self.config.tile_config.tile_n),
|
||||
get_tile_value(self.config.tile_config.tile_k)
|
||||
))
|
||||
tile_group = list(
|
||||
itertools.product(
|
||||
get_tile_value(self.config.tile_config.tile_m),
|
||||
get_tile_value(self.config.tile_config.tile_n),
|
||||
get_tile_value(self.config.tile_config.tile_k),
|
||||
)
|
||||
)
|
||||
|
||||
warp_group = list(itertools.product(
|
||||
get_tile_value(self.config.tile_config.warp_m),
|
||||
get_tile_value(self.config.tile_config.warp_n),
|
||||
get_tile_value(self.config.tile_config.warp_k)
|
||||
))
|
||||
warp_group = list(
|
||||
itertools.product(
|
||||
get_tile_value(self.config.tile_config.warp_m),
|
||||
get_tile_value(self.config.tile_config.warp_n),
|
||||
get_tile_value(self.config.tile_config.warp_k),
|
||||
)
|
||||
)
|
||||
|
||||
warp_tile_group = list(itertools.product(
|
||||
get_tile_value(self.config.tile_config.warp_tile_m),
|
||||
get_tile_value(self.config.tile_config.warp_tile_n),
|
||||
get_tile_value(self.config.tile_config.warp_tile_k)
|
||||
))
|
||||
warp_tile_group = list(
|
||||
itertools.product(
|
||||
get_tile_value(self.config.tile_config.warp_tile_m),
|
||||
get_tile_value(self.config.tile_config.warp_tile_n),
|
||||
get_tile_value(self.config.tile_config.warp_tile_k),
|
||||
)
|
||||
)
|
||||
|
||||
tile_params = {
|
||||
t + w + wt
|
||||
for t in tile_group
|
||||
for w in warp_group
|
||||
for wt in warp_tile_group
|
||||
t + w + wt for t in tile_group for w in warp_group for wt in warp_tile_group
|
||||
}
|
||||
|
||||
for trait in self.valid_trait_names:
|
||||
tile_valid_params = list(
|
||||
filter(lambda t: self.is_tile_valid(t, trait), tile_params))
|
||||
tile_valid_params = [
|
||||
tile for tile in tile_params if self.is_tile_valid(tile, trait)
|
||||
]
|
||||
|
||||
# if len(tile_valid_params) == 0:
|
||||
# raise RuntimeError(f"No valid kernel instance selected for trait: {trait}")
|
||||
if trait not in self.valid_trait_tile_combinations:
|
||||
self.valid_trait_tile_combinations[trait] = []
|
||||
self.valid_trait_tile_combinations[trait].append(tile_valid_params)
|
||||
|
||||
def _generate_instantiation_source_files(self):
|
||||
"""Generate kernel instance instantiation source files """
|
||||
"""Generate kernel instance instantiation source files"""
|
||||
|
||||
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
|
||||
for tile in tile_valid_params:
|
||||
for tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k in tile:
|
||||
for (
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
) in tile:
|
||||
instance_name = f"{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
|
||||
|
||||
content = f"""
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
@@ -514,23 +603,41 @@ struct GemmKernel {{
|
||||
#include "gemm_{trait}.hpp"
|
||||
|
||||
"""
|
||||
sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \
|
||||
self.config.problem.datatype_map['matrix_b'] == 'fp16' and \
|
||||
self.config.problem.datatype_map['matrix_c'] == 'fp16' and \
|
||||
((warp_tile_m == 32 and warp_tile_n == 32 and warp_tile_k == 16) or
|
||||
(warp_tile_m == 16 and warp_tile_n == 16 and warp_tile_k == 32))
|
||||
sparse = (
|
||||
self.config.problem.datatype_map["matrix_a"] == "fp16"
|
||||
and self.config.problem.datatype_map["matrix_b"] == "fp16"
|
||||
and self.config.problem.datatype_map["matrix_c"] == "fp16"
|
||||
and (
|
||||
(
|
||||
warp_tile_m == 32
|
||||
and warp_tile_n == 32
|
||||
and warp_tile_k == 16
|
||||
)
|
||||
or (
|
||||
warp_tile_m == 16
|
||||
and warp_tile_n == 16
|
||||
and warp_tile_k == 32
|
||||
)
|
||||
)
|
||||
)
|
||||
if sparse:
|
||||
sparse_content = content + f"""
|
||||
sparse_filename = f"gemm_{trait}_{instance_name}_true.cpp"
|
||||
sparse_content = (
|
||||
content
|
||||
+ f"""
|
||||
template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, true>;
|
||||
"""
|
||||
(self.output_dir /
|
||||
f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_true.cpp").write_text(sparse_content)
|
||||
)
|
||||
(self.output_dir / sparse_filename).write_text(sparse_content)
|
||||
|
||||
no_sparse_content = content + f"""
|
||||
no_sparse_filename = f"gemm_{trait}_{instance_name}_false.cpp"
|
||||
no_sparse_content = (
|
||||
content
|
||||
+ f"""
|
||||
template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, false>;
|
||||
"""
|
||||
(self.output_dir /
|
||||
f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_false.cpp").write_text(no_sparse_content)
|
||||
)
|
||||
(self.output_dir / no_sparse_filename).write_text(no_sparse_content)
|
||||
|
||||
def _generate_dispatcher_file(self):
|
||||
"""Generate the code block of dispatch mechanism."""
|
||||
@@ -576,7 +683,7 @@ struct GemmDispatcher {
|
||||
}
|
||||
|
||||
static void init(bool structured_sparsity) {
|
||||
ck_tile::ignore = structured_sparsity;
|
||||
(void)structured_sparsity; // Suppress unused parameter warning
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(!kernel_map.empty()) return;
|
||||
\n"""
|
||||
@@ -585,16 +692,37 @@ struct GemmDispatcher {
|
||||
content += f""" kernel_map["{trait}"] = {{"""
|
||||
for _, tile in enumerate(tile_valid_params):
|
||||
for j in range(len(tile)):
|
||||
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile[
|
||||
j]
|
||||
content += f"""[=](ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{ """
|
||||
(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
) = tile[j]
|
||||
content += f"""[=](ck_tile::GemmHostArgs<><>& args, const ck_tile::stream_config& stream) {{ """
|
||||
content += f"""
|
||||
if(structured_sparsity){{ // SMFMA"""
|
||||
sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \
|
||||
self.config.problem.datatype_map['matrix_b'] == 'fp16' and \
|
||||
self.config.problem.datatype_map['matrix_c'] == 'fp16' and \
|
||||
((warp_tile_m == 32 and warp_tile_n == 32 and warp_tile_k == 16) or
|
||||
(warp_tile_m == 16 and warp_tile_n == 16 and warp_tile_k == 32))
|
||||
sparse = (
|
||||
self.config.problem.datatype_map["matrix_a"] == "fp16"
|
||||
and self.config.problem.datatype_map["matrix_b"] == "fp16"
|
||||
and self.config.problem.datatype_map["matrix_c"] == "fp16"
|
||||
and (
|
||||
(
|
||||
warp_tile_m == 32
|
||||
and warp_tile_n == 32
|
||||
and warp_tile_k == 16
|
||||
)
|
||||
or (
|
||||
warp_tile_m == 16
|
||||
and warp_tile_n == 16
|
||||
and warp_tile_k == 32
|
||||
)
|
||||
)
|
||||
)
|
||||
content += f"""
|
||||
return run_kernel<{trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, {BOOL_MAP(sparse)}>>(args, stream);"""
|
||||
content += f"""
|
||||
@@ -604,7 +732,7 @@ struct GemmDispatcher {
|
||||
content += f"""
|
||||
}} """
|
||||
|
||||
if j == len(tile)-1:
|
||||
if j == len(tile) - 1:
|
||||
content += f"""
|
||||
}} """
|
||||
else:
|
||||
@@ -651,22 +779,26 @@ private:
|
||||
(self.output_dir / "gemm_dispatcher.hpp").write_text(content)
|
||||
|
||||
|
||||
def do_list_blobs(args: argparse.Namespace,
|
||||
user_provide_config: Optional[GemmConfig] = None):
|
||||
def do_list_blobs(
|
||||
args: argparse.Namespace, user_provide_config: Optional[GemmConfig] = None
|
||||
):
|
||||
generator = GemmCodeGenerator(args.working_path, user_provide_config)
|
||||
generator.list_all_trait_names()
|
||||
|
||||
|
||||
def do_gen_blobs(args: argparse.Namespace,
|
||||
user_provide_config: Optional[GemmConfig] = None):
|
||||
def do_gen_blobs(
|
||||
args: argparse.Namespace, user_provide_config: Optional[GemmConfig] = None
|
||||
):
|
||||
generator = GemmCodeGenerator(args.working_path, user_provide_config)
|
||||
generator.generate_all_instance_files()
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
gemm_config = GemmConfig.from_json(
|
||||
args.config_json) if args.config_json is not None else args.config_json
|
||||
gemm_config = (
|
||||
GemmConfig.from_json(args.config_json)
|
||||
if args.config_json is not None
|
||||
else args.config_json
|
||||
)
|
||||
|
||||
if args.list_blobs:
|
||||
do_list_blobs(args, gemm_config)
|
||||
@@ -674,7 +806,8 @@ def main(args):
|
||||
do_gen_blobs(args, gemm_config)
|
||||
else:
|
||||
logging.warning(
|
||||
"No mode specified (use --list_blobs or --gen_blobs). Generating by default...")
|
||||
"No mode specified (use --list_blobs or --gen_blobs). Generating by default..."
|
||||
)
|
||||
do_gen_blobs(args, gemm_config)
|
||||
|
||||
|
||||
@@ -684,16 +817,29 @@ if __name__ == "__main__":
|
||||
description="gen API for CK gemm kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w", "--working_path", default="./", required=False, help="The path where all the blobs are going to be generated"
|
||||
"-w",
|
||||
"--working_path",
|
||||
default="./",
|
||||
required=False,
|
||||
help="The path where all the blobs are going to be generated",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-j", "--config_json", required=False, help="Path to the json which contains the configurations that user provide"
|
||||
"-j",
|
||||
"--config_json",
|
||||
required=False,
|
||||
help="Path to the json which contains the configurations that user provide",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l", "--list_blobs", action='store_true', help="List all kernel instances to file"
|
||||
"-l",
|
||||
"--list_blobs",
|
||||
action="store_true",
|
||||
help="List all kernel instances to file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g", "--gen_blobs", action='store_true', help="Generate all kernel instances into different files"
|
||||
"-g",
|
||||
"--gen_blobs",
|
||||
action="store_true",
|
||||
help="Generate all kernel instances into different files",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -23,6 +23,7 @@ class GemmProfiler
|
||||
void benchmark(GemmProblem& gemm_problem,
|
||||
std::vector<std::function<std::tuple<std::string, float>(
|
||||
ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables)
|
||||
ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables)
|
||||
{
|
||||
const ALayout layout_a = ALayout{};
|
||||
const BLayout layout_b = BLayout{};
|
||||
@@ -89,17 +90,20 @@ class GemmProfiler
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::GemmHostArgs<> gemm_args;
|
||||
gemm_args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
gemm_args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
gemm_args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
gemm_args.k_batch = gemm_problem.split_k_;
|
||||
gemm_args.M = gemm_problem.m_;
|
||||
gemm_args.N = gemm_problem.n_;
|
||||
gemm_args.K = gemm_problem.k_;
|
||||
gemm_args.stride_A = gemm_problem.stride_a_;
|
||||
gemm_args.stride_B = gemm_problem.stride_b_;
|
||||
gemm_args.stride_C = gemm_problem.stride_c_;
|
||||
ck_tile::GemmHostArgs<> gemm_args = {
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
{}, // ds_ptr
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
gemm_problem.split_k_,
|
||||
gemm_problem.m_,
|
||||
gemm_problem.n_,
|
||||
gemm_problem.k_,
|
||||
gemm_problem.stride_a_,
|
||||
gemm_problem.stride_b_,
|
||||
{}, // stride_Ds
|
||||
gemm_problem.stride_c_,
|
||||
};
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_result(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
|
||||
|
||||
@@ -16,12 +16,14 @@ import json
|
||||
@dataclass
|
||||
class EnumConfigParam:
|
||||
"""Represents an enumeration-type configuration parameter"""
|
||||
|
||||
values: List[Union[int, str, bool]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RangeConfigParam:
|
||||
"""Represents a numeric range-type configuration parameter"""
|
||||
|
||||
min: int
|
||||
max: int
|
||||
step: int
|
||||
@@ -31,17 +33,13 @@ class RangeConfigParam:
|
||||
"""Generates valid candidates after applying range constraints"""
|
||||
|
||||
if self.min > self.max:
|
||||
raise ValueError(
|
||||
f"Invalid range: min({self.min}) > max({self.max})"
|
||||
)
|
||||
raise ValueError(f"Invalid range: min({self.min}) > max({self.max})")
|
||||
if self.step <= 0:
|
||||
raise ValueError(
|
||||
f"Step must be positive, got {self.step}"
|
||||
)
|
||||
raise ValueError(f"Step must be positive, got {self.step}")
|
||||
|
||||
candidates = list(range(self.min, self.max + 1, self.step))
|
||||
|
||||
if hasattr(self, 'exclude') and self.exclude:
|
||||
if hasattr(self, "exclude") and self.exclude:
|
||||
if not isinstance(self.exclude, list):
|
||||
raise TypeError("exclude must be list type")
|
||||
exclude_set = set(self.exclude)
|
||||
@@ -59,6 +57,7 @@ class RangeConfigParam:
|
||||
@dataclass
|
||||
class ProblemConfig:
|
||||
"""configuration class for problem parameter."""
|
||||
|
||||
datatypes: Tuple[EnumConfigParam, ...]
|
||||
layouts: Tuple[EnumConfigParam, ...]
|
||||
|
||||
@@ -66,24 +65,25 @@ class ProblemConfig:
|
||||
def datatype_map(self) -> Dict[str, str]:
|
||||
"""Get datatype as a key-value map."""
|
||||
return {
|
||||
'matrix_a': self.datatypes[0].values[0],
|
||||
'matrix_b': self.datatypes[1].values[0],
|
||||
'matrix_c': self.datatypes[2].values[0]
|
||||
"matrix_a": self.datatypes[0].values[0],
|
||||
"matrix_b": self.datatypes[1].values[0],
|
||||
"matrix_c": self.datatypes[2].values[0],
|
||||
}
|
||||
|
||||
@property
|
||||
def layout_map(self) -> Dict[str, str]:
|
||||
"""Get layout as a key-value map."""
|
||||
return {
|
||||
'matrix_a': self.layouts[0].values[0],
|
||||
'matrix_b': self.layouts[1].values[0],
|
||||
'matrix_c': self.layouts[2].values[0]
|
||||
"matrix_a": self.layouts[0].values[0],
|
||||
"matrix_b": self.layouts[1].values[0],
|
||||
"matrix_c": self.layouts[2].values[0],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TileConfig:
|
||||
"""Configuration class for tile parameter."""
|
||||
|
||||
tile_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
tile_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
tile_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
@@ -100,6 +100,7 @@ class TileConfig:
|
||||
@dataclass
|
||||
class TraitConfig:
|
||||
"""Configuration class for kernel traits."""
|
||||
|
||||
pipeline: EnumConfigParam
|
||||
scheduler: EnumConfigParam
|
||||
epilogue: EnumConfigParam
|
||||
@@ -110,7 +111,8 @@ class TraitConfig:
|
||||
|
||||
@dataclass
|
||||
class GemmConfig:
|
||||
"""Main configuration class for GEMM operations """
|
||||
"""Main configuration class for GEMM operations"""
|
||||
|
||||
problem: ProblemConfig
|
||||
tile_config: TileConfig
|
||||
trait_config: TraitConfig
|
||||
@@ -124,76 +126,83 @@ class GemmConfig:
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Config file {filepath} not found")
|
||||
|
||||
with config_path.open('r') as f:
|
||||
with config_path.open("r") as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Parse problem config
|
||||
problem = ProblemConfig(
|
||||
datatypes=(
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['datatype_a']['values']),
|
||||
values=config_dict["problem"]["datatype_a"]["values"]
|
||||
),
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['datatype_b']['values']),
|
||||
values=config_dict["problem"]["datatype_b"]["values"]
|
||||
),
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['datatype_c']['values'])
|
||||
values=config_dict["problem"]["datatype_c"]["values"]
|
||||
),
|
||||
),
|
||||
layouts=(
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['layout_a']['values']),
|
||||
values=config_dict["problem"]["layout_a"]["values"]
|
||||
),
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['layout_b']['values']),
|
||||
values=config_dict["problem"]["layout_b"]["values"]
|
||||
),
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['layout_c']['values'])
|
||||
)
|
||||
values=config_dict["problem"]["layout_c"]["values"]
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Parse tile config
|
||||
def create_param(param_dict):
|
||||
if 'values' in param_dict:
|
||||
return EnumConfigParam(values=param_dict['values'])
|
||||
if "values" in param_dict:
|
||||
return EnumConfigParam(values=param_dict["values"])
|
||||
else:
|
||||
return RangeConfigParam(
|
||||
min=param_dict['min'],
|
||||
max=param_dict['max'],
|
||||
step=param_dict['step'],
|
||||
exclude=param_dict.get('exclude', [])
|
||||
min=param_dict["min"],
|
||||
max=param_dict["max"],
|
||||
step=param_dict["step"],
|
||||
exclude=param_dict.get("exclude", []),
|
||||
)
|
||||
|
||||
tile_config = TileConfig(
|
||||
tile_m=create_param(config_dict['tile_config']['tile_m']),
|
||||
tile_n=create_param(config_dict['tile_config']['tile_n']),
|
||||
tile_k=create_param(config_dict['tile_config']['tile_k']),
|
||||
warp_m=create_param(config_dict['tile_config']['warp_m']),
|
||||
warp_n=create_param(config_dict['tile_config']['warp_n']),
|
||||
warp_k=create_param(config_dict['tile_config']['warp_k']),
|
||||
warp_tile_m=create_param(
|
||||
config_dict['tile_config']['warp_tile_m']),
|
||||
warp_tile_n=create_param(
|
||||
config_dict['tile_config']['warp_tile_n']),
|
||||
warp_tile_k=create_param(
|
||||
config_dict['tile_config']['warp_tile_k'])
|
||||
tile_m=create_param(config_dict["tile_config"]["tile_m"]),
|
||||
tile_n=create_param(config_dict["tile_config"]["tile_n"]),
|
||||
tile_k=create_param(config_dict["tile_config"]["tile_k"]),
|
||||
warp_m=create_param(config_dict["tile_config"]["warp_m"]),
|
||||
warp_n=create_param(config_dict["tile_config"]["warp_n"]),
|
||||
warp_k=create_param(config_dict["tile_config"]["warp_k"]),
|
||||
warp_tile_m=create_param(config_dict["tile_config"]["warp_tile_m"]),
|
||||
warp_tile_n=create_param(config_dict["tile_config"]["warp_tile_n"]),
|
||||
warp_tile_k=create_param(config_dict["tile_config"]["warp_tile_k"]),
|
||||
)
|
||||
|
||||
# Parse trait config
|
||||
trait_config = TraitConfig(
|
||||
pipeline=EnumConfigParam(
|
||||
values=config_dict['trait_config']['pipeline']['values']),
|
||||
values=config_dict["trait_config"]["pipeline"]["values"]
|
||||
),
|
||||
scheduler=EnumConfigParam(
|
||||
values=config_dict['trait_config']['scheduler']['values']),
|
||||
values=config_dict["trait_config"]["scheduler"]["values"]
|
||||
),
|
||||
epilogue=EnumConfigParam(
|
||||
values=config_dict['trait_config']['epilogue']['values']),
|
||||
values=config_dict["trait_config"]["epilogue"]["values"]
|
||||
),
|
||||
pad_m=EnumConfigParam(
|
||||
values=config_dict['trait_config']['pad_m']['values']),
|
||||
values=config_dict["trait_config"]["pad_m"]["values"]
|
||||
),
|
||||
pad_n=EnumConfigParam(
|
||||
values=config_dict['trait_config']['pad_n']['values']),
|
||||
values=config_dict["trait_config"]["pad_n"]["values"]
|
||||
),
|
||||
pad_k=EnumConfigParam(
|
||||
values=config_dict['trait_config']['pad_k']['values'])
|
||||
values=config_dict["trait_config"]["pad_k"]["values"]
|
||||
),
|
||||
)
|
||||
|
||||
return cls(
|
||||
problem=problem,
|
||||
tile_config=tile_config,
|
||||
trait_config=trait_config
|
||||
problem=problem, tile_config=tile_config, trait_config=trait_config
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
|
||||
Reference in New Issue
Block a user