[CKTILE] Layout Support for CK Tile engine (#2482)

* Updating runtime log message for CK TILE ENGINE

* CKTile layout from config

* CKTile custom config for CI

* Documentation for Layout Changes

* CKTile Layout changes  to Jenkins

* Fixing Clang Format

* Changes to Jenkins file to fix error

* fix(cmake-ck-dev): no longer sets invalid values as gpu arch

* style(py files): ruff formatting

* fix(cmake-ck-release): no longer sets invalid values as gpu arch

* chore(cmake-tile_engine): add reminder to uncomment user config json

* Changes to jenkin file to address more cases

* Changes to Jenkins to fix Error

* Changes to Jenkins file for fixing an error

* Update Jenkinsfile (#2517)

* Update Jenkinsfile

---------

Co-authored-by: ThruptiRajLakshmanaGowda <tlakshma@amd.com>
Co-authored-by: AviralGoelAMD <aviral.goel@amd.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Thrupti Raj Lakshmana Gowda
2025-07-17 14:19:41 -05:00
committed by GitHub
parent c08986b026
commit 0f3083ab5c
11 changed files with 239 additions and 124 deletions

View File

@@ -1,21 +1,32 @@
set(GEMM_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)")
set(GEMM_LAYOUT "rcr" CACHE STRING "List of layout for GEMM (semicolon-separated)")
function(build_gemm_for_datatype datatype)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/")
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
#set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json")
function(build_gemm_for_datatype datatype layout)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Comment this if-else block when using user_provided_config
if(layout STREQUAL "rcr")
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
else()
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/custom_ci_config.json")
endif()
# uncomment this if you want to use user_provided_config.json
# set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json")
# Generate kernel list
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--list_blobs
RESULT_VARIABLE ret
)
if(NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to list kernels for ${datatype}: ${ret}")
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${ret}")
endif()
file(STRINGS "${working_path}/gemm_instance_blobs.txt" codegen_blobs)
@@ -27,11 +38,12 @@ function(build_gemm_for_datatype datatype)
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path "${working_path}"
--datatype ${datatype}
--layout ${layout}
--config_json "${json_blob}"
--gen_blobs
COMMENT "Generating GEMM instance sources for ${datatype}"
COMMENT "Generating GEMM instance sources for ${datatype} ${layout}"
)
add_custom_target(gemm_gen_${datatype} DEPENDS ${codegen_blobs})
add_custom_target(gemm_gen_${datatype}_${layout} DEPENDS ${codegen_blobs})
set(intermediate_libs)
list(LENGTH codegen_blobs codegen_blobs_len)
@@ -69,7 +81,7 @@ function(build_gemm_for_datatype datatype)
#list(LENGTH chunk_files chunk_files_len)
#if(chunk_files_len AND chunk_files_len GREATER 1)
if(chunk_files)
set(sub_intermediate_lib_name "gemm_objlib_${name}_${i}_${datatype}")
set(sub_intermediate_lib_name "gemm_objlib_${name}_${i}_${datatype}_${layout}")
add_library(${sub_intermediate_lib_name} OBJECT ${chunk_files})
list(APPEND sub_intermediate_libs ${sub_intermediate_lib_name})
endif()
@@ -80,7 +92,7 @@ function(build_gemm_for_datatype datatype)
#list(LENGTH sub_intermediate_libs sub_intermediate_libs_len)
#if(sub_intermediate_libs AND sub_intermediate_libs_len GREATER 1)
if(sub_intermediate_libs)
set(intermediate_lib_name "gemm_staticlib_${name}_${datatype}")
set(intermediate_lib_name "gemm_staticlib_${name}_${datatype}_${layout}")
# Collect the $<TARGET_OBJECTS:...> expressions
set(obj_exprs)
@@ -89,7 +101,7 @@ function(build_gemm_for_datatype datatype)
endforeach()
add_library(${intermediate_lib_name} STATIC ${obj_exprs})
add_dependencies(${intermediate_lib_name} gemm_gen_${datatype})
add_dependencies(${intermediate_lib_name} gemm_gen_${datatype}_${layout})
#foreach(objlib IN LISTS sub_intermediate_libs)
# target_sources(${intermediate_lib_name} PRIVATE $<TARGET_OBJECTS:${objlib}>)
#endforeach()
@@ -99,28 +111,28 @@ function(build_gemm_for_datatype datatype)
endforeach()
# Interface library for instances
add_library(gemm_template_instances_${datatype} INTERFACE)
add_dependencies(gemm_template_instances_${datatype} gemm_gen_${datatype})
target_link_libraries(gemm_template_instances_${datatype} INTERFACE ${intermediate_libs})
target_include_directories(gemm_template_instances_${datatype} INTERFACE
add_library(gemm_template_instances_${datatype}_${layout} INTERFACE)
add_dependencies(gemm_template_instances_${datatype}_${layout} gemm_gen_${datatype}_${layout})
target_link_libraries(gemm_template_instances_${datatype}_${layout} INTERFACE ${intermediate_libs})
target_include_directories(gemm_template_instances_${datatype}_${layout} INTERFACE
${CMAKE_CURRENT_LIST_DIR}
"${working_path}"
)
set_target_properties(gemm_template_instances_${datatype} PROPERTIES LINKER_LANGUAGE CXX)
set_target_properties(gemm_template_instances_${datatype}_${layout} PROPERTIES LINKER_LANGUAGE CXX)
# Host API interface library
add_library(gemm_host_api_${datatype} INTERFACE)
target_link_libraries(gemm_host_api_${datatype} INTERFACE gemm_template_instances_${datatype})
target_include_directories(gemm_host_api_${datatype} INTERFACE
add_library(gemm_host_api_${datatype}_${layout} INTERFACE)
target_link_libraries(gemm_host_api_${datatype}_${layout} INTERFACE gemm_template_instances_${datatype}_${layout})
target_include_directories(gemm_host_api_${datatype}_${layout} INTERFACE
${CMAKE_CURRENT_LIST_DIR}
"${working_path}"
)
# Executable per datatype
set(exec_name "benchmark_gemm_${datatype}")
set(exec_name "benchmark_gemm_${datatype}_${layout}")
add_executable(${exec_name} benchmark_gemm.cpp)
target_link_libraries(${exec_name} PRIVATE gemm_host_api_${datatype})
target_link_libraries(${exec_name} PRIVATE gemm_host_api_${datatype}_${layout})
target_compile_options(${exec_name} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
@@ -130,5 +142,7 @@ endfunction()
# Process each datatype in isolation
foreach(dt IN LISTS GEMM_DATATYPE)
build_gemm_for_datatype(${dt})
foreach(l IN LISTS GEMM_LAYOUT)
build_gemm_for_datatype(${dt} ${l})
endforeach()
endforeach()

View File

@@ -7,6 +7,7 @@ CK Tile Engine GEMM is used to generate and run GEMM kernels with different comb
Users can specify custom kernel configurations such as tile size, warp size, padding, pipeline, scheduler, and epilogue in the config file. This allows building only for selected configurations, significantly reducing build time.
For reference please see `./configs/user_provided_config.json`.
The Tile engine also has a default kernel configuration for providing range of configuration parameter values, which helps users who lack kernel development experience to benchmark. For reference please see in `./configs/default_config.json`
If user does not provide kernel configuration, the tile engine uses default kernel configuration to generate kernel instances and benchmark.
@@ -18,25 +19,28 @@ mkdir build && cd build
# build composable kernel
# replace [Arch] with the appropriate architecture or leave blank and
# replace [Datatype1;Datatype2;...] in comma separated datatypes string (possible datatypes are [fp8, bf8, int8, fp16, bf16])
sh ../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_DATATYPE="[Datatype1;Datatype2]"
# replace [Layout1;Layout2;...] in comma separated datatypes string (possible layouts are [rcr, rrr, crr, ccr])
sh ../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_DATATYPE="[Datatype1;Datatype2]" -DGEMM_LAYOUT="[Layout1;Layout2]"
# generate different executable for each passed datatype
make benchmark_gemm_[Datatype1] -j
make benchmark_gemm_[Datatype2] -j
make benchmark_gemm_[Datatype1]_[Layout1] -j
make benchmark_gemm_[Datatype1]_[Layout2] -j
make benchmark_gemm_[Datatype2]_[Layout1] -j
make benchmark_gemm_[Datatype2]_[Layout2] -j
```
`benchmark_gemm_[Datatypes]` will be located in the `./bin/` directory.
`benchmark_gemm_[Datatype]_[Layout]` will be located in the `./bin/` directory.
`benchmark_gemm_[Datatypes]` must be rebuilt everytime if configuration file is modified.
`benchmark_gemm_[Datatype]_[Layout]` must be rebuilt everytime if configuration file is modified.
``` bash
rm -rf tile_engine/ && make benchmark_gemm_[Datatypes] -j # rebuild
rm -rf tile_engine/ && make benchmark_gemm_[Datatypes]_[Layout] -j # rebuild
```
## For eaxmple build for gfx942 for fp8 and fp16 datatypes
## For eaxmple build for gfx942 for fp8 and fp16 datatypes with rcr layout
``` bash
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_DATATYPE="fp8;fp16"
make benchmark_gemm_fp8 -j
make benchmark_gemm_fp16 -j
sh ../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_DATATYPE="fp8;fp16" -DGEMM_LAYOUT="rcr"
make benchmark_gemm_fp8_rcr -j
make benchmark_gemm_fp16_rcr -j
```
## benchmark_gemm inputs
@@ -103,7 +107,7 @@ The following JSON file specifies parameters used to generate and build GEMM ker
At runtime, a specific subset of the generated kernels can be selected using command-line arguments.
``` bash
./bin/benchmark_gemm -pipeline=compv3 -scheduler=intrawave -epilogue=default
./bin/benchmark_gemm_[Datatype]_[Layout] -pipeline=compv3 -scheduler=intrawave -epilogue=default
```
The above command runs kernels configured with the compv3 pipeline, intrawave scheduler, and default epilogue, while sweeping over different BlockTile sizes, WarpTile sizes, and WarpTile mappings.

View File

@@ -1,20 +1,5 @@
{
"problem": {
"layout_a": {
"values": [
"r"
]
},
"layout_b": {
"values": [
"c"
]
},
"layout_c": {
"values": [
"r"
]
}
},
"tile_config": {
"tile_m": {

View File

@@ -0,0 +1,82 @@
{
"problem": {
},
"tile_config": {
"tile_m": {
"values": [
128 ]
},
"tile_n": {
"values": [
128
]
},
"tile_k": {
"values": [
128
]
},
"warp_m": {
"values": [
2
]
},
"warp_n": {
"values": [
2
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
32
]
},
"warp_tile_n": {
"values": [
32
]
},
"warp_tile_k": {
"values": [
16
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3"
]
},
"scheduler": {
"values": [
"intrawave"
]
},
"epilogue": {
"values": [
"default"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
}
}
}

View File

@@ -1,20 +1,5 @@
{
"problem": {
"layout_a": {
"values": [
"r"
]
},
"layout_b": {
"values": [
"c"
]
},
"layout_c": {
"values": [
"r"
]
}
},
"tile_config": {
"tile_m": {

View File

@@ -1,20 +1,5 @@
{
"problem": {
"layout_a": {
"values": [
"r"
]
},
"layout_b": {
"values": [
"c"
]
},
"layout_c": {
"values": [
"r"
]
}
},
"tile_config": {
"tile_m": {

View File

@@ -98,19 +98,19 @@ class GemmCodeGenerator:
_,
) in tile:
instance_name = f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp"
if instance_name not in file_name:
file_name.add(instance_name)
f.write(str(w_p / instance_name) + "\n")
files_listed += 1
file_range_map[trait] = (start_idx, files_listed)
file_path = w_p / 'gemm_instance_blobs_range.txt'
with file_path.open('w') as f:
file_path = w_p / "gemm_instance_blobs_range.txt"
with file_path.open("w") as f:
for name, ranges in file_range_map.items():
s, l = ranges
f.write(name + " " + f"{s}" + " " + f"{l}"+ "\n")
f.write(name + " " + f"{s}" + " " + f"{l}" + "\n")
def _generate_all_traits(self):
"""Generate all possible kernel traits names."""
@@ -563,7 +563,7 @@ struct GemmKernel {{
self.valid_trait_tile_combinations[trait].append(tile_valid_params)
def _generate_instantiation_source_files(self):
"""Generate kernel instance instantiation source files """
"""Generate kernel instance instantiation source files"""
tile_map = {}
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
for tile in tile_valid_params:
@@ -583,11 +583,13 @@ struct GemmKernel {{
if key not in tile_map:
tile_map[key] = set()
tile_map[key].add(value)
files_listed = 0
for trait, _ in self.valid_trait_tile_combinations.items():
for block_tile, warp_tiles in tile_map.items():
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k = map(int, block_tile.split('x'))
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k = map(
int, block_tile.split("x")
)
content = f"""
// SPDX-License-Identifier: MIT
@@ -598,8 +600,10 @@ struct GemmKernel {{
"""
for warp_tile in warp_tiles:
warp_tile_m, warp_tile_n, warp_tile_k = map(int, warp_tile.split("x"))
warp_tile_m, warp_tile_n, warp_tile_k = map(
int, warp_tile.split("x")
)
sparse = (
self.config.problem.datatype_map["matrix_a"] == "fp16"
and self.config.problem.datatype_map["matrix_b"] == "fp16"
@@ -619,15 +623,23 @@ struct GemmKernel {{
)
if sparse:
files_listed = files_listed + 1
content = content + f"""
content = (
content
+ f"""
template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, true>;"""
)
files_listed = files_listed + 1
content = content + f"""
content = (
content
+ f"""
template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, false>;"""
)
content += f"""
"""
(self.output_dir /
f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp").write_text(content)
(
self.output_dir
/ f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp"
).write_text(content)
print(f"Generated {files_listed} kernel instances in total.")
def _generate_dispatcher_file(self):
@@ -785,7 +797,7 @@ def do_gen_blobs(
def main(args):
gemm_config = (
GemmConfig.from_json(args.config_json, args.datatype)
GemmConfig.from_json(args.config_json, args.datatype, args.layout)
if args.config_json is not None
else args.config_json
)
@@ -823,7 +835,13 @@ if __name__ == "__main__":
"-d",
"--datatype",
required=True,
help="Specify what datatype to use for the kernel generation, e.g. fp16, bf16, int8, fp8, bf8"
help="Specify what datatype to use for the kernel generation, e.g. fp16, bf16, int8, fp8, bf8",
)
parser.add_argument(
"-ly",
"--layout",
required=True,
help="Specify what layout to use for the kernel generation, e.g. rcr, rrr",
)
parser.add_argument(
"-l",

View File

@@ -118,7 +118,9 @@ class GemmConfig:
trait_config: TraitConfig
@classmethod
def from_json(cls: Type["GemmConfig"], filepath: str, datatype: str) -> "GemmConfig":
def from_json(
cls: Type["GemmConfig"], filepath: str, datatype: str, layout: str
) -> "GemmConfig":
"""JSON configuration loader with validation controls"""
config_path = Path(filepath)
@@ -132,32 +134,40 @@ class GemmConfig:
a_type = datatype
b_type = datatype
c_type = datatype
if b_type == 'int4':
if b_type == "int4":
a_type = "fp16"
if b_type in ['bf8', 'fp8', 'int4']:
if b_type in ["bf8", "fp8", "int4"]:
c_type = "fp16"
layout_parts = layout.lower()
assert len(layout_parts) == 3, (
f"Invalid layout string: {layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
)
assert layout_parts[0] in ("r", "c"), (
f"Invalid matrix_a layout: {layout_parts[0]} (must be 'r' for row major or or 'c' for column major)"
)
assert layout_parts[1] in ("r", "c"), (
f"Invalid matrix_a layout: {layout_parts[1]} (must be 'r' for row major or or 'c' for column major)"
)
assert layout_parts[2] == "r", (
f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
)
a_layout = layout_parts[0]
b_layout = layout_parts[1]
c_layout = layout_parts[2]
# Parse problem config
#TODO: Not reading datatype information from json file.
# TODO: Not reading datatype information from json file.
problem = ProblemConfig(
datatypes=(
EnumConfigParam(
values=[a_type]),
EnumConfigParam(
values=[b_type]),
EnumConfigParam(
values=[c_type])
EnumConfigParam(values=[a_type]),
EnumConfigParam(values=[b_type]),
EnumConfigParam(values=[c_type]),
),
layouts=(
EnumConfigParam(
values=config_dict["problem"]["layout_a"]["values"]
),
EnumConfigParam(
values=config_dict["problem"]["layout_b"]["values"]
),
EnumConfigParam(
values=config_dict["problem"]["layout_c"]["values"]
),
EnumConfigParam(values=[a_layout]),
EnumConfigParam(values=[b_layout]),
EnumConfigParam(values=[c_layout]),
),
)