mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Enabling diff datatypes for tile_engine and build with more granularity (#2392)
* merging recent changes to universal gemm to tile_engine * Reducing Linking time by generating less intermediate files * make small libs to build faster * Reducing the instances * reducing instances * Restoring default config * Restoring default config * warp_n reverted in default config * Adding diff json files for fp8 and fp16, cmake changes for fp8 * Restructure the CMake File * Added more granularity for build and some debugging code * removed some of debugging statements * added fp8 instances * tahe datatype from command line to enable both type of json files * updated README file * code cleanup * code cleanup * updated jenkinsfile * enable tile_engine daily builds * updating cmake file * updated CMakeLists.txt * Updating CMake code fixing gfx12 build * Updating CMake code fixing gfx12 build * Fix CMake file null checks * fixed traces of rebase * Update tile_engine/ops/gemm/README.md Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update tile_engine/ops/gemm/README.md Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update tile_engine/ops/gemm/README.md Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * fixing rebase issue --------- Co-authored-by: khushbu <khuagarw@gmail.com> Co-authored-by: ThomasNing <thomas.ning@amd.com> Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com> Co-authored-by: AviralGoelAMD <aviral.goel@amd.com> Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
This commit is contained in:
16
Jenkinsfile
vendored
16
Jenkinsfile
vendored
@@ -800,7 +800,7 @@ def process_results(Map conf=[:]){
|
||||
}
|
||||
|
||||
//launch develop branch daily jobs
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=false;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
|
||||
@@ -1216,9 +1216,12 @@ pipeline {
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx90a" \
|
||||
-D GEMM_DATATYPE="fp8;fp16" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
|
||||
ninja -j64 benchmark_gemm && \
|
||||
./bin/benchmark_gemm """
|
||||
ninja -j64 benchmark_gemm_fp8 && \
|
||||
./bin/benchmark_gemm_fp8 && \
|
||||
ninja -j64 benchmark_gemm_fp16 && \
|
||||
./bin/benchmark_gemm_fp16 """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
@@ -1238,9 +1241,12 @@ pipeline {
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx942" \
|
||||
-D GEMM_DATATYPE="fp8;fp16" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
|
||||
ninja -j128 benchmark_gemm && \
|
||||
./bin/benchmark_gemm """
|
||||
ninja -j128 benchmark_gemm_fp8 && \
|
||||
./bin/benchmark_gemm_fp8 && \
|
||||
ninja -j128 benchmark_gemm_fp16 && \
|
||||
./bin/benchmark_gemm_fp16 """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
|
||||
@@ -1,67 +1,134 @@
|
||||
# generate a list of kernels, but not actually emit files at config stage
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
|
||||
--list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "Fail to list kernels via Python. ${ret}")
|
||||
endif()
|
||||
set(GEMM_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)")
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS)
|
||||
|
||||
set(GEMM_CODEGEN_CPP_FILES "")
|
||||
set(GEMM_CODEGEN_HPP_FILES "")
|
||||
|
||||
foreach(blob ${GEMM_CODEGEN_BLOBS})
|
||||
string(STRIP "${blob}" stripped_blob)
|
||||
|
||||
if(stripped_blob MATCHES "\\.cpp$")
|
||||
list(APPEND GEMM_CODEGEN_CPP_FILES "${stripped_blob}")
|
||||
elseif(stripped_blob MATCHES "\\.hpp$")
|
||||
list(APPEND GEMM_CODEGEN_HPP_FILES "${stripped_blob}")
|
||||
function(build_gemm_for_datatype datatype)
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/")
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
|
||||
#set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json")
|
||||
# Generate kernel list
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--config_json ${json_blob}
|
||||
--list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to list kernels for ${datatype}: ${ret}")
|
||||
endif()
|
||||
|
||||
file(STRINGS "${working_path}/gemm_instance_blobs.txt" codegen_blobs)
|
||||
file(STRINGS "${working_path}/gemm_instance_blobs_range.txt" codegen_blobs_range)
|
||||
|
||||
# Generate the blobs
|
||||
add_custom_command(
|
||||
OUTPUT ${codegen_blobs}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path "${working_path}"
|
||||
--datatype ${datatype}
|
||||
--config_json "${json_blob}"
|
||||
--gen_blobs
|
||||
COMMENT "Generating GEMM instance sources for ${datatype}"
|
||||
)
|
||||
add_custom_target(gemm_gen_${datatype} DEPENDS ${codegen_blobs})
|
||||
|
||||
set(intermediate_libs)
|
||||
list(LENGTH codegen_blobs codegen_blobs_len)
|
||||
|
||||
foreach(blob IN LISTS codegen_blobs_range)
|
||||
string(STRIP "${blob}" stripped_blob)
|
||||
separate_arguments(spilit_blob UNIX_COMMAND "${stripped_blob}")
|
||||
# Each line is: <trait_name> <first_index_inclusive> <last_index_exclusive>
|
||||
list(GET spilit_blob 0 name)
|
||||
list(GET spilit_blob 1 first)
|
||||
list(GET spilit_blob 2 last)
|
||||
math(EXPR total_files "${last} - ${first}")
|
||||
if(total_files EQUAL 0)
|
||||
continue() # nothing for this trait
|
||||
endif()
|
||||
|
||||
# Object libraries (chunked) per trait
|
||||
set(sub_intermediate_libs)
|
||||
set(chunk_size 3)
|
||||
math(EXPR num_chunks "( ${total_files} + ${chunk_size} - 1 ) / ${chunk_size}")
|
||||
math(EXPR num_chunks_minus_1 "${num_chunks} - 1")
|
||||
|
||||
foreach(i RANGE 0 ${num_chunks_minus_1})
|
||||
math(EXPR start "${first} + ${i} * ${chunk_size} ")
|
||||
math(EXPR end "${start} + ${chunk_size} - 1")
|
||||
|
||||
set(chunk_files)
|
||||
foreach(j RANGE ${start} ${end})
|
||||
if(j LESS ${last} AND j LESS ${codegen_blobs_len})
|
||||
list(GET codegen_blobs ${j} f)
|
||||
list(APPEND chunk_files "${f}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
#list(LENGTH chunk_files chunk_files_len)
|
||||
#if(chunk_files_len AND chunk_files_len GREATER 1)
|
||||
if(chunk_files)
|
||||
set(sub_intermediate_lib_name "gemm_objlib_${name}_${i}_${datatype}")
|
||||
add_library(${sub_intermediate_lib_name} OBJECT ${chunk_files})
|
||||
list(APPEND sub_intermediate_libs ${sub_intermediate_lib_name})
|
||||
endif()
|
||||
|
||||
endforeach()
|
||||
|
||||
# ------------------ Bundle the object libs into one static lib ---------
|
||||
#list(LENGTH sub_intermediate_libs sub_intermediate_libs_len)
|
||||
#if(sub_intermediate_libs AND sub_intermediate_libs_len GREATER 1)
|
||||
if(sub_intermediate_libs)
|
||||
set(intermediate_lib_name "gemm_staticlib_${name}_${datatype}")
|
||||
# Collect the $<TARGET_OBJECTS:...> expressions
|
||||
|
||||
set(obj_exprs)
|
||||
foreach(objlib IN LISTS sub_intermediate_libs)
|
||||
list(APPEND obj_exprs $<TARGET_OBJECTS:${objlib}>)
|
||||
endforeach()
|
||||
|
||||
add_library(${intermediate_lib_name} STATIC ${obj_exprs})
|
||||
add_dependencies(${intermediate_lib_name} gemm_gen_${datatype})
|
||||
#foreach(objlib IN LISTS sub_intermediate_libs)
|
||||
# target_sources(${intermediate_lib_name} PRIVATE $<TARGET_OBJECTS:${objlib}>)
|
||||
#endforeach()
|
||||
list(APPEND intermediate_libs ${intermediate_lib_name})
|
||||
endif()
|
||||
|
||||
endforeach()
|
||||
|
||||
# Interface library for instances
|
||||
add_library(gemm_template_instances_${datatype} INTERFACE)
|
||||
add_dependencies(gemm_template_instances_${datatype} gemm_gen_${datatype})
|
||||
target_link_libraries(gemm_template_instances_${datatype} INTERFACE ${intermediate_libs})
|
||||
target_include_directories(gemm_template_instances_${datatype} INTERFACE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
"${working_path}"
|
||||
)
|
||||
set_target_properties(gemm_template_instances_${datatype} PROPERTIES LINKER_LANGUAGE CXX)
|
||||
|
||||
# Host API interface library
|
||||
add_library(gemm_host_api_${datatype} INTERFACE)
|
||||
target_link_libraries(gemm_host_api_${datatype} INTERFACE gemm_template_instances_${datatype})
|
||||
target_include_directories(gemm_host_api_${datatype} INTERFACE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
"${working_path}"
|
||||
)
|
||||
|
||||
|
||||
# Executable per datatype
|
||||
set(exec_name "benchmark_gemm_${datatype}")
|
||||
add_executable(${exec_name} benchmark_gemm.cpp)
|
||||
target_link_libraries(${exec_name} PRIVATE gemm_host_api_${datatype})
|
||||
target_compile_options(${exec_name} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
)
|
||||
endfunction()
|
||||
|
||||
# Process each datatype in isolation
|
||||
foreach(dt IN LISTS GEMM_DATATYPE)
|
||||
build_gemm_for_datatype(${dt})
|
||||
endforeach()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${GEMM_CODEGEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
|
||||
--gen_blobs
|
||||
)
|
||||
|
||||
add_library(gemm_template_instances OBJECT EXCLUDE_FROM_ALL ${GEMM_CODEGEN_CPP_FILES})
|
||||
# Explicitly set LINKER_LANGUAGE to avoid build config failures with Ninja.
|
||||
set_target_properties(gemm_template_instances PROPERTIES LINKER_LANGUAGE CXX)
|
||||
target_include_directories(gemm_template_instances PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(gemm_template_instances PRIVATE ${GEMM_CODEGEN_HPP_FILES})
|
||||
|
||||
set(BENCHMARK_GEMM_EXECUTABLE "benchmark_gemm")
|
||||
message(DEBUG "adding example ${BENCHMARK_GEMM_EXECUTABLE}")
|
||||
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
add_library(gemm_host_api INTERFACE EXCLUDE_FROM_ALL)
|
||||
target_include_directories(gemm_host_api INTERFACE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(gemm_host_api INTERFACE ${GEMM_CODEGEN_HPP_FILES} gemm_host_api.hpp)
|
||||
target_link_libraries(gemm_host_api INTERFACE gemm_template_instances)
|
||||
|
||||
add_executable(${BENCHMARK_GEMM_EXECUTABLE} EXCLUDE_FROM_ALL benchmark_gemm.cpp)
|
||||
target_include_directories(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE benchmark_gemm.hpp gemm_profiler.hpp)
|
||||
target_link_libraries(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE gemm_host_api)
|
||||
|
||||
set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS)
|
||||
|
||||
list(APPEND EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress)
|
||||
|
||||
target_compile_options(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE ${EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS})
|
||||
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
@@ -15,16 +15,27 @@ If user does not provide kernel configuration, the tile engine uses default kern
|
||||
# in the root of composable kernel create build directory
|
||||
mkdir build && cd build
|
||||
# build composable kernel
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # replace <arch> with the appropriate architecture (example gfx942) or leave blank
|
||||
# generate the executable
|
||||
make benchmark_gemm -j
|
||||
# replace [Arch] with the appropriate architecture or leave blank and
|
||||
# replace [Datatype1;Datatype2;...] in comma separated datatypes string (possible datatypes are [fp8, bf8, int8, fp16, bf16])
|
||||
sh ../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_DATATYPE="[Datatype1;Datatype2]"
|
||||
# generate different executable for each passed datatype
|
||||
make benchmark_gemm_[Datatype1] -j
|
||||
make benchmark_gemm_[Datatype2] -j
|
||||
```
|
||||
`benchmark_gemm` will be located in the `./bin/` directory.
|
||||
`benchmark_gemm_[Datatypes]` will be located in the `./bin/` directory.
|
||||
|
||||
`benchmark_gemm` must be rebuilt everytime if configuration file is modified.
|
||||
`benchmark_gemm_[Datatypes]` must be rebuilt everytime if configuration file is modified.
|
||||
|
||||
``` bash
|
||||
rm -rf tile_engine/ && make benchmark_gemm -j # rebuild
|
||||
rm -rf tile_engine/ && make benchmark_gemm_[Datatypes] -j # rebuild
|
||||
```
|
||||
|
||||
## For eaxmple build for gfx942 for fp8 and fp16 datatypes
|
||||
``` bash
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_DATATYPE="fp8;fp16"
|
||||
make benchmark_gemm_fp8 -j
|
||||
make benchmark_gemm_fp16 -j
|
||||
```
|
||||
|
||||
## benchmark_gemm inputs
|
||||
|
||||
@@ -199,7 +199,7 @@ warp_tile_supported_combinations = {
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
},
|
||||
"gfx942": {
|
||||
"fp16_fp16_fp16": [
|
||||
@@ -219,7 +219,7 @@ warp_tile_supported_combinations = {
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
|
||||
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
|
||||
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
|
||||
},
|
||||
"gfx950": {
|
||||
@@ -247,7 +247,7 @@ warp_tile_supported_combinations = {
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"fp8_fp8_fp16": [
|
||||
"bf8_bf8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 64],
|
||||
|
||||
116
tile_engine/ops/gemm/configs/benchmark.json
Normal file
116
tile_engine/ops/gemm/configs/benchmark.json
Normal file
@@ -0,0 +1,116 @@
|
||||
{
|
||||
"problem": {
|
||||
"layout_a": {
|
||||
"values": [
|
||||
"r"
|
||||
]
|
||||
},
|
||||
"layout_b": {
|
||||
"values": [
|
||||
"c"
|
||||
]
|
||||
},
|
||||
"layout_c": {
|
||||
"values": [
|
||||
"r"
|
||||
]
|
||||
}
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64,
|
||||
"exclude": [192]
|
||||
},
|
||||
"tile_n": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64,
|
||||
"exclude": [192]
|
||||
},
|
||||
"tile_k": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64,
|
||||
"exclude": [192]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
4,
|
||||
16,
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16,
|
||||
32,
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"compv4",
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave",
|
||||
"interwave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,136 +1,115 @@
|
||||
{
|
||||
"problem": {
|
||||
"layout_a": {
|
||||
"values": [
|
||||
"r"
|
||||
]
|
||||
},
|
||||
"layout_b": {
|
||||
"values": [
|
||||
"c"
|
||||
]
|
||||
},
|
||||
"layout_c": {
|
||||
"values": [
|
||||
"r"
|
||||
]
|
||||
},
|
||||
"datatype_a": {
|
||||
"values": [
|
||||
"fp16"
|
||||
]
|
||||
},
|
||||
"datatype_b": {
|
||||
"values": [
|
||||
"fp16"
|
||||
]
|
||||
},
|
||||
"datatype_c": {
|
||||
"values": [
|
||||
"fp16"
|
||||
]
|
||||
}
|
||||
"problem": {
|
||||
"layout_a": {
|
||||
"values": [
|
||||
"r"
|
||||
]
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64,
|
||||
"exclude": []
|
||||
},
|
||||
"tile_n": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 32,
|
||||
"exclude": []
|
||||
},
|
||||
"tile_k": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64,
|
||||
"exclude": [192]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128
|
||||
]
|
||||
}
|
||||
"layout_b": {
|
||||
"values": [
|
||||
"c"
|
||||
]
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv4",
|
||||
"compv3",
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave",
|
||||
"interwave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"default",
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
}
|
||||
"layout_c": {
|
||||
"values": [
|
||||
"r"
|
||||
]
|
||||
}
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
256
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
128,
|
||||
256
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
1,
|
||||
2,
|
||||
4
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
1,
|
||||
2,
|
||||
4
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
4,
|
||||
16,
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16,
|
||||
32,
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"compv4",
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave",
|
||||
"interwave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle",
|
||||
"default"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,27 +14,13 @@
|
||||
"values": [
|
||||
"r"
|
||||
]
|
||||
},
|
||||
"datatype_a": {
|
||||
"values": [
|
||||
"int8"
|
||||
]
|
||||
},
|
||||
"datatype_b": {
|
||||
"values": [
|
||||
"int8"
|
||||
]
|
||||
},
|
||||
"datatype_c": {
|
||||
"values": [
|
||||
"int32"
|
||||
]
|
||||
}
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
128
|
||||
128,
|
||||
256
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
@@ -49,12 +35,12 @@
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
4
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
|
||||
@@ -62,7 +62,7 @@ class GemmCodeGenerator:
|
||||
file_path = w_p / "gemm_instance_blobs.txt"
|
||||
self._generate_all_traits()
|
||||
self._get_valid_trait_tile_combinations()
|
||||
|
||||
file_range_map = {}
|
||||
# Write all file paths to the header file
|
||||
files_listed = 0
|
||||
with file_path.open("w") as f:
|
||||
@@ -81,9 +81,10 @@ class GemmCodeGenerator:
|
||||
trait_file = f"gemm_{trait}.hpp"
|
||||
f.write(str(w_p / trait_file) + "\n")
|
||||
files_listed += 1
|
||||
|
||||
file_name = set()
|
||||
# Instance source files
|
||||
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
|
||||
start_idx = files_listed
|
||||
for tile in tile_valid_params:
|
||||
for (
|
||||
tile_m,
|
||||
@@ -92,38 +93,24 @@ class GemmCodeGenerator:
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) in tile:
|
||||
instance_name = f"{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
|
||||
sparse = (
|
||||
self.config.problem.datatype_map["matrix_a"] == "fp16"
|
||||
and self.config.problem.datatype_map["matrix_b"] == "fp16"
|
||||
and self.config.problem.datatype_map["matrix_c"] == "fp16"
|
||||
and (
|
||||
(
|
||||
warp_tile_m == 32
|
||||
and warp_tile_n == 32
|
||||
and warp_tile_k == 16
|
||||
)
|
||||
or (
|
||||
warp_tile_m == 16
|
||||
and warp_tile_n == 16
|
||||
and warp_tile_k == 32
|
||||
)
|
||||
)
|
||||
)
|
||||
if sparse:
|
||||
sparse_file = f"gemm_{trait}_{instance_name}_true.cpp"
|
||||
f.write(str(w_p / sparse_file) + "\n")
|
||||
instance_name = f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp"
|
||||
|
||||
if instance_name not in file_name:
|
||||
file_name.add(instance_name)
|
||||
f.write(str(w_p / instance_name) + "\n")
|
||||
files_listed += 1
|
||||
|
||||
regular_file = f"gemm_{trait}_{instance_name}_false.cpp"
|
||||
f.write(str(w_p / regular_file) + "\n")
|
||||
files_listed += 1
|
||||
|
||||
print(f"File listing complete: {files_listed} files listed in {file_path}\n")
|
||||
file_range_map[trait] = (start_idx, files_listed)
|
||||
|
||||
file_path = w_p / 'gemm_instance_blobs_range.txt'
|
||||
with file_path.open('w') as f:
|
||||
for name, ranges in file_range_map.items():
|
||||
s, l = ranges
|
||||
f.write(name + " " + f"{s}" + " " + f"{l}"+ "\n")
|
||||
|
||||
def _generate_all_traits(self):
|
||||
"""Generate all possible kernel traits names."""
|
||||
@@ -246,7 +233,7 @@ struct GemmKernel {{
|
||||
static constexpr bool kPadN = {pad_n};
|
||||
static constexpr bool kPadK = {pad_k};
|
||||
|
||||
static float launch(ck_tile::GemmHostArgs<><>& args, const ck_tile::stream_config& stream) {{
|
||||
static float launch(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{
|
||||
static constexpr bool permuteA = false;
|
||||
static constexpr bool permuteB = false;
|
||||
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
|
||||
@@ -360,7 +347,6 @@ struct GemmKernel {{
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_));
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_));
|
||||
}};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
stream,
|
||||
@@ -577,8 +563,8 @@ struct GemmKernel {{
|
||||
self.valid_trait_tile_combinations[trait].append(tile_valid_params)
|
||||
|
||||
def _generate_instantiation_source_files(self):
|
||||
"""Generate kernel instance instantiation source files"""
|
||||
|
||||
"""Generate kernel instance instantiation source files """
|
||||
tile_map = {}
|
||||
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
|
||||
for tile in tile_valid_params:
|
||||
for (
|
||||
@@ -592,17 +578,28 @@ struct GemmKernel {{
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
) in tile:
|
||||
instance_name = f"{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
|
||||
key = f"{tile_m}x{tile_n}x{tile_k}x{warp_m}x{warp_n}x{warp_k}"
|
||||
value = f"{warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
|
||||
if key not in tile_map:
|
||||
tile_map[key] = set()
|
||||
tile_map[key].add(value)
|
||||
|
||||
files_listed = 0
|
||||
for trait, _ in self.valid_trait_tile_combinations.items():
|
||||
for block_tile, warp_tiles in tile_map.items():
|
||||
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k = map(int, block_tile.split('x'))
|
||||
|
||||
content = f"""
|
||||
content = f"""
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
|
||||
|
||||
#include "gemm_{trait}.hpp"
|
||||
|
||||
"""
|
||||
for warp_tile in warp_tiles:
|
||||
warp_tile_m, warp_tile_n, warp_tile_k = map(int, warp_tile.split("x"))
|
||||
|
||||
sparse = (
|
||||
self.config.problem.datatype_map["matrix_a"] == "fp16"
|
||||
and self.config.problem.datatype_map["matrix_b"] == "fp16"
|
||||
@@ -621,23 +618,17 @@ struct GemmKernel {{
|
||||
)
|
||||
)
|
||||
if sparse:
|
||||
sparse_filename = f"gemm_{trait}_{instance_name}_true.cpp"
|
||||
sparse_content = (
|
||||
content
|
||||
+ f"""
|
||||
template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, true>;
|
||||
files_listed = files_listed + 1
|
||||
content = content + f"""
|
||||
template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, true>;"""
|
||||
files_listed = files_listed + 1
|
||||
content = content + f"""
|
||||
template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, false>;"""
|
||||
content += f"""
|
||||
"""
|
||||
)
|
||||
(self.output_dir / sparse_filename).write_text(sparse_content)
|
||||
|
||||
no_sparse_filename = f"gemm_{trait}_{instance_name}_false.cpp"
|
||||
no_sparse_content = (
|
||||
content
|
||||
+ f"""
|
||||
template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, false>;
|
||||
"""
|
||||
)
|
||||
(self.output_dir / no_sparse_filename).write_text(no_sparse_content)
|
||||
(self.output_dir /
|
||||
f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp").write_text(content)
|
||||
print(f"Generated {files_listed} kernel instances in total.")
|
||||
|
||||
def _generate_dispatcher_file(self):
|
||||
"""Generate the code block of dispatch mechanism."""
|
||||
@@ -682,8 +673,7 @@ struct GemmDispatcher {
|
||||
return kernel_map;
|
||||
}
|
||||
|
||||
static void init(bool structured_sparsity) {
|
||||
(void)structured_sparsity; // Suppress unused parameter warning
|
||||
static void init([[maybe_unused]]bool structured_sparsity) {
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(!kernel_map.empty()) return;
|
||||
\n"""
|
||||
@@ -703,7 +693,7 @@ struct GemmDispatcher {
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
) = tile[j]
|
||||
content += f"""[=](ck_tile::GemmHostArgs<><>& args, const ck_tile::stream_config& stream) {{ """
|
||||
content += f"""[=](ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{ """
|
||||
content += f"""
|
||||
if(structured_sparsity){{ // SMFMA"""
|
||||
sparse = (
|
||||
@@ -795,7 +785,7 @@ def do_gen_blobs(
|
||||
|
||||
def main(args):
|
||||
gemm_config = (
|
||||
GemmConfig.from_json(args.config_json)
|
||||
GemmConfig.from_json(args.config_json, args.datatype)
|
||||
if args.config_json is not None
|
||||
else args.config_json
|
||||
)
|
||||
@@ -829,6 +819,12 @@ if __name__ == "__main__":
|
||||
required=False,
|
||||
help="Path to the json which contains the configurations that user provide",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--datatype",
|
||||
required=True,
|
||||
help="Specify what datatype to use for the kernel generation, e.g. fp16, bf16, int8, fp8, bf8"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--list_blobs",
|
||||
|
||||
@@ -23,7 +23,6 @@ class GemmProfiler
|
||||
void benchmark(GemmProblem& gemm_problem,
|
||||
std::vector<std::function<std::tuple<std::string, float>(
|
||||
ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables)
|
||||
ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables)
|
||||
{
|
||||
const ALayout layout_a = ALayout{};
|
||||
const BLayout layout_b = BLayout{};
|
||||
|
||||
@@ -118,7 +118,7 @@ class GemmConfig:
|
||||
trait_config: TraitConfig
|
||||
|
||||
@classmethod
|
||||
def from_json(cls: Type["GemmConfig"], filepath: str) -> "GemmConfig":
|
||||
def from_json(cls: Type["GemmConfig"], filepath: str, datatype: str) -> "GemmConfig":
|
||||
"""JSON configuration loader with validation controls"""
|
||||
config_path = Path(filepath)
|
||||
|
||||
@@ -129,18 +129,24 @@ class GemmConfig:
|
||||
with config_path.open("r") as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
a_type = datatype
|
||||
b_type = datatype
|
||||
c_type = datatype
|
||||
if b_type == 'int4':
|
||||
a_type = "fp16"
|
||||
if b_type in ['bf8', 'fp8', 'int4']:
|
||||
c_type = "fp16"
|
||||
|
||||
# Parse problem config
|
||||
#TODO: Not reading datatype information from json file.
|
||||
problem = ProblemConfig(
|
||||
datatypes=(
|
||||
EnumConfigParam(
|
||||
values=config_dict["problem"]["datatype_a"]["values"]
|
||||
),
|
||||
values=[a_type]),
|
||||
EnumConfigParam(
|
||||
values=config_dict["problem"]["datatype_b"]["values"]
|
||||
),
|
||||
values=[b_type]),
|
||||
EnumConfigParam(
|
||||
values=config_dict["problem"]["datatype_c"]["values"]
|
||||
),
|
||||
values=[c_type])
|
||||
),
|
||||
layouts=(
|
||||
EnumConfigParam(
|
||||
|
||||
Reference in New Issue
Block a user