mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
add cmake option & modify
This commit is contained in:
@@ -1,13 +1,22 @@
|
||||
|
||||
option(USE_CUSTOM_CONFIG "Enable user-provided configuration file" ON)
|
||||
|
||||
# generate a list of kernels, but not actually emit files at config stage
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
|
||||
--list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(USE_CUSTOM_CONFIG)
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
|
||||
--list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
else()
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
endif()
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "Fail to list kernels via Python. ${ret}")
|
||||
@@ -15,13 +24,22 @@ endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${GEMM_CODEGEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
|
||||
--gen_blobs
|
||||
)
|
||||
if(USE_CUSTOM_CONFIG)
|
||||
add_custom_command(
|
||||
OUTPUT ${GEMM_CODEGEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
|
||||
--gen_blobs
|
||||
)
|
||||
else()
|
||||
add_custom_command(
|
||||
OUTPUT ${GEMM_CODEGEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--gen_blobs
|
||||
)
|
||||
endif()
|
||||
|
||||
set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm")
|
||||
message("adding example ${EXECUTABLE_GEMM_INSTANCE}")
|
||||
|
||||
@@ -153,7 +153,8 @@ HOT_LOOP_TRUE = {'mem': RUN_MEM,
|
||||
def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)]
|
||||
|
||||
|
||||
warp_tile_combinations = {
|
||||
# To Do: add some more supported combinations
|
||||
warp_tile_supported_combinations = {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
# last 2 were not supported by MI300 architecture.
|
||||
@@ -161,13 +162,21 @@ warp_tile_combinations = {
|
||||
'bf8_bf8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
|
||||
}
|
||||
|
||||
# To Do: remove some unsupported combinations
|
||||
trait_unsupported_combinations = {
|
||||
("compv3", "cshuffle", "interwave"),
|
||||
("compv3", "default", "interwave"),
|
||||
("compv4", "cshuffle", "interwave"),
|
||||
("compv4", "default", "interwave")
|
||||
}
|
||||
|
||||
def size_of(data_type):
|
||||
if data_type == 'fp16' or data_type == 'bf16':
|
||||
def element_size(data_type: str) -> float:
|
||||
data_type = data_type.lower()
|
||||
if data_type in {'fp16', 'bf16'}:
|
||||
return 2
|
||||
elif data_type == 'int8' or data_type == 'fp8' or data_type == 'bf8':
|
||||
elif data_type in {'int8', 'fp8', 'bf8'}:
|
||||
return 1
|
||||
elif data_type == 'int4': # TODO:: needs to confirm
|
||||
elif data_type == 'int4':
|
||||
return 0.5
|
||||
else:
|
||||
return 4
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
|
||||
@@ -34,53 +34,84 @@
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"max": 512,
|
||||
"min": 256,
|
||||
"step": 64
|
||||
"min": 64,
|
||||
"step": 8
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
256
|
||||
]
|
||||
"max": 512,
|
||||
"min": 64,
|
||||
"step": 8
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
"max": 512,
|
||||
"min": 64,
|
||||
"step": 8
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
32
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
32
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
8
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv4",
|
||||
"compv3",
|
||||
"mem"
|
||||
]
|
||||
|
||||
@@ -26,8 +26,9 @@ from codegen_utils import (
|
||||
EPILOGUE_MAP,
|
||||
HOT_LOOP_TRUE,
|
||||
BOOL_MAP,
|
||||
warp_tile_combinations,
|
||||
size_of
|
||||
warp_tile_supported_combinations,
|
||||
trait_unsupported_combinations,
|
||||
element_size
|
||||
)
|
||||
import logging
|
||||
|
||||
@@ -73,14 +74,6 @@ class GemmCodeGenerator:
|
||||
"pad_n",
|
||||
"pad_k"]
|
||||
|
||||
# To remove some unsupported combinations
|
||||
unsupported_combinations = {
|
||||
("compv3", "cshuffle", "interwave"),
|
||||
("compv3", "default", "interwave"),
|
||||
("compv4", "cshuffle", "interwave"),
|
||||
("compv4", "default", "interwave")
|
||||
}
|
||||
|
||||
# Generate all unique_combinations
|
||||
_unique = set(itertools.product(*[
|
||||
getattr(self.config.trait_config, param).values
|
||||
@@ -91,7 +84,7 @@ class GemmCodeGenerator:
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = combo
|
||||
current_combination = (pipeline, epilogue, scheduler)
|
||||
|
||||
if current_combination not in unsupported_combinations:
|
||||
if current_combination not in trait_unsupported_combinations:
|
||||
trait_name = (
|
||||
f"{pipeline}_{epilogue}_{scheduler}_"
|
||||
f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}"
|
||||
@@ -118,15 +111,15 @@ class GemmCodeGenerator:
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
// Data types
|
||||
using ADataType = {DATA_TYPE_MAP[self.config.problem.datatype_values[0]]};
|
||||
using BDataType = {DATA_TYPE_MAP[self.config.problem.datatype_values[1]]};
|
||||
using ADataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_a']]};
|
||||
using BDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_b']]};
|
||||
using AccDataType = float;
|
||||
using CDataType = {DATA_TYPE_MAP[self.config.problem.datatype_values[2]]};
|
||||
using CDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_c']]};
|
||||
|
||||
// Layout configurations
|
||||
using ALayout = {LAYOUT_MAP[self.config.problem.layout_values[0]]};
|
||||
using BLayout = {LAYOUT_MAP[self.config.problem.layout_values[1]]};
|
||||
using CLayout = {LAYOUT_MAP[self.config.problem.layout_values[2]]};
|
||||
using ALayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_a']]};
|
||||
using BLayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_b']]};
|
||||
using CLayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_c']]};
|
||||
"""
|
||||
|
||||
(self.output_dir / "gemm_common.hpp").write_text(content)
|
||||
@@ -358,15 +351,15 @@ struct GemmKernel {{
|
||||
logging.warning(
|
||||
f"Dimension alignment failed [{trait}]: {', '.join(alignment_issues)}. "
|
||||
f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by "
|
||||
f"[warp×tile] {warp_m}x{warp_n}x{warp_k} × {warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
|
||||
f"[warpxtile] {warp_m}x{warp_n}x{warp_k} x {warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
|
||||
)
|
||||
return False
|
||||
|
||||
# LDS capacity verification
|
||||
matrix_a_size = (tile_m * tile_k) * \
|
||||
size_of(self.config.problem.datatype_values[0])
|
||||
element_size(self.config.problem.datatype_map['matrix_a'])
|
||||
matrix_b_size = (tile_n * tile_k) * \
|
||||
size_of(self.config.problem.datatype_values[1])
|
||||
element_size(self.config.problem.datatype_map['matrix_b'])
|
||||
total_tile_in_lds = matrix_a_size + matrix_b_size
|
||||
|
||||
max_tile_size = 2**16 if pipeline == "compv4" else 2**15
|
||||
@@ -374,15 +367,15 @@ struct GemmKernel {{
|
||||
logging.warning(
|
||||
f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds/1024:.1f}KB) > "
|
||||
f"maximum allowed {max_tile_size:,}B ({max_tile_size/1024}KB). Breakdown:\n"
|
||||
f"- Matrix A ({self.config.problem.datatype_values[0]}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n"
|
||||
f"- Matrix B ({self.config.problem.datatype_values[1]}): {tile_n}x{tile_k} = {matrix_b_size:,}B"
|
||||
f"- Matrix A ({self.config.problem.datatype_map['matrix_a']}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n"
|
||||
f"- Matrix B ({self.config.problem.datatype_map['matrix_b']}): {tile_n}x{tile_k} = {matrix_b_size:,}B"
|
||||
)
|
||||
return False
|
||||
|
||||
# Warp combination validation
|
||||
warp_tile_key = f"{self.config.problem.datatype_values[0]}_{self.config.problem.datatype_values[1]}_{self.config.problem.datatype_values[2]}"
|
||||
warp_tile_key = f"{self.config.problem.datatype_map['matrix_a']}_{self.config.problem.datatype_map['matrix_b']}_{self.config.problem.datatype_map['matrix_c']}"
|
||||
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
|
||||
allowed_combinations = warp_tile_combinations.get(warp_tile_key, [])
|
||||
allowed_combinations = warp_tile_supported_combinations.get(warp_tile_key, [])
|
||||
|
||||
if current_combination not in allowed_combinations:
|
||||
logging.warning(
|
||||
@@ -455,7 +448,9 @@ struct GemmDispatcher {
|
||||
if(structured_sparsity){{ // SMFMA"""
|
||||
for tile in tile_params:
|
||||
if self.is_tile_valid(tile, trait):
|
||||
sparse = self.config.problem.datatype_values[0] == 'fp16' and \
|
||||
sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \
|
||||
self.config.problem.datatype_map['matrix_b'] == 'fp16' and \
|
||||
self.config.problem.datatype_map['matrix_c'] == 'fp16' and \
|
||||
((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or
|
||||
(tile[6] == 16 and tile[7] == 16 and tile[8] == 32))
|
||||
content += f"""
|
||||
|
||||
@@ -228,12 +228,22 @@ class ProblemConfig:
|
||||
)
|
||||
|
||||
@property
|
||||
def datatype_values(self) -> list:
|
||||
return [p.values[0] for p in self.datatypes]
|
||||
def datatype_map(self) -> dict[str, str]:
|
||||
"""Get current layout selections as a key-value map."""
|
||||
return {
|
||||
'matrix_a': self.datatypes[0].values[0],
|
||||
'matrix_b': self.datatypes[1].values[0],
|
||||
'matrix_c': self.datatypes[2].values[0]
|
||||
}
|
||||
|
||||
@property
|
||||
def layout_values(self) -> list:
|
||||
return [p.values[0] for p in self.layouts]
|
||||
def layout_map(self) -> dict[str, str]:
|
||||
"""Get current layout selections as a key-value map."""
|
||||
return {
|
||||
'matrix_a': self.layouts[0].values[0],
|
||||
'matrix_b': self.layouts[1].values[0],
|
||||
'matrix_c': self.layouts[2].values[0]
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -257,33 +267,33 @@ class TileConfig:
|
||||
|
||||
warp_m: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[256]
|
||||
values=[8]
|
||||
)
|
||||
)
|
||||
warp_n: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[256]
|
||||
values=[8]
|
||||
)
|
||||
)
|
||||
warp_k: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[256]
|
||||
values=[8]
|
||||
)
|
||||
)
|
||||
|
||||
warp_tile_m: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[256]
|
||||
values=[8]
|
||||
)
|
||||
)
|
||||
warp_tile_n: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[256]
|
||||
values=[8]
|
||||
)
|
||||
)
|
||||
warp_tile_k: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[256]
|
||||
values=[8]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user