add cmake option & modify

This commit is contained in:
Yanxing-Shi
2025-05-14 09:17:37 +00:00
parent 58ab4eb617
commit 4bbe7eca09
5 changed files with 133 additions and 70 deletions

View File

@@ -1,13 +1,22 @@
option(USE_CUSTOM_CONFIG "Enable user-provided configuration file" ON)
# generate a list of kernels, but not actually emit files at config stage
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--list_blobs
RESULT_VARIABLE ret
)
if(USE_CUSTOM_CONFIG)
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--list_blobs
RESULT_VARIABLE ret
)
else()
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--list_blobs
RESULT_VARIABLE ret
)
endif()
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "Fail to list kernels via Python. ${ret}")
@@ -15,13 +24,22 @@ endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS)
add_custom_command(
OUTPUT ${GEMM_CODEGEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--gen_blobs
)
if(USE_CUSTOM_CONFIG)
add_custom_command(
OUTPUT ${GEMM_CODEGEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--gen_blobs
)
else()
add_custom_command(
OUTPUT ${GEMM_CODEGEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--gen_blobs
)
endif()
set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm")
message("adding example ${EXECUTABLE_GEMM_INSTANCE}")

View File

@@ -153,7 +153,8 @@ HOT_LOOP_TRUE = {'mem': RUN_MEM,
def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)]
warp_tile_combinations = {
# To Do: add some more supported combinations
warp_tile_supported_combinations = {
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
# last 2 were not supported by MI300 architecture.
@@ -161,13 +162,21 @@ warp_tile_combinations = {
'bf8_bf8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
}
# To Do: remove some unsupported combinations
trait_unsupported_combinations = {
("compv3", "cshuffle", "interwave"),
("compv3", "default", "interwave"),
("compv4", "cshuffle", "interwave"),
("compv4", "default", "interwave")
}
def size_of(data_type):
if data_type == 'fp16' or data_type == 'bf16':
def element_size(data_type: str) -> float:
data_type = data_type.lower()
if data_type in {'fp16', 'bf16'}:
return 2
elif data_type == 'int8' or data_type == 'fp8' or data_type == 'bf8':
elif data_type in {'int8', 'fp8', 'bf8'}:
return 1
elif data_type == 'int4': # TODO:: needs to confirm
elif data_type == 'int4':
return 0.5
else:
return 4
raise ValueError(f"Unsupported data type: {data_type}")

View File

@@ -34,53 +34,84 @@
"tile_config": {
"tile_m": {
"max": 512,
"min": 256,
"step": 64
"min": 64,
"step": 8
},
"tile_n": {
"values": [
256
]
"max": 512,
"min": 64,
"step": 8
},
"tile_k": {
"values": [
32
]
"max": 512,
"min": 64,
"step": 8
},
"warp_m": {
"values": [
2
4,
8,
16,
32,
64,
128
]
},
"warp_n": {
"values": [
2
4,
8,
16,
32,
64,
128
]
},
"warp_k": {
"values": [
1
4,
8,
16,
32,
64,
128
]
},
"warp_tile_m": {
"values": [
32
4,
8,
16,
32,
64,
128
]
},
"warp_tile_n": {
"values": [
32
4,
8,
16,
32,
64,
128
]
},
"warp_tile_k": {
"values": [
8
4,
8,
16,
32,
64,
128
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv4",
"compv3",
"mem"
]

View File

@@ -26,8 +26,9 @@ from codegen_utils import (
EPILOGUE_MAP,
HOT_LOOP_TRUE,
BOOL_MAP,
warp_tile_combinations,
size_of
warp_tile_supported_combinations,
trait_unsupported_combinations,
element_size
)
import logging
@@ -73,14 +74,6 @@ class GemmCodeGenerator:
"pad_n",
"pad_k"]
# To remove some unsupported combinations
unsupported_combinations = {
("compv3", "cshuffle", "interwave"),
("compv3", "default", "interwave"),
("compv4", "cshuffle", "interwave"),
("compv4", "default", "interwave")
}
# Generate all unique_combinations
_unique = set(itertools.product(*[
getattr(self.config.trait_config, param).values
@@ -91,7 +84,7 @@ class GemmCodeGenerator:
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = combo
current_combination = (pipeline, epilogue, scheduler)
if current_combination not in unsupported_combinations:
if current_combination not in trait_unsupported_combinations:
trait_name = (
f"{pipeline}_{epilogue}_{scheduler}_"
f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}"
@@ -118,15 +111,15 @@ class GemmCodeGenerator:
#include "ck_tile/core.hpp"
// Data types
using ADataType = {DATA_TYPE_MAP[self.config.problem.datatype_values[0]]};
using BDataType = {DATA_TYPE_MAP[self.config.problem.datatype_values[1]]};
using ADataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_a']]};
using BDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_b']]};
using AccDataType = float;
using CDataType = {DATA_TYPE_MAP[self.config.problem.datatype_values[2]]};
using CDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_c']]};
// Layout configurations
using ALayout = {LAYOUT_MAP[self.config.problem.layout_values[0]]};
using BLayout = {LAYOUT_MAP[self.config.problem.layout_values[1]]};
using CLayout = {LAYOUT_MAP[self.config.problem.layout_values[2]]};
using ALayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_a']]};
using BLayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_b']]};
using CLayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_c']]};
"""
(self.output_dir / "gemm_common.hpp").write_text(content)
@@ -358,15 +351,15 @@ struct GemmKernel {{
logging.warning(
f"Dimension alignment failed [{trait}]: {', '.join(alignment_issues)}. "
f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by "
f"[warp×tile] {warp_m}x{warp_n}x{warp_k} × {warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
f"[warpxtile] {warp_m}x{warp_n}x{warp_k} x {warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
)
return False
# LDS capacity verification
matrix_a_size = (tile_m * tile_k) * \
size_of(self.config.problem.datatype_values[0])
element_size(self.config.problem.datatype_map['matrix_a'])
matrix_b_size = (tile_n * tile_k) * \
size_of(self.config.problem.datatype_values[1])
element_size(self.config.problem.datatype_map['matrix_b'])
total_tile_in_lds = matrix_a_size + matrix_b_size
max_tile_size = 2**16 if pipeline == "compv4" else 2**15
@@ -374,15 +367,15 @@ struct GemmKernel {{
logging.warning(
f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds/1024:.1f}KB) > "
f"maximum allowed {max_tile_size:,}B ({max_tile_size/1024}KB). Breakdown:\n"
f"- Matrix A ({self.config.problem.datatype_values[0]}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n"
f"- Matrix B ({self.config.problem.datatype_values[1]}): {tile_n}x{tile_k} = {matrix_b_size:,}B"
f"- Matrix A ({self.config.problem.datatype_map['matrix_a']}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n"
f"- Matrix B ({self.config.problem.datatype_map['matrix_b']}): {tile_n}x{tile_k} = {matrix_b_size:,}B"
)
return False
# Warp combination validation
warp_tile_key = f"{self.config.problem.datatype_values[0]}_{self.config.problem.datatype_values[1]}_{self.config.problem.datatype_values[2]}"
warp_tile_key = f"{self.config.problem.datatype_map['matrix_a']}_{self.config.problem.datatype_map['matrix_b']}_{self.config.problem.datatype_map['matrix_c']}"
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
allowed_combinations = warp_tile_combinations.get(warp_tile_key, [])
allowed_combinations = warp_tile_supported_combinations.get(warp_tile_key, [])
if current_combination not in allowed_combinations:
logging.warning(
@@ -455,7 +448,9 @@ struct GemmDispatcher {
if(structured_sparsity){{ // SMFMA"""
for tile in tile_params:
if self.is_tile_valid(tile, trait):
sparse = self.config.problem.datatype_values[0] == 'fp16' and \
sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \
self.config.problem.datatype_map['matrix_b'] == 'fp16' and \
self.config.problem.datatype_map['matrix_c'] == 'fp16' and \
((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or
(tile[6] == 16 and tile[7] == 16 and tile[8] == 32))
content += f"""

View File

@@ -228,12 +228,22 @@ class ProblemConfig:
)
@property
def datatype_values(self) -> list:
return [p.values[0] for p in self.datatypes]
def datatype_map(self) -> dict[str, str]:
"""Get current layout selections as a key-value map."""
return {
'matrix_a': self.datatypes[0].values[0],
'matrix_b': self.datatypes[1].values[0],
'matrix_c': self.datatypes[2].values[0]
}
@property
def layout_values(self) -> list:
return [p.values[0] for p in self.layouts]
def layout_map(self) -> dict[str, str]:
"""Get current layout selections as a key-value map."""
return {
'matrix_a': self.layouts[0].values[0],
'matrix_b': self.layouts[1].values[0],
'matrix_c': self.layouts[2].values[0]
}
@dataclass
@@ -257,33 +267,33 @@ class TileConfig:
warp_m: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[256]
values=[8]
)
)
warp_n: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[256]
values=[8]
)
)
warp_k: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[256]
values=[8]
)
)
warp_tile_m: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[256]
values=[8]
)
)
warp_tile_n: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[256]
values=[8]
)
)
warp_tile_k: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[256]
values=[8]
)
)