Enabling diff datatypes for tile_engine and build with more granularity (#2392)

* merging recent changes to universal gemm to tile_engine

* Reducing Linking time by generating less intermediate files

* make small libs to build faster

* Reducing the instances

* reducing instances

* Restoring default config

* Restoring default config

* warp_n reverted in default config

* Adding diff json files for fp8 and fp16, cmake changes for fp8

* Restructure the CMake File

* Added more granularity for build and some debugging code

* removed some of debugging statements

* added fp8 instances

* tahe datatype from command line to enable both type of json files

* updated README file

* code cleanup

* code cleanup

* updated jenkinsfile

* enable tile_engine daily builds

* updating cmake file

* updated CMakeLists.txt

* Updating CMake code fixing gfx12 build

* Updating CMake code fixing gfx12 build

* Fix CMake file null checks

* fixed traces of rebase

* Update tile_engine/ops/gemm/README.md

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Update tile_engine/ops/gemm/README.md

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Update tile_engine/ops/gemm/README.md

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* fixing rebase issue

---------

Co-authored-by: khushbu <khuagarw@gmail.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
Co-authored-by: AviralGoelAMD <aviral.goel@amd.com>
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
This commit is contained in:
Khushbu Agarwal
2025-06-25 15:18:24 -07:00
committed by GitHub
parent e03293ebce
commit a14753b86f
10 changed files with 458 additions and 292 deletions

16
Jenkinsfile vendored
View File

@@ -800,7 +800,7 @@ def process_results(Map conf=[:]){
}
//launch develop branch daily jobs
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=false;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
@@ -1216,9 +1216,12 @@ pipeline {
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx90a" \
-D GEMM_DATATYPE="fp8;fp16" \
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
ninja -j64 benchmark_gemm && \
./bin/benchmark_gemm """
ninja -j64 benchmark_gemm_fp8 && \
./bin/benchmark_gemm_fp8 && \
ninja -j64 benchmark_gemm_fp16 && \
./bin/benchmark_gemm_fp16 """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
@@ -1238,9 +1241,12 @@ pipeline {
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx942" \
-D GEMM_DATATYPE="fp8;fp16" \
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
ninja -j128 benchmark_gemm && \
./bin/benchmark_gemm """
ninja -j128 benchmark_gemm_fp8 && \
./bin/benchmark_gemm_fp8 && \
ninja -j128 benchmark_gemm_fp16 && \
./bin/benchmark_gemm_fp16 """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)

View File

@@ -1,67 +1,134 @@
# generate a list of kernels, but not actually emit files at config stage
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--list_blobs
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "Fail to list kernels via Python. ${ret}")
endif()
set(GEMM_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)")
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS)
set(GEMM_CODEGEN_CPP_FILES "")
set(GEMM_CODEGEN_HPP_FILES "")
foreach(blob ${GEMM_CODEGEN_BLOBS})
string(STRIP "${blob}" stripped_blob)
if(stripped_blob MATCHES "\\.cpp$")
list(APPEND GEMM_CODEGEN_CPP_FILES "${stripped_blob}")
elseif(stripped_blob MATCHES "\\.hpp$")
list(APPEND GEMM_CODEGEN_HPP_FILES "${stripped_blob}")
function(build_gemm_for_datatype datatype)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/")
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
#set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json")
# Generate kernel list
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--config_json ${json_blob}
--list_blobs
RESULT_VARIABLE ret
)
if(NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to list kernels for ${datatype}: ${ret}")
endif()
file(STRINGS "${working_path}/gemm_instance_blobs.txt" codegen_blobs)
file(STRINGS "${working_path}/gemm_instance_blobs_range.txt" codegen_blobs_range)
# Generate the blobs
add_custom_command(
OUTPUT ${codegen_blobs}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path "${working_path}"
--datatype ${datatype}
--config_json "${json_blob}"
--gen_blobs
COMMENT "Generating GEMM instance sources for ${datatype}"
)
add_custom_target(gemm_gen_${datatype} DEPENDS ${codegen_blobs})
set(intermediate_libs)
list(LENGTH codegen_blobs codegen_blobs_len)
foreach(blob IN LISTS codegen_blobs_range)
string(STRIP "${blob}" stripped_blob)
separate_arguments(spilit_blob UNIX_COMMAND "${stripped_blob}")
# Each line is: <trait_name> <first_index_inclusive> <last_index_exclusive>
list(GET spilit_blob 0 name)
list(GET spilit_blob 1 first)
list(GET spilit_blob 2 last)
math(EXPR total_files "${last} - ${first}")
if(total_files EQUAL 0)
continue() # nothing for this trait
endif()
# Object libraries (chunked) per trait
set(sub_intermediate_libs)
set(chunk_size 3)
math(EXPR num_chunks "( ${total_files} + ${chunk_size} - 1 ) / ${chunk_size}")
math(EXPR num_chunks_minus_1 "${num_chunks} - 1")
foreach(i RANGE 0 ${num_chunks_minus_1})
math(EXPR start "${first} + ${i} * ${chunk_size} ")
math(EXPR end "${start} + ${chunk_size} - 1")
set(chunk_files)
foreach(j RANGE ${start} ${end})
if(j LESS ${last} AND j LESS ${codegen_blobs_len})
list(GET codegen_blobs ${j} f)
list(APPEND chunk_files "${f}")
endif()
endforeach()
#list(LENGTH chunk_files chunk_files_len)
#if(chunk_files_len AND chunk_files_len GREATER 1)
if(chunk_files)
set(sub_intermediate_lib_name "gemm_objlib_${name}_${i}_${datatype}")
add_library(${sub_intermediate_lib_name} OBJECT ${chunk_files})
list(APPEND sub_intermediate_libs ${sub_intermediate_lib_name})
endif()
endforeach()
# ------------------ Bundle the object libs into one static lib ---------
#list(LENGTH sub_intermediate_libs sub_intermediate_libs_len)
#if(sub_intermediate_libs AND sub_intermediate_libs_len GREATER 1)
if(sub_intermediate_libs)
set(intermediate_lib_name "gemm_staticlib_${name}_${datatype}")
# Collect the $<TARGET_OBJECTS:...> expressions
set(obj_exprs)
foreach(objlib IN LISTS sub_intermediate_libs)
list(APPEND obj_exprs $<TARGET_OBJECTS:${objlib}>)
endforeach()
add_library(${intermediate_lib_name} STATIC ${obj_exprs})
add_dependencies(${intermediate_lib_name} gemm_gen_${datatype})
#foreach(objlib IN LISTS sub_intermediate_libs)
# target_sources(${intermediate_lib_name} PRIVATE $<TARGET_OBJECTS:${objlib}>)
#endforeach()
list(APPEND intermediate_libs ${intermediate_lib_name})
endif()
endforeach()
# Interface library for instances
add_library(gemm_template_instances_${datatype} INTERFACE)
add_dependencies(gemm_template_instances_${datatype} gemm_gen_${datatype})
target_link_libraries(gemm_template_instances_${datatype} INTERFACE ${intermediate_libs})
target_include_directories(gemm_template_instances_${datatype} INTERFACE
${CMAKE_CURRENT_LIST_DIR}
"${working_path}"
)
set_target_properties(gemm_template_instances_${datatype} PROPERTIES LINKER_LANGUAGE CXX)
# Host API interface library
add_library(gemm_host_api_${datatype} INTERFACE)
target_link_libraries(gemm_host_api_${datatype} INTERFACE gemm_template_instances_${datatype})
target_include_directories(gemm_host_api_${datatype} INTERFACE
${CMAKE_CURRENT_LIST_DIR}
"${working_path}"
)
# Executable per datatype
set(exec_name "benchmark_gemm_${datatype}")
add_executable(${exec_name} benchmark_gemm.cpp)
target_link_libraries(${exec_name} PRIVATE gemm_host_api_${datatype})
target_compile_options(${exec_name} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
)
endfunction()
# Process each datatype in isolation
foreach(dt IN LISTS GEMM_DATATYPE)
build_gemm_for_datatype(${dt})
endforeach()
add_custom_command(
OUTPUT ${GEMM_CODEGEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--gen_blobs
)
add_library(gemm_template_instances OBJECT EXCLUDE_FROM_ALL ${GEMM_CODEGEN_CPP_FILES})
# Explicitly set LINKER_LANGUAGE to avoid build config failures with Ninja.
set_target_properties(gemm_template_instances PROPERTIES LINKER_LANGUAGE CXX)
target_include_directories(gemm_template_instances PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(gemm_template_instances PRIVATE ${GEMM_CODEGEN_HPP_FILES})
set(BENCHMARK_GEMM_EXECUTABLE "benchmark_gemm")
message(DEBUG "adding example ${BENCHMARK_GEMM_EXECUTABLE}")
include_directories(${CMAKE_CURRENT_BINARY_DIR})
add_library(gemm_host_api INTERFACE EXCLUDE_FROM_ALL)
target_include_directories(gemm_host_api INTERFACE ${CMAKE_CURRENT_LIST_DIR})
target_sources(gemm_host_api INTERFACE ${GEMM_CODEGEN_HPP_FILES} gemm_host_api.hpp)
target_link_libraries(gemm_host_api INTERFACE gemm_template_instances)
add_executable(${BENCHMARK_GEMM_EXECUTABLE} EXCLUDE_FROM_ALL benchmark_gemm.cpp)
target_include_directories(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE benchmark_gemm.hpp gemm_profiler.hpp)
target_link_libraries(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE gemm_host_api)
set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS)
list(APPEND EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress)
target_compile_options(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE ${EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS})
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

View File

@@ -15,16 +15,27 @@ If user does not provide kernel configuration, the tile engine uses default kern
# in the root of composable kernel create build directory
mkdir build && cd build
# build composable kernel
sh ../script/cmake-ck-dev.sh ../ <arch> # replace <arch> with the appropriate architecture (example gfx942) or leave blank
# generate the executable
make benchmark_gemm -j
# replace [Arch] with the appropriate architecture or leave blank and
# replace [Datatype1;Datatype2;...] in comma separated datatypes string (possible datatypes are [fp8, bf8, int8, fp16, bf16])
sh ../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_DATATYPE="[Datatype1;Datatype2]"
# generate different executable for each passed datatype
make benchmark_gemm_[Datatype1] -j
make benchmark_gemm_[Datatype2] -j
```
`benchmark_gemm` will be located in the `./bin/` directory.
`benchmark_gemm_[Datatypes]` will be located in the `./bin/` directory.
`benchmark_gemm` must be rebuilt everytime if configuration file is modified.
`benchmark_gemm_[Datatypes]` must be rebuilt everytime if configuration file is modified.
``` bash
rm -rf tile_engine/ && make benchmark_gemm -j # rebuild
rm -rf tile_engine/ && make benchmark_gemm_[Datatypes] -j # rebuild
```
## For eaxmple build for gfx942 for fp8 and fp16 datatypes
``` bash
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_DATATYPE="fp8;fp16"
make benchmark_gemm_fp8 -j
make benchmark_gemm_fp16 -j
```
## benchmark_gemm inputs

View File

@@ -199,7 +199,7 @@ warp_tile_supported_combinations = {
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]],
},
"gfx942": {
"fp16_fp16_fp16": [
@@ -219,7 +219,7 @@ warp_tile_supported_combinations = {
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
},
"gfx950": {
@@ -247,7 +247,7 @@ warp_tile_supported_combinations = {
[16, 16, 128],
[32, 32, 64],
],
"fp8_fp8_fp16": [
"bf8_bf8_fp16": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 64],

View File

@@ -0,0 +1,116 @@
{
"problem": {
"layout_a": {
"values": [
"r"
]
},
"layout_b": {
"values": [
"c"
]
},
"layout_c": {
"values": [
"r"
]
}
},
"tile_config": {
"tile_m": {
"max": 256,
"min": 64,
"step": 64,
"exclude": [192]
},
"tile_n": {
"max": 256,
"min": 64,
"step": 64,
"exclude": [192]
},
"tile_k": {
"max": 256,
"min": 64,
"step": 64,
"exclude": [192]
},
"warp_m": {
"values": [
4,
2,
1
]
},
"warp_n": {
"values": [
4,
2,
1
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
4,
16,
32
]
},
"warp_tile_n": {
"values": [
16,
32,
64
]
},
"warp_tile_k": {
"values": [
8,
16,
32,
64,
128
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3",
"compv4",
"mem"
]
},
"scheduler": {
"values": [
"intrawave",
"interwave"
]
},
"epilogue": {
"values": [
"cshuffle"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
}
}
}

View File

@@ -1,136 +1,115 @@
{
"problem": {
"layout_a": {
"values": [
"r"
]
},
"layout_b": {
"values": [
"c"
]
},
"layout_c": {
"values": [
"r"
]
},
"datatype_a": {
"values": [
"fp16"
]
},
"datatype_b": {
"values": [
"fp16"
]
},
"datatype_c": {
"values": [
"fp16"
]
}
"problem": {
"layout_a": {
"values": [
"r"
]
},
"tile_config": {
"tile_m": {
"max": 256,
"min": 64,
"step": 64,
"exclude": []
},
"tile_n": {
"max": 256,
"min": 64,
"step": 32,
"exclude": []
},
"tile_k": {
"max": 256,
"min": 64,
"step": 64,
"exclude": [192]
},
"warp_m": {
"values": [
4,
2,
1
]
},
"warp_n": {
"values": [
4,
2,
1
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
4,
8,
16,
32,
64
]
},
"warp_tile_n": {
"values": [
4,
8,
16,
32,
64
]
},
"warp_tile_k": {
"values": [
8,
16,
32,
64,
128
]
}
"layout_b": {
"values": [
"c"
]
},
"trait_config": {
"pipeline": {
"values": [
"compv4",
"compv3",
"mem"
]
},
"scheduler": {
"values": [
"intrawave",
"interwave"
]
},
"epilogue": {
"values": [
"default",
"cshuffle"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
}
"layout_c": {
"values": [
"r"
]
}
},
"tile_config": {
"tile_m": {
"values": [
256
]
},
"tile_n": {
"values": [
128,
256
]
},
"tile_k": {
"values": [
32
]
},
"warp_m": {
"values": [
1,
2,
4
]
},
"warp_n": {
"values": [
1,
2,
4
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
4,
16,
32
]
},
"warp_tile_n": {
"values": [
16,
32,
64
]
},
"warp_tile_k": {
"values": [
8,
16,
32,
64,
128
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3",
"compv4",
"mem"
]
},
"scheduler": {
"values": [
"intrawave",
"interwave"
]
},
"epilogue": {
"values": [
"cshuffle",
"default"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
}
}
}

View File

@@ -14,27 +14,13 @@
"values": [
"r"
]
},
"datatype_a": {
"values": [
"int8"
]
},
"datatype_b": {
"values": [
"int8"
]
},
"datatype_c": {
"values": [
"int32"
]
}
},
"tile_config": {
"tile_m": {
"values": [
128
128,
256
]
},
"tile_n": {
@@ -49,12 +35,12 @@
},
"warp_m": {
"values": [
2
4
]
},
"warp_n": {
"values": [
2
1
]
},
"warp_k": {

View File

@@ -62,7 +62,7 @@ class GemmCodeGenerator:
file_path = w_p / "gemm_instance_blobs.txt"
self._generate_all_traits()
self._get_valid_trait_tile_combinations()
file_range_map = {}
# Write all file paths to the header file
files_listed = 0
with file_path.open("w") as f:
@@ -81,9 +81,10 @@ class GemmCodeGenerator:
trait_file = f"gemm_{trait}.hpp"
f.write(str(w_p / trait_file) + "\n")
files_listed += 1
file_name = set()
# Instance source files
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
start_idx = files_listed
for tile in tile_valid_params:
for (
tile_m,
@@ -92,38 +93,24 @@ class GemmCodeGenerator:
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
_,
_,
_,
) in tile:
instance_name = f"{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
sparse = (
self.config.problem.datatype_map["matrix_a"] == "fp16"
and self.config.problem.datatype_map["matrix_b"] == "fp16"
and self.config.problem.datatype_map["matrix_c"] == "fp16"
and (
(
warp_tile_m == 32
and warp_tile_n == 32
and warp_tile_k == 16
)
or (
warp_tile_m == 16
and warp_tile_n == 16
and warp_tile_k == 32
)
)
)
if sparse:
sparse_file = f"gemm_{trait}_{instance_name}_true.cpp"
f.write(str(w_p / sparse_file) + "\n")
instance_name = f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp"
if instance_name not in file_name:
file_name.add(instance_name)
f.write(str(w_p / instance_name) + "\n")
files_listed += 1
regular_file = f"gemm_{trait}_{instance_name}_false.cpp"
f.write(str(w_p / regular_file) + "\n")
files_listed += 1
print(f"File listing complete: {files_listed} files listed in {file_path}\n")
file_range_map[trait] = (start_idx, files_listed)
file_path = w_p / 'gemm_instance_blobs_range.txt'
with file_path.open('w') as f:
for name, ranges in file_range_map.items():
s, l = ranges
f.write(name + " " + f"{s}" + " " + f"{l}"+ "\n")
def _generate_all_traits(self):
"""Generate all possible kernel traits names."""
@@ -246,7 +233,7 @@ struct GemmKernel {{
static constexpr bool kPadN = {pad_n};
static constexpr bool kPadK = {pad_k};
static float launch(ck_tile::GemmHostArgs<><>& args, const ck_tile::stream_config& stream) {{
static float launch(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{
static constexpr bool permuteA = false;
static constexpr bool permuteB = false;
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
@@ -360,7 +347,6 @@ struct GemmKernel {{
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_));
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_));
}};
ave_time = ck_tile::launch_kernel_preprocess(
stream,
@@ -577,8 +563,8 @@ struct GemmKernel {{
self.valid_trait_tile_combinations[trait].append(tile_valid_params)
def _generate_instantiation_source_files(self):
"""Generate kernel instance instantiation source files"""
"""Generate kernel instance instantiation source files """
tile_map = {}
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
for tile in tile_valid_params:
for (
@@ -592,17 +578,28 @@ struct GemmKernel {{
warp_tile_n,
warp_tile_k,
) in tile:
instance_name = f"{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
key = f"{tile_m}x{tile_n}x{tile_k}x{warp_m}x{warp_n}x{warp_k}"
value = f"{warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
if key not in tile_map:
tile_map[key] = set()
tile_map[key].add(value)
files_listed = 0
for trait, _ in self.valid_trait_tile_combinations.items():
for block_tile, warp_tiles in tile_map.items():
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k = map(int, block_tile.split('x'))
content = f"""
content = f"""
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_{trait}.hpp"
"""
for warp_tile in warp_tiles:
warp_tile_m, warp_tile_n, warp_tile_k = map(int, warp_tile.split("x"))
sparse = (
self.config.problem.datatype_map["matrix_a"] == "fp16"
and self.config.problem.datatype_map["matrix_b"] == "fp16"
@@ -621,23 +618,17 @@ struct GemmKernel {{
)
)
if sparse:
sparse_filename = f"gemm_{trait}_{instance_name}_true.cpp"
sparse_content = (
content
+ f"""
template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, true>;
files_listed = files_listed + 1
content = content + f"""
template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, true>;"""
files_listed = files_listed + 1
content = content + f"""
template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, false>;"""
content += f"""
"""
)
(self.output_dir / sparse_filename).write_text(sparse_content)
no_sparse_filename = f"gemm_{trait}_{instance_name}_false.cpp"
no_sparse_content = (
content
+ f"""
template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, false>;
"""
)
(self.output_dir / no_sparse_filename).write_text(no_sparse_content)
(self.output_dir /
f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp").write_text(content)
print(f"Generated {files_listed} kernel instances in total.")
def _generate_dispatcher_file(self):
"""Generate the code block of dispatch mechanism."""
@@ -682,8 +673,7 @@ struct GemmDispatcher {
return kernel_map;
}
static void init(bool structured_sparsity) {
(void)structured_sparsity; // Suppress unused parameter warning
static void init([[maybe_unused]]bool structured_sparsity) {
auto& kernel_map = get_kernel_map();
if(!kernel_map.empty()) return;
\n"""
@@ -703,7 +693,7 @@ struct GemmDispatcher {
warp_tile_n,
warp_tile_k,
) = tile[j]
content += f"""[=](ck_tile::GemmHostArgs<><>& args, const ck_tile::stream_config& stream) {{ """
content += f"""[=](ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{ """
content += f"""
if(structured_sparsity){{ // SMFMA"""
sparse = (
@@ -795,7 +785,7 @@ def do_gen_blobs(
def main(args):
gemm_config = (
GemmConfig.from_json(args.config_json)
GemmConfig.from_json(args.config_json, args.datatype)
if args.config_json is not None
else args.config_json
)
@@ -829,6 +819,12 @@ if __name__ == "__main__":
required=False,
help="Path to the json which contains the configurations that user provide",
)
parser.add_argument(
"-d",
"--datatype",
required=True,
help="Specify what datatype to use for the kernel generation, e.g. fp16, bf16, int8, fp8, bf8"
)
parser.add_argument(
"-l",
"--list_blobs",

View File

@@ -23,7 +23,6 @@ class GemmProfiler
void benchmark(GemmProblem& gemm_problem,
std::vector<std::function<std::tuple<std::string, float>(
ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables)
ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables)
{
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};

View File

@@ -118,7 +118,7 @@ class GemmConfig:
trait_config: TraitConfig
@classmethod
def from_json(cls: Type["GemmConfig"], filepath: str) -> "GemmConfig":
def from_json(cls: Type["GemmConfig"], filepath: str, datatype: str) -> "GemmConfig":
"""JSON configuration loader with validation controls"""
config_path = Path(filepath)
@@ -129,18 +129,24 @@ class GemmConfig:
with config_path.open("r") as f:
config_dict = json.load(f)
a_type = datatype
b_type = datatype
c_type = datatype
if b_type == 'int4':
a_type = "fp16"
if b_type in ['bf8', 'fp8', 'int4']:
c_type = "fp16"
# Parse problem config
#TODO: Not reading datatype information from json file.
problem = ProblemConfig(
datatypes=(
EnumConfigParam(
values=config_dict["problem"]["datatype_a"]["values"]
),
values=[a_type]),
EnumConfigParam(
values=config_dict["problem"]["datatype_b"]["values"]
),
values=[b_type]),
EnumConfigParam(
values=config_dict["problem"]["datatype_c"]["values"]
),
values=[c_type])
),
layouts=(
EnumConfigParam(