diff --git a/Jenkinsfile b/Jenkinsfile index 9f1c021878..aec833587f 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -800,7 +800,7 @@ def process_results(Map conf=[:]){ } //launch develop branch daily jobs -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=false;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true @@ -1216,9 +1216,12 @@ pipeline { -D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_BUILD_TYPE=Release \ -D GPU_TARGETS="gfx90a" \ + -D GEMM_DATATYPE="fp8;fp16" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ - ninja -j64 benchmark_gemm && \ - ./bin/benchmark_gemm """ + ninja -j64 benchmark_gemm_fp8 && \ + ./bin/benchmark_gemm_fp8 && \ + ninja -j64 benchmark_gemm_fp16 && \ + ./bin/benchmark_gemm_fp16 """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -1238,9 +1241,12 @@ pipeline { -D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_BUILD_TYPE=Release \ -D GPU_TARGETS="gfx942" \ + -D GEMM_DATATYPE="fp8;fp16" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ - ninja -j128 benchmark_gemm && \ - ./bin/benchmark_gemm """ + ninja -j128 benchmark_gemm_fp8 && \ + ./bin/benchmark_gemm_fp8 && \ + ninja -j128 benchmark_gemm_fp16 && \ + ./bin/benchmark_gemm_fp16 """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index c3c177487f..5db55f02d5 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -1,67 +1,134 @@ -# generate a list of kernels, but not actually emit files at config stage -execute_process( - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py - --working_path ${CMAKE_CURRENT_BINARY_DIR} - # --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json - --list_blobs - RESULT_VARIABLE ret -) -if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "Fail to list kernels via Python. ${ret}") -endif() +set(GEMM_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)") -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS) - -set(GEMM_CODEGEN_CPP_FILES "") -set(GEMM_CODEGEN_HPP_FILES "") - -foreach(blob ${GEMM_CODEGEN_BLOBS}) - string(STRIP "${blob}" stripped_blob) - - if(stripped_blob MATCHES "\\.cpp$") - list(APPEND GEMM_CODEGEN_CPP_FILES "${stripped_blob}") - elseif(stripped_blob MATCHES "\\.hpp$") - list(APPEND GEMM_CODEGEN_HPP_FILES "${stripped_blob}") +function(build_gemm_for_datatype datatype) + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/") + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") + #set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json") + # Generate kernel list + execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --config_json ${json_blob} + --list_blobs + RESULT_VARIABLE ret + ) + if(NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to list kernels for ${datatype}: ${ret}") endif() + + file(STRINGS "${working_path}/gemm_instance_blobs.txt" codegen_blobs) + file(STRINGS "${working_path}/gemm_instance_blobs_range.txt" codegen_blobs_range) + + # Generate the blobs + add_custom_command( + OUTPUT ${codegen_blobs} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + --working_path "${working_path}" + --datatype ${datatype} + --config_json "${json_blob}" + --gen_blobs + COMMENT "Generating GEMM instance sources for ${datatype}" + ) + add_custom_target(gemm_gen_${datatype} DEPENDS ${codegen_blobs}) + + set(intermediate_libs) + list(LENGTH codegen_blobs codegen_blobs_len) + + foreach(blob IN LISTS codegen_blobs_range) + string(STRIP "${blob}" stripped_blob) + separate_arguments(spilit_blob UNIX_COMMAND "${stripped_blob}") + # Each line is: + list(GET spilit_blob 0 name) + list(GET spilit_blob 1 first) + list(GET spilit_blob 2 last) + math(EXPR total_files "${last} - ${first}") + if(total_files EQUAL 0) + continue() # nothing for this trait + endif() + + # Object libraries (chunked) per trait + set(sub_intermediate_libs) + set(chunk_size 3) + math(EXPR num_chunks "( ${total_files} + ${chunk_size} - 1 ) / ${chunk_size}") + math(EXPR num_chunks_minus_1 "${num_chunks} - 1") + + foreach(i RANGE 0 ${num_chunks_minus_1}) + math(EXPR start "${first} + ${i} * ${chunk_size} ") + math(EXPR end "${start} + ${chunk_size} - 1") + + set(chunk_files) + foreach(j RANGE ${start} ${end}) + if(j LESS ${last} AND j LESS ${codegen_blobs_len}) + list(GET codegen_blobs ${j} f) + list(APPEND chunk_files "${f}") + endif() + endforeach() + + #list(LENGTH chunk_files chunk_files_len) + #if(chunk_files_len AND chunk_files_len GREATER 1) + if(chunk_files) + set(sub_intermediate_lib_name "gemm_objlib_${name}_${i}_${datatype}") + add_library(${sub_intermediate_lib_name} OBJECT ${chunk_files}) + list(APPEND sub_intermediate_libs ${sub_intermediate_lib_name}) + endif() + + endforeach() + + # ------------------ Bundle the object libs into one static lib --------- + #list(LENGTH sub_intermediate_libs sub_intermediate_libs_len) + #if(sub_intermediate_libs AND sub_intermediate_libs_len GREATER 1) + if(sub_intermediate_libs) + set(intermediate_lib_name "gemm_staticlib_${name}_${datatype}") + # Collect the $ expressions + + set(obj_exprs) + foreach(objlib IN LISTS sub_intermediate_libs) + list(APPEND obj_exprs $) + endforeach() + + add_library(${intermediate_lib_name} STATIC ${obj_exprs}) + add_dependencies(${intermediate_lib_name} gemm_gen_${datatype}) + #foreach(objlib IN LISTS sub_intermediate_libs) + # target_sources(${intermediate_lib_name} PRIVATE $) + #endforeach() + list(APPEND intermediate_libs ${intermediate_lib_name}) + endif() + + endforeach() + + # Interface library for instances + add_library(gemm_template_instances_${datatype} INTERFACE) + add_dependencies(gemm_template_instances_${datatype} gemm_gen_${datatype}) + target_link_libraries(gemm_template_instances_${datatype} INTERFACE ${intermediate_libs}) + target_include_directories(gemm_template_instances_${datatype} INTERFACE + ${CMAKE_CURRENT_LIST_DIR} + "${working_path}" + ) + set_target_properties(gemm_template_instances_${datatype} PROPERTIES LINKER_LANGUAGE CXX) + + # Host API interface library + add_library(gemm_host_api_${datatype} INTERFACE) + target_link_libraries(gemm_host_api_${datatype} INTERFACE gemm_template_instances_${datatype}) + target_include_directories(gemm_host_api_${datatype} INTERFACE + ${CMAKE_CURRENT_LIST_DIR} + "${working_path}" + ) + + + # Executable per datatype + set(exec_name "benchmark_gemm_${datatype}") + add_executable(${exec_name} benchmark_gemm.cpp) + target_link_libraries(${exec_name} PRIVATE gemm_host_api_${datatype}) + target_compile_options(${exec_name} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + ) +endfunction() + +# Process each datatype in isolation +foreach(dt IN LISTS GEMM_DATATYPE) + build_gemm_for_datatype(${dt}) endforeach() - -add_custom_command( - OUTPUT ${GEMM_CODEGEN_BLOBS} - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py - --working_path ${CMAKE_CURRENT_BINARY_DIR} - # --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json - --gen_blobs -) - -add_library(gemm_template_instances OBJECT EXCLUDE_FROM_ALL ${GEMM_CODEGEN_CPP_FILES}) -# Explicitly set LINKER_LANGUAGE to avoid build config failures with Ninja. -set_target_properties(gemm_template_instances PROPERTIES LINKER_LANGUAGE CXX) -target_include_directories(gemm_template_instances PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_sources(gemm_template_instances PRIVATE ${GEMM_CODEGEN_HPP_FILES}) - -set(BENCHMARK_GEMM_EXECUTABLE "benchmark_gemm") -message(DEBUG "adding example ${BENCHMARK_GEMM_EXECUTABLE}") - -include_directories(${CMAKE_CURRENT_BINARY_DIR}) - -add_library(gemm_host_api INTERFACE EXCLUDE_FROM_ALL) -target_include_directories(gemm_host_api INTERFACE ${CMAKE_CURRENT_LIST_DIR}) -target_sources(gemm_host_api INTERFACE ${GEMM_CODEGEN_HPP_FILES} gemm_host_api.hpp) -target_link_libraries(gemm_host_api INTERFACE gemm_template_instances) - -add_executable(${BENCHMARK_GEMM_EXECUTABLE} EXCLUDE_FROM_ALL benchmark_gemm.cpp) -target_include_directories(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_sources(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE benchmark_gemm.hpp gemm_profiler.hpp) -target_link_libraries(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE gemm_host_api) - -set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS) - -list(APPEND EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS - -Wno-undefined-func-template - -Wno-float-equal - --offload-compress) - -target_compile_options(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE ${EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS}) - -set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) \ No newline at end of file diff --git a/tile_engine/ops/gemm/README.md b/tile_engine/ops/gemm/README.md index db624e576e..40cb9acd1c 100644 --- a/tile_engine/ops/gemm/README.md +++ b/tile_engine/ops/gemm/README.md @@ -15,16 +15,27 @@ If user does not provide kernel configuration, the tile engine uses default kern # in the root of composable kernel create build directory mkdir build && cd build # build composable kernel -sh ../script/cmake-ck-dev.sh ../ # replace with the appropriate architecture (example gfx942) or leave blank -# generate the executable -make benchmark_gemm -j +# replace [Arch] with the appropriate architecture or leave blank and +# replace [Datatype1;Datatype2;...] in comma separated datatypes string (possible datatypes are [fp8, bf8, int8, fp16, bf16]) +sh ../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_DATATYPE="[Datatype1;Datatype2]" +# generate different executable for each passed datatype +make benchmark_gemm_[Datatype1] -j +make benchmark_gemm_[Datatype2] -j ``` -`benchmark_gemm` will be located in the `./bin/` directory. +`benchmark_gemm_[Datatypes]` will be located in the `./bin/` directory. -`benchmark_gemm` must be rebuilt everytime if configuration file is modified. +`benchmark_gemm_[Datatypes]` must be rebuilt everytime if configuration file is modified. ``` bash -rm -rf tile_engine/ && make benchmark_gemm -j # rebuild +rm -rf tile_engine/ && make benchmark_gemm_[Datatypes] -j # rebuild +``` + +## For eaxmple build for gfx942 for fp8 and fp16 datatypes +``` bash +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_DATATYPE="fp8;fp16" +make benchmark_gemm_fp8 -j +make benchmark_gemm_fp16 -j ``` ## benchmark_gemm inputs diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index ae496636c6..9ff76724cc 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -199,7 +199,7 @@ warp_tile_supported_combinations = { [64, 4, 16], ], "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], - "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]], }, "gfx942": { "fp16_fp16_fp16": [ @@ -219,7 +219,7 @@ warp_tile_supported_combinations = { [64, 4, 16], ], "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], - "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], }, "gfx950": { @@ -247,7 +247,7 @@ warp_tile_supported_combinations = { [16, 16, 128], [32, 32, 64], ], - "fp8_fp8_fp16": [ + "bf8_bf8_fp16": [ [32, 32, 16], [32, 32, 32], [16, 16, 64], diff --git a/tile_engine/ops/gemm/configs/benchmark.json b/tile_engine/ops/gemm/configs/benchmark.json new file mode 100644 index 0000000000..601784049b --- /dev/null +++ b/tile_engine/ops/gemm/configs/benchmark.json @@ -0,0 +1,116 @@ +{ + "problem": { + "layout_a": { + "values": [ + "r" + ] + }, + "layout_b": { + "values": [ + "c" + ] + }, + "layout_c": { + "values": [ + "r" + ] + } + }, + "tile_config": { + "tile_m": { + "max": 256, + "min": 64, + "step": 64, + "exclude": [192] + }, + "tile_n": { + "max": 256, + "min": 64, + "step": 64, + "exclude": [192] + }, + "tile_k": { + "max": 256, + "min": 64, + "step": 64, + "exclude": [192] + }, + "warp_m": { + "values": [ + 4, + 2, + 1 + ] + }, + "warp_n": { + "values": [ + 4, + 2, + 1 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 4, + 16, + 32 + ] + }, + "warp_tile_n": { + "values": [ + 16, + 32, + 64 + ] + }, + "warp_tile_k": { + "values": [ + 8, + 16, + 32, + 64, + 128 + ] + } + }, + "trait_config": { + "pipeline": { + "values": [ + "compv3", + "compv4", + "mem" + ] + }, + "scheduler": { + "values": [ + "intrawave", + "interwave" + ] + }, + "epilogue": { + "values": [ + "cshuffle" + ] + }, + "pad_m": { + "values": [ + false + ] + }, + "pad_n": { + "values": [ + false + ] + }, + "pad_k": { + "values": [ + false + ] + } + } +} \ No newline at end of file diff --git a/tile_engine/ops/gemm/configs/default_config.json b/tile_engine/ops/gemm/configs/default_config.json index 9f71e430de..069a3b080c 100644 --- a/tile_engine/ops/gemm/configs/default_config.json +++ b/tile_engine/ops/gemm/configs/default_config.json @@ -1,136 +1,115 @@ { - "problem": { - "layout_a": { - "values": [ - "r" - ] - }, - "layout_b": { - "values": [ - "c" - ] - }, - "layout_c": { - "values": [ - "r" - ] - }, - "datatype_a": { - "values": [ - "fp16" - ] - }, - "datatype_b": { - "values": [ - "fp16" - ] - }, - "datatype_c": { - "values": [ - "fp16" - ] - } + "problem": { + "layout_a": { + "values": [ + "r" + ] }, - "tile_config": { - "tile_m": { - "max": 256, - "min": 64, - "step": 64, - "exclude": [] - }, - "tile_n": { - "max": 256, - "min": 64, - "step": 32, - "exclude": [] - }, - "tile_k": { - "max": 256, - "min": 64, - "step": 64, - "exclude": [192] - }, - "warp_m": { - "values": [ - 4, - 2, - 1 - ] - }, - "warp_n": { - "values": [ - 4, - 2, - 1 - ] - }, - "warp_k": { - "values": [ - 1 - ] - }, - "warp_tile_m": { - "values": [ - 4, - 8, - 16, - 32, - 64 - ] - }, - "warp_tile_n": { - "values": [ - 4, - 8, - 16, - 32, - 64 - ] - }, - "warp_tile_k": { - "values": [ - 8, - 16, - 32, - 64, - 128 - ] - } + "layout_b": { + "values": [ + "c" + ] }, - "trait_config": { - "pipeline": { - "values": [ - "compv4", - "compv3", - "mem" - ] - }, - "scheduler": { - "values": [ - "intrawave", - "interwave" - ] - }, - "epilogue": { - "values": [ - "default", - "cshuffle" - ] - }, - "pad_m": { - "values": [ - false - ] - }, - "pad_n": { - "values": [ - false - ] - }, - "pad_k": { - "values": [ - false - ] - } + "layout_c": { + "values": [ + "r" + ] } + }, + "tile_config": { + "tile_m": { + "values": [ + 256 + ] + }, + "tile_n": { + "values": [ + 128, + 256 + ] + }, + "tile_k": { + "values": [ + 32 + ] + }, + "warp_m": { + "values": [ + 1, + 2, + 4 + ] + }, + "warp_n": { + "values": [ + 1, + 2, + 4 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 4, + 16, + 32 + ] + }, + "warp_tile_n": { + "values": [ + 16, + 32, + 64 + ] + }, + "warp_tile_k": { + "values": [ + 8, + 16, + 32, + 64, + 128 + ] + } + }, + "trait_config": { + "pipeline": { + "values": [ + "compv3", + "compv4", + "mem" + ] + }, + "scheduler": { + "values": [ + "intrawave", + "interwave" + ] + }, + "epilogue": { + "values": [ + "cshuffle", + "default" + ] + }, + "pad_m": { + "values": [ + false + ] + }, + "pad_n": { + "values": [ + false + ] + }, + "pad_k": { + "values": [ + false + ] + } + } } \ No newline at end of file diff --git a/tile_engine/ops/gemm/configs/user_provided_config.json b/tile_engine/ops/gemm/configs/user_provided_config.json index 43c8784667..79bcced82a 100644 --- a/tile_engine/ops/gemm/configs/user_provided_config.json +++ b/tile_engine/ops/gemm/configs/user_provided_config.json @@ -14,27 +14,13 @@ "values": [ "r" ] - }, - "datatype_a": { - "values": [ - "int8" - ] - }, - "datatype_b": { - "values": [ - "int8" - ] - }, - "datatype_c": { - "values": [ - "int32" - ] } }, "tile_config": { "tile_m": { "values": [ - 128 + 128, + 256 ] }, "tile_n": { @@ -49,12 +35,12 @@ }, "warp_m": { "values": [ - 2 + 4 ] }, "warp_n": { "values": [ - 2 + 1 ] }, "warp_k": { diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index f217522feb..de1fd0bb62 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -62,7 +62,7 @@ class GemmCodeGenerator: file_path = w_p / "gemm_instance_blobs.txt" self._generate_all_traits() self._get_valid_trait_tile_combinations() - + file_range_map = {} # Write all file paths to the header file files_listed = 0 with file_path.open("w") as f: @@ -81,9 +81,10 @@ class GemmCodeGenerator: trait_file = f"gemm_{trait}.hpp" f.write(str(w_p / trait_file) + "\n") files_listed += 1 - + file_name = set() # Instance source files for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): + start_idx = files_listed for tile in tile_valid_params: for ( tile_m, @@ -92,38 +93,24 @@ class GemmCodeGenerator: warp_m, warp_n, warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, + _, + _, + _, ) in tile: - instance_name = f"{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}" - sparse = ( - self.config.problem.datatype_map["matrix_a"] == "fp16" - and self.config.problem.datatype_map["matrix_b"] == "fp16" - and self.config.problem.datatype_map["matrix_c"] == "fp16" - and ( - ( - warp_tile_m == 32 - and warp_tile_n == 32 - and warp_tile_k == 16 - ) - or ( - warp_tile_m == 16 - and warp_tile_n == 16 - and warp_tile_k == 32 - ) - ) - ) - if sparse: - sparse_file = f"gemm_{trait}_{instance_name}_true.cpp" - f.write(str(w_p / sparse_file) + "\n") + instance_name = f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp" + + if instance_name not in file_name: + file_name.add(instance_name) + f.write(str(w_p / instance_name) + "\n") files_listed += 1 - regular_file = f"gemm_{trait}_{instance_name}_false.cpp" - f.write(str(w_p / regular_file) + "\n") - files_listed += 1 - - print(f"File listing complete: {files_listed} files listed in {file_path}\n") + file_range_map[trait] = (start_idx, files_listed) + + file_path = w_p / 'gemm_instance_blobs_range.txt' + with file_path.open('w') as f: + for name, ranges in file_range_map.items(): + s, l = ranges + f.write(name + " " + f"{s}" + " " + f"{l}"+ "\n") def _generate_all_traits(self): """Generate all possible kernel traits names.""" @@ -246,7 +233,7 @@ struct GemmKernel {{ static constexpr bool kPadN = {pad_n}; static constexpr bool kPadK = {pad_k}; - static float launch(ck_tile::GemmHostArgs<><>& args, const ck_tile::stream_config& stream) {{ + static float launch(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{ static constexpr bool permuteA = false; static constexpr bool permuteB = false; static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"}; @@ -360,7 +347,6 @@ struct GemmKernel {{ if(args.k_batch > 1) hipGetErrorString(hipMemsetAsync( args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_)); - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_)); }}; ave_time = ck_tile::launch_kernel_preprocess( stream, @@ -577,8 +563,8 @@ struct GemmKernel {{ self.valid_trait_tile_combinations[trait].append(tile_valid_params) def _generate_instantiation_source_files(self): - """Generate kernel instance instantiation source files""" - + """Generate kernel instance instantiation source files """ + tile_map = {} for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): for tile in tile_valid_params: for ( @@ -592,17 +578,28 @@ struct GemmKernel {{ warp_tile_n, warp_tile_k, ) in tile: - instance_name = f"{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}" + key = f"{tile_m}x{tile_n}x{tile_k}x{warp_m}x{warp_n}x{warp_k}" + value = f"{warp_tile_m}x{warp_tile_n}x{warp_tile_k}" + if key not in tile_map: + tile_map[key] = set() + tile_map[key].add(value) + + files_listed = 0 + for trait, _ in self.valid_trait_tile_combinations.items(): + for block_tile, warp_tiles in tile_map.items(): + tile_m, tile_n, tile_k, warp_m, warp_n, warp_k = map(int, block_tile.split('x')) - content = f""" + content = f""" // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - #include "gemm_{trait}.hpp" """ + for warp_tile in warp_tiles: + warp_tile_m, warp_tile_n, warp_tile_k = map(int, warp_tile.split("x")) + sparse = ( self.config.problem.datatype_map["matrix_a"] == "fp16" and self.config.problem.datatype_map["matrix_b"] == "fp16" @@ -621,23 +618,17 @@ struct GemmKernel {{ ) ) if sparse: - sparse_filename = f"gemm_{trait}_{instance_name}_true.cpp" - sparse_content = ( - content - + f""" -template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, true>; + files_listed = files_listed + 1 + content = content + f""" +template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, true>;""" + files_listed = files_listed + 1 + content = content + f""" +template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, false>;""" + content += f""" """ - ) - (self.output_dir / sparse_filename).write_text(sparse_content) - - no_sparse_filename = f"gemm_{trait}_{instance_name}_false.cpp" - no_sparse_content = ( - content - + f""" -template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, false>; -""" - ) - (self.output_dir / no_sparse_filename).write_text(no_sparse_content) + (self.output_dir / + f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp").write_text(content) + print(f"Generated {files_listed} kernel instances in total.") def _generate_dispatcher_file(self): """Generate the code block of dispatch mechanism.""" @@ -682,8 +673,7 @@ struct GemmDispatcher { return kernel_map; } - static void init(bool structured_sparsity) { - (void)structured_sparsity; // Suppress unused parameter warning + static void init([[maybe_unused]]bool structured_sparsity) { auto& kernel_map = get_kernel_map(); if(!kernel_map.empty()) return; \n""" @@ -703,7 +693,7 @@ struct GemmDispatcher { warp_tile_n, warp_tile_k, ) = tile[j] - content += f"""[=](ck_tile::GemmHostArgs<><>& args, const ck_tile::stream_config& stream) {{ """ + content += f"""[=](ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{ """ content += f""" if(structured_sparsity){{ // SMFMA""" sparse = ( @@ -795,7 +785,7 @@ def do_gen_blobs( def main(args): gemm_config = ( - GemmConfig.from_json(args.config_json) + GemmConfig.from_json(args.config_json, args.datatype) if args.config_json is not None else args.config_json ) @@ -829,6 +819,12 @@ if __name__ == "__main__": required=False, help="Path to the json which contains the configurations that user provide", ) + parser.add_argument( + "-d", + "--datatype", + required=True, + help="Specify what datatype to use for the kernel generation, e.g. fp16, bf16, int8, fp8, bf8" + ) parser.add_argument( "-l", "--list_blobs", diff --git a/tile_engine/ops/gemm/gemm_profiler.hpp b/tile_engine/ops/gemm/gemm_profiler.hpp index 272799e4d6..20f601d46e 100644 --- a/tile_engine/ops/gemm/gemm_profiler.hpp +++ b/tile_engine/ops/gemm/gemm_profiler.hpp @@ -23,7 +23,6 @@ class GemmProfiler void benchmark(GemmProblem& gemm_problem, std::vector( ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables) - ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables) { const ALayout layout_a = ALayout{}; const BLayout layout_b = BLayout{}; diff --git a/tile_engine/ops/gemm/json_config.py b/tile_engine/ops/gemm/json_config.py index aaf732c6a8..8b83977dd3 100644 --- a/tile_engine/ops/gemm/json_config.py +++ b/tile_engine/ops/gemm/json_config.py @@ -118,7 +118,7 @@ class GemmConfig: trait_config: TraitConfig @classmethod - def from_json(cls: Type["GemmConfig"], filepath: str) -> "GemmConfig": + def from_json(cls: Type["GemmConfig"], filepath: str, datatype: str) -> "GemmConfig": """JSON configuration loader with validation controls""" config_path = Path(filepath) @@ -129,18 +129,24 @@ class GemmConfig: with config_path.open("r") as f: config_dict = json.load(f) + a_type = datatype + b_type = datatype + c_type = datatype + if b_type == 'int4': + a_type = "fp16" + if b_type in ['bf8', 'fp8', 'int4']: + c_type = "fp16" + # Parse problem config + #TODO: Not reading datatype information from json file. problem = ProblemConfig( datatypes=( EnumConfigParam( - values=config_dict["problem"]["datatype_a"]["values"] - ), + values=[a_type]), EnumConfigParam( - values=config_dict["problem"]["datatype_b"]["values"] - ), + values=[b_type]), EnumConfigParam( - values=config_dict["problem"]["datatype_c"]["values"] - ), + values=[c_type]) ), layouts=( EnumConfigParam(