mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
GEMM Multi D for CK Tile Engine (#2660)
* Readme for GEMM Multi D * GEMM Multi D partial Progress * GEMM Multi D partial Progress! * CK Tile Engine GEMM Multi D : All Python files generated * Partial Progress * Partial Progress * Partial Progress * Partial Progress : Incorrect Result * Partial Progress : Debugging * Partial Progress : Correct Results * Partial Progress - Incorrect Results * Partial Progress - Commenting Passthrough bypass logic * Changing Passthrough to MultiplyMultiply * Correct Results! * Fix and debug the pass through feature * Sample commit * Correct Results : MultiplyMultiply * Code Cleanup * Removing Failed Instances * Working code before Unary element support * Custom Elementwise Function support and working implementation for Mul and Add * Updating README * Working for Passthrough * Review Comments : Minor Fixes * Review Comments : Minor Fixes * Readme Updated * Partial Changes after Rebase * Working Code : Changes after Rebase * Updating Jenkins file * Removing default value changed while testing * Configuration changes in config files * Tile Handler changes in GEMM Multi D Tile Engine * Tile Handler changes in GEMM Multi D Example * Change log for Gemm Multi D in CK Tile Engine * Configuration changes in config files --------- Co-authored-by: ThomasNing <thomasning@amd.com>
This commit is contained in:
committed by
GitHub
parent
30dafe8281
commit
3f57ec3d2d
152
tile_engine/ops/gemm_multi_d/CMakeLists.txt
Normal file
152
tile_engine/ops/gemm_multi_d/CMakeLists.txt
Normal file
@@ -0,0 +1,152 @@
|
||||
|
||||
set(GEMM_MULTI_D_DATATYPE "fp16" CACHE STRING "List of datatypes for GEMM Multi D (semicolon-separated)")
|
||||
set(GEMM_MULTI_D_LAYOUT "rcrr" CACHE STRING "List of layout for GEMM Multi D(semicolon-separated)")
|
||||
set(GEMM_MULTI_D_ELEMENTWISE_FUNCTION "mul" CACHE STRING "Elementwise function")
|
||||
|
||||
function(build_gemm_multi_d_for_datatype_layout datatype layout)
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
|
||||
|
||||
# Comment this if-else block when using user_provided_config
|
||||
if(layout STREQUAL "rcrr")
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
|
||||
else()
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/custom_ci_config.json")
|
||||
endif()
|
||||
|
||||
# uncomment this if you want to use user_provided_config.json
|
||||
# set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json")
|
||||
|
||||
# Generate kernel list
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
|
||||
--config_json ${json_blob}
|
||||
--list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${ret}")
|
||||
endif()
|
||||
|
||||
file(STRINGS "${working_path}/gemm_multi_d_instance_blobs.txt" codegen_blobs)
|
||||
file(STRINGS "${working_path}/gemm_multi_d_instance_blobs_range.txt" codegen_blobs_range)
|
||||
|
||||
# Generate the blobs
|
||||
add_custom_command(
|
||||
OUTPUT ${codegen_blobs}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py
|
||||
--working_path "${working_path}"
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
|
||||
--config_json "${json_blob}"
|
||||
--gen_blobs
|
||||
COMMENT "Generating GEMM Multi D instance sources for ${datatype} ${layout}"
|
||||
)
|
||||
add_custom_target(gemm_multi_d_gen_${datatype}_${layout} DEPENDS ${codegen_blobs})
|
||||
|
||||
set(intermediate_libs)
|
||||
list(LENGTH codegen_blobs codegen_blobs_len)
|
||||
|
||||
foreach(blob IN LISTS codegen_blobs_range)
|
||||
string(STRIP "${blob}" stripped_blob)
|
||||
separate_arguments(spilit_blob UNIX_COMMAND "${stripped_blob}")
|
||||
# Each line is: <trait_name> <first_index_inclusive> <last_index_exclusive>
|
||||
list(GET spilit_blob 0 name)
|
||||
list(GET spilit_blob 1 first)
|
||||
list(GET spilit_blob 2 last)
|
||||
math(EXPR total_files "${last} - ${first}")
|
||||
if(total_files EQUAL 0)
|
||||
continue() # nothing for this trait
|
||||
endif()
|
||||
|
||||
# Object libraries (chunked) per trait
|
||||
set(sub_intermediate_libs)
|
||||
set(chunk_size 3)
|
||||
math(EXPR num_chunks "( ${total_files} + ${chunk_size} - 1 ) / ${chunk_size}")
|
||||
math(EXPR num_chunks_minus_1 "${num_chunks} - 1")
|
||||
|
||||
foreach(i RANGE 0 ${num_chunks_minus_1})
|
||||
math(EXPR start "${first} + ${i} * ${chunk_size} ")
|
||||
math(EXPR end "${start} + ${chunk_size} - 1")
|
||||
|
||||
set(chunk_files)
|
||||
foreach(j RANGE ${start} ${end})
|
||||
if(j LESS ${last} AND j LESS ${codegen_blobs_len})
|
||||
list(GET codegen_blobs ${j} f)
|
||||
list(APPEND chunk_files "${f}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
#list(LENGTH chunk_files chunk_files_len)
|
||||
#if(chunk_files_len AND chunk_files_len GREATER 1)
|
||||
if(chunk_files)
|
||||
set(sub_intermediate_lib_name "gemm_multi_d_objlib_${name}_${i}_${datatype}_${layout}")
|
||||
add_library(${sub_intermediate_lib_name} OBJECT ${chunk_files})
|
||||
list(APPEND sub_intermediate_libs ${sub_intermediate_lib_name})
|
||||
endif()
|
||||
|
||||
endforeach()
|
||||
|
||||
# ------------------ Bundle the object libs into one static lib ---------
|
||||
#list(LENGTH sub_intermediate_libs sub_intermediate_libs_len)
|
||||
#if(sub_intermediate_libs AND sub_intermediate_libs_len GREATER 1)
|
||||
if(sub_intermediate_libs)
|
||||
set(intermediate_lib_name "gemm_multi_d_staticlib_${name}_${datatype}_${layout}")
|
||||
# Collect the $<TARGET_OBJECTS:...> expressions
|
||||
|
||||
set(obj_exprs)
|
||||
foreach(objlib IN LISTS sub_intermediate_libs)
|
||||
list(APPEND obj_exprs $<TARGET_OBJECTS:${objlib}>)
|
||||
endforeach()
|
||||
|
||||
add_library(${intermediate_lib_name} STATIC ${obj_exprs})
|
||||
add_dependencies(${intermediate_lib_name} gemm_multi_d_gen_${datatype}_${layout})
|
||||
#foreach(objlib IN LISTS sub_intermediate_libs)
|
||||
# target_sources(${intermediate_lib_name} PRIVATE $<TARGET_OBJECTS:${objlib}>)
|
||||
#endforeach()
|
||||
list(APPEND intermediate_libs ${intermediate_lib_name})
|
||||
endif()
|
||||
|
||||
endforeach()
|
||||
|
||||
# Interface library for instances
|
||||
add_library(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE)
|
||||
add_dependencies(gemm_multi_d_template_instances_${datatype}_${layout} gemm_multi_d_gen_${datatype}_${layout})
|
||||
target_link_libraries(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE ${intermediate_libs})
|
||||
target_include_directories(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
"${working_path}"
|
||||
)
|
||||
set_target_properties(gemm_multi_d_template_instances_${datatype}_${layout} PROPERTIES LINKER_LANGUAGE CXX)
|
||||
|
||||
# Host API interface library
|
||||
add_library(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE)
|
||||
target_link_libraries(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE gemm_multi_d_template_instances_${datatype}_${layout})
|
||||
target_include_directories(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
"${working_path}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
# Executable per datatype
|
||||
set(exec_name "benchmark_gemm_multi_d_${datatype}_${layout}")
|
||||
add_executable(${exec_name} benchmark_gemm_multi_d.cpp)
|
||||
target_link_libraries(${exec_name} PRIVATE gemm_multi_d_host_api_${datatype}_${layout})
|
||||
target_compile_options(${exec_name} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
)
|
||||
endfunction()
|
||||
|
||||
# Process each datatype in isolation
|
||||
foreach(dt IN LISTS GEMM_MULTI_D_DATATYPE)
|
||||
foreach(l IN LISTS GEMM_MULTI_D_LAYOUT)
|
||||
build_gemm_multi_d_for_datatype_layout(${dt} ${l})
|
||||
endforeach()
|
||||
endforeach()
|
||||
110
tile_engine/ops/gemm_multi_d/README.md
Normal file
110
tile_engine/ops/gemm_multi_d/README.md
Normal file
@@ -0,0 +1,110 @@
|
||||
|
||||
CK Tile Engine for GEMM Multi D is used to generate and run GEMM kernels with different combinations of BlockTile sizes, WarpTile sizes, WarpTile mapping for all valid pipelines, schedulers and epilogues while able to give custom datatype and Layout selections
|
||||
|
||||
# Kernel Configurations
|
||||
|
||||
# User Specific
|
||||
Users can specify custom kernel configurations such as tile size, warp size, padding, pipeline, scheduler, and epilogue in the config file. This allows building only for selected configurations, significantly reducing build time.
|
||||
For reference please see `./configs/user_provided_config.json`.
|
||||
|
||||
# Default
|
||||
The Tile engine also has a default kernel configuration for providing range of configuration parameter values, which helps users who lack kernel development experience to benchmark. For reference please see in `./configs/default_config.json`
|
||||
|
||||
If user does not provide kernel configuration, the tile engine uses default kernel configuration to generate kernel instances and benchmark.
|
||||
|
||||
## Build Instructions
|
||||
``` bash
|
||||
# in the root of composable kernel create build directory
|
||||
mkdir build && cd build
|
||||
# build composable kernel
|
||||
# replace [Arch] with the appropriate architecture or leave blank and
|
||||
# replace [Datatype] in comma separated datatypes string (possible datatypes are [fp16])
|
||||
# replace [Layout1;Layout2;...] in comma separated datatypes string (possible layouts are [rcr, rrr, crr, ccr])
|
||||
# replace "mul" with either of mul,add,passthrough for Elementwise function as Multiply, Add or Passthrough respectively. If this is not specified it is considered as mul by default.
|
||||
sh ../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_MULTI_D_DATATYPE="[Datatype]" -DGEMM_MULTI_D_LAYOUT="[Layout1;Layout2]" -DGEMM_MULTI_D_ELEMENTWISE_FUNCTION="mul"
|
||||
# generate different executable for each passed datatype
|
||||
make benchmark_gemm_multi_d_[Datatype]_[Layout1] -j
|
||||
make benchmark_gemm_multi_d_[Datatype]_[Layout2] -j
|
||||
```
|
||||
`benchmark_gemm_multi_d_[Datatype]_[Layout]` will be located in the `./bin/` directory.
|
||||
|
||||
`benchmark_gemm_multi_d_[Datatype]_[Layout]` must be rebuilt everytime if configuration file is modified.
|
||||
|
||||
``` bash
|
||||
rm -rf tile_engine/ && make benchmark_gemm_multi_d_[Datatype]_[Layout] -j # rebuild
|
||||
```
|
||||
|
||||
## For eaxmple build for gfx942 for datatype with rcr layout
|
||||
``` bash
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_MULTI_D_DATATYPE="fp16" -DGEMM_MULTI_D_LAYOUT="rcrr"
|
||||
make benchmark_gemm_multi_d_fp16_rcrr -j
|
||||
|
||||
## benchmark_gemm inputs
|
||||
```
|
||||
-m The value for m dimension. Default is 3840.
|
||||
-n The value for n dimension. Default is 4096.
|
||||
-k The value for k dimension. Default is 2048.
|
||||
-stride_a The stride value for tensor A. Default is 0.
|
||||
-stride_b The stride value for tensor B. Default is 0.
|
||||
-stride_ds The stride value for tensor Ds. Default is 0.
|
||||
-stride_e The stride value for tensor E. Default is 0.
|
||||
-split_k The split value for k dimension. Default is 1.
|
||||
-verify The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 for validation on GPU. Default is 1, validation on CPU, as validation on GPU is not supported.
|
||||
-log Wether output kernel instance information or not. Possible values are true or false. Default is false.
|
||||
-warmup The number of iterations before benchmark the kernel. Default is 50.
|
||||
-repeat The number of iterations to benchmark the kernel. Default is 100.
|
||||
-timer Whether if the timer is gpu timer or not. Possible values are false or true. Default is true.
|
||||
-init The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 for constant(1). Default is 0, random.
|
||||
-flush_cache To flush cache, possible values are true or false. Default is false.
|
||||
-rotating_count Number of iterations to rotate the cache. Default is 5.
|
||||
-metric Metric with which to measure kernel performance. Set to 0 for latency, 1 for tflops, or 2 for bandwidth. Default is 0, latency.
|
||||
-csv_filename The filename of benchmark result. Default is gemm_multi_d_kernel.
|
||||
-pipeline The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.
|
||||
-scheduler The type of scheduler. Possible values are intrawave. Default is intrawave.
|
||||
-epilogue The type of epilogue. Possible values are cshuffle or default. Default is cshuffle.
|
||||
-pad_m Whether pad or not in m direction. Possible values are true or false. Default is false.
|
||||
-pad_n Whether pad or not in n direction. Possible values are true or false. Default is false.
|
||||
-pad_k Whether pad or not in k direction. Possible values are true or false. Default is false.
|
||||
|
||||
Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in user_provided_config.json
|
||||
```
|
||||
Note: In `./configs/user_provided_config.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above.
|
||||
|
||||
## Example
|
||||
|
||||
The following JSON file specifies parameters used to generate and build GEMM kernels across all possible combinations of pipelines, schedulers, epilogues with different tile and warp sizes.
|
||||
|
||||
```json
|
||||
{
|
||||
/// other parameters ///
|
||||
|
||||
"tile_m": {
|
||||
"values": [256]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [256]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [64, 32]
|
||||
},
|
||||
|
||||
/// other parameters ///
|
||||
|
||||
"pipeline": {
|
||||
"values": ["compv3", "compv4", "mem"]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": ["intrawave", "interwave"]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": ["cshuffle"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
At runtime, a specific subset of the generated kernels can be selected using command-line arguments.
|
||||
``` bash
|
||||
./bin/benchmark_gemm_multi_d_[Datatype]_[Layout] -pipeline=compv3 -scheduler=intrawave -epilogue=cshuffle
|
||||
```
|
||||
The above command runs kernels configured with the compv3 pipeline, intrawave scheduler, and cshuffle epilogue, while sweeping over different BlockTile sizes, WarpTile sizes, and WarpTile mappings.
|
||||
73
tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.cpp
Normal file
73
tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.cpp
Normal file
@@ -0,0 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <functional>
|
||||
#include <tuple>
|
||||
#include <exception>
|
||||
|
||||
#include "benchmark_gemm_multi_d.hpp"
|
||||
#include "gemm_multi_d_profiler.hpp"
|
||||
|
||||
void benchmark_gemm_multi_d(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
GemmMultiDProblem gemm_multi_d_problem{arg_parser.get_int("split_k"),
|
||||
arg_parser.get_int("m"),
|
||||
arg_parser.get_int("n"),
|
||||
arg_parser.get_int("k"),
|
||||
arg_parser.get_int("stride_a"),
|
||||
arg_parser.get_int("stride_b"),
|
||||
arg_parser.get_int("stride_ds"),
|
||||
arg_parser.get_int("stride_ds"),
|
||||
arg_parser.get_int("stride_e"),
|
||||
DataTypeTraits<ADataType>::name,
|
||||
DataTypeTraits<BDataType>::name,
|
||||
DataTypeTraits<D0DataType>::name,
|
||||
DataTypeTraits<D1DataType>::name,
|
||||
DataTypeTraits<AccDataType>::name,
|
||||
DataTypeTraits<EDataType>::name,
|
||||
ALayout::name,
|
||||
BLayout::name,
|
||||
D0Layout::name,
|
||||
D1Layout::name,
|
||||
ELayout::name};
|
||||
|
||||
Setting setting{arg_parser.get_int("warmup"),
|
||||
arg_parser.get_int("repeat"),
|
||||
arg_parser.get_bool("timer"),
|
||||
arg_parser.get_int("verify"),
|
||||
arg_parser.get_int("init"),
|
||||
arg_parser.get_bool("log"),
|
||||
arg_parser.get_str("csv_filename"),
|
||||
arg_parser.get_bool("flush_cache"),
|
||||
arg_parser.get_int("rotating_count")};
|
||||
|
||||
auto& profiler = GemmMultiDProfiler::instance(setting);
|
||||
|
||||
try
|
||||
{
|
||||
auto kernel_func = get_kernel_func_by_trait(arg_parser);
|
||||
profiler.benchmark(gemm_multi_d_problem, kernel_func);
|
||||
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Benchmark failed: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
auto [result, parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
benchmark_gemm_multi_d(parser);
|
||||
return 0;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << "\n";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
218
tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.hpp
Normal file
218
tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.hpp
Normal file
@@ -0,0 +1,218 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "gemm_multi_d_host_api.hpp"
|
||||
|
||||
struct GemmMultiDProblem
|
||||
{
|
||||
int split_k_;
|
||||
int m_, n_, k_;
|
||||
int stride_a_, stride_b_, stride_d0_, stride_d1_, stride_e_;
|
||||
|
||||
std::string dtype_a_, dtype_b_, dtype_d0_, dtype_d1_, dtype_acc_, dtype_e_;
|
||||
std::string layout_a_, layout_b_, layout_d0_, layout_d1_, layout_e_;
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const GemmMultiDProblem& problem)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"split_k\":" << problem.split_k_ << ",\n"
|
||||
<< " \"m\":" << problem.m_ << ",\n"
|
||||
<< " \"n\":" << problem.n_ << ",\n"
|
||||
<< " \"k\":" << problem.k_ << ",\n"
|
||||
<< " \"stride_a\":" << problem.stride_a_ << ",\n"
|
||||
<< " \"stride_b\":" << problem.stride_b_ << ",\n"
|
||||
<< " \"stride_d0\":" << problem.stride_d0_ << ",\n"
|
||||
<< " \"stride_d1\":" << problem.stride_d1_ << ",\n"
|
||||
<< " \"stride_e\":" << problem.stride_e_ << ",\n"
|
||||
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
|
||||
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
|
||||
<< " \"dtype_d0\":\"" << problem.dtype_d0_ << "\",\n"
|
||||
<< " \"dtype_d1\":\"" << problem.dtype_d1_ << "\",\n"
|
||||
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
|
||||
<< " \"dtype_e\":\"" << problem.dtype_e_ << "\",\n"
|
||||
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
|
||||
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
|
||||
<< " \"layout_d0\":\"" << problem.layout_d0_ << "\",\n"
|
||||
<< " \"layout_d1\":\"" << problem.layout_d1_ << "\",\n"
|
||||
<< " \"layout_e\":\"" << problem.layout_e_ << "\"\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct Setting
|
||||
{
|
||||
int n_warmup_;
|
||||
int n_repeat_;
|
||||
bool is_gpu_timer_;
|
||||
int verify_;
|
||||
int init_method_;
|
||||
bool log_;
|
||||
std::string csv_filename_;
|
||||
bool flush_cache_;
|
||||
int rotating_count_;
|
||||
};
|
||||
|
||||
// @brief Function to get the kernel output with reference implementation on CPU
|
||||
void gemm_multi_d_host_reference(int verify,
|
||||
ck_tile::HostTensor<ADataType>& a_m_k,
|
||||
ck_tile::HostTensor<BDataType>& b_k_n,
|
||||
ck_tile::HostTensor<D0DataType>& d0_m_n,
|
||||
ck_tile::HostTensor<D1DataType>& d1_m_n,
|
||||
ck_tile::HostTensor<EDataType>& e_m_n_host_result)
|
||||
{
|
||||
if(verify > 0)
|
||||
{
|
||||
// Currently supporting on CPU verification for Gemm Multi D
|
||||
// e_m_n_host_result.SetZero();
|
||||
ck_tile::reference_gemm_multiple_d<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ElementWiseFn>(
|
||||
a_m_k, b_k_n, {d0_m_n, d1_m_n}, e_m_n_host_result);
|
||||
}
|
||||
}
|
||||
|
||||
enum class Metric
|
||||
{
|
||||
LATENCY = 0,
|
||||
TFLOPS = 1,
|
||||
BANDWIDTH = 2
|
||||
};
|
||||
|
||||
inline constexpr auto get_metric_name(Metric m)
|
||||
{
|
||||
switch(m)
|
||||
{
|
||||
case Metric::LATENCY: return "latency";
|
||||
case Metric::TFLOPS: return "tflops";
|
||||
case Metric::BANDWIDTH: return "bandwidth";
|
||||
default: throw std::invalid_argument("Unsupported metric type");
|
||||
}
|
||||
}
|
||||
|
||||
struct PerformanceResult
|
||||
{
|
||||
double latency_;
|
||||
double tflops_;
|
||||
double bandwidth_;
|
||||
|
||||
static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
|
||||
{
|
||||
switch(m)
|
||||
{
|
||||
case Metric::LATENCY: return a.latency_ < b.latency_;
|
||||
case Metric::TFLOPS: return a.tflops_ > b.tflops_;
|
||||
case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_;
|
||||
default: throw std::invalid_argument("Unsupported metric type");
|
||||
}
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_
|
||||
<< ",\n"
|
||||
<< " \"tflops(TFlops)\": " << result.tflops_ << ",\n"
|
||||
<< " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct KernelInstance
|
||||
{
|
||||
std::string name_;
|
||||
GemmMultiDProblem problem_;
|
||||
PerformanceResult perf_result_;
|
||||
|
||||
static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
|
||||
{
|
||||
return PerformanceResult::compare(a.perf_result_, b.perf_result_, m);
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"name\": \"" << "{\n"
|
||||
<< obj.name_ << "\n}" << "\",\n"
|
||||
<< " \"problem\": \"" << obj.problem_ << "\",\n"
|
||||
<< " \"perf_result\": " << obj.perf_result_ << "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
inline std::string get_rocm_version()
|
||||
{
|
||||
std::ifstream version_file("/opt/rocm/.info/version");
|
||||
if(version_file.is_open())
|
||||
{
|
||||
std::string version;
|
||||
std::getline(version_file, version);
|
||||
return version;
|
||||
}
|
||||
return "Unknown";
|
||||
}
|
||||
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeTypeAB =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
|
||||
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
|
||||
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
/// @brief Function to compare the results of the device and host computations
|
||||
bool compare(std::string instanceName,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::HostTensor<EDataType>& e_m_n_dev_result,
|
||||
ck_tile::HostTensor<EDataType>& e_m_n_host_result)
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(e_m_n_host_result.mData.begin(), e_m_n_host_result.mData.end());
|
||||
|
||||
const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value);
|
||||
|
||||
bool pass = ck_tile::check_err(e_m_n_dev_result,
|
||||
e_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "For " << instanceName << " Relative error threshold is "
|
||||
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
|
||||
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
80
tile_engine/ops/gemm_multi_d/configs/custom_ci_config.json
Normal file
80
tile_engine/ops/gemm_multi_d/configs/custom_ci_config.json
Normal file
@@ -0,0 +1,80 @@
|
||||
{
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
256 ]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
84
tile_engine/ops/gemm_multi_d/configs/default_config.json
Normal file
84
tile_engine/ops/gemm_multi_d/configs/default_config.json
Normal file
@@ -0,0 +1,84 @@
|
||||
{
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
256
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"compv4",
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave",
|
||||
"interwave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
{
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
256
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
256
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
229
tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py
Normal file
229
tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Mappings and utility functions for kernel code generation.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import re
|
||||
from functools import lru_cache
|
||||
|
||||
DATA_TYPE_MAP = {
|
||||
"fp32": "float",
|
||||
"fp16": "ck_tile::half_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"int8": "ck_tile::int8_t",
|
||||
"fp8": "ck_tile::fp8_t",
|
||||
"bf8": "ck_tile::bf8_t",
|
||||
"int4": "ck_tile::pk_int4_t",
|
||||
"int32": "ck_tile::int32_t",
|
||||
}
|
||||
|
||||
LAYOUT_MAP = {
|
||||
"r": "ck_tile::tensor_layout::gemm::RowMajor",
|
||||
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
|
||||
}
|
||||
|
||||
|
||||
# TODO THIS IS NOT SUPPORTED FOR MULTI D AS OF NOW
|
||||
# DEFAULT_EPILOGUE = """
|
||||
# using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
# ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
|
||||
# BDataType,
|
||||
# AccDataType,
|
||||
# CDataType,
|
||||
# CLayout,
|
||||
# kPadM,
|
||||
# kPadN,
|
||||
# WarpTileM,
|
||||
# WarpTileN,
|
||||
# WarpTileK,
|
||||
# UniversalGemmProblem::TransposeC,
|
||||
# true,
|
||||
# memory_operation>>;
|
||||
# """
|
||||
|
||||
CSHUFFLE_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
WarpM,
|
||||
WarpN,
|
||||
WarpTileM,
|
||||
WarpTileN,
|
||||
WarpTileK,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
"""
|
||||
|
||||
PIPELINE_MAP = {
|
||||
"mem": ["ck_tile::BaseGemmPipelineAgBgCrMem", "ck_tile::GemmPipelineAgBgCrMem"],
|
||||
"compv3": [
|
||||
"ck_tile::BaseGemmPipelineAgBgCrCompV3",
|
||||
"ck_tile::GemmPipelineAgBgCrCompV3",
|
||||
],
|
||||
"compv4": [
|
||||
"ck_tile::BaseGemmPipelineAgBgCrCompV4",
|
||||
"ck_tile::GemmPipelineAgBgCrCompV4",
|
||||
],
|
||||
}
|
||||
|
||||
SCHEDULER_MAP = {
|
||||
"interwave": "ck_tile::GemmPipelineScheduler::Interwave",
|
||||
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
|
||||
}
|
||||
|
||||
# EPILOGUE_MAP = {"default": DEFAULT_EPILOGUE, "cshuffle": CSHUFFLE_EPILOGUE}
|
||||
|
||||
EPILOGUE_MAP = {"cshuffle": CSHUFFLE_EPILOGUE}
|
||||
|
||||
|
||||
def BOOL_MAP(b_):
|
||||
return {True: "true", False: "false"}[bool(b_)]
|
||||
|
||||
|
||||
# Can add some more supported combinations
|
||||
warp_tile_supported_combinations = {
|
||||
"gfx90a": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
},
|
||||
"gfx942": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
|
||||
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
|
||||
},
|
||||
"gfx950": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"bf8_bf8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 32],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# Remove some unsupported combinations
|
||||
trait_unsupported_combinations = {
|
||||
("compv3", "cshuffle", "interwave"),
|
||||
("compv3", "default", "interwave"),
|
||||
("compv4", "cshuffle", "interwave"),
|
||||
("compv4", "default", "interwave"),
|
||||
}
|
||||
|
||||
|
||||
ELEMENT_SIZE_MAP = {
|
||||
"fp16": 2,
|
||||
"bf16": 2,
|
||||
"int8": 1,
|
||||
"fp8": 1,
|
||||
"bf8": 1,
|
||||
"int4": 0.5,
|
||||
"int32": 4,
|
||||
}
|
||||
|
||||
|
||||
def element_size(data_type: str) -> float:
|
||||
"""Calculate the size (in bytes) of a single element for given data type."""
|
||||
data_type = data_type.lower()
|
||||
if data_type not in ELEMENT_SIZE_MAP:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
return ELEMENT_SIZE_MAP[data_type]
|
||||
|
||||
|
||||
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
|
||||
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
|
||||
)
|
||||
if matches := GPU_NAME_PATTERN.finditer(output):
|
||||
gpu_list = [m.group(1) for m in matches]
|
||||
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
|
||||
|
||||
return ""
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
|
||||
except FileNotFoundError:
|
||||
print("ROCm tools not installed (requires rocminfo)")
|
||||
except subprocess.TimeoutExpired:
|
||||
print("GPU query timeout (5s)")
|
||||
except Exception as e:
|
||||
print(f"GPU detection error: {str(e)}")
|
||||
|
||||
return ""
|
||||
250
tile_engine/ops/gemm_multi_d/gemm_multi_d_config.py
Normal file
250
tile_engine/ops/gemm_multi_d/gemm_multi_d_config.py
Normal file
@@ -0,0 +1,250 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Handles loading, parsing, and validation of JSON and Argument configuration parameters.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union, Type
|
||||
import json
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnumConfigParam:
|
||||
"""Represents an enumeration-type configuration parameter"""
|
||||
|
||||
values: List[Union[int, str, bool]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RangeConfigParam:
|
||||
"""Represents a numeric range-type configuration parameter"""
|
||||
|
||||
min: int
|
||||
max: int
|
||||
step: int
|
||||
exclude: Optional[List[int]]
|
||||
|
||||
def generate_candidates(self) -> List[int]:
|
||||
"""Generates valid candidates after applying range constraints"""
|
||||
|
||||
if self.min > self.max:
|
||||
raise ValueError(f"Invalid range: min({self.min}) > max({self.max})")
|
||||
if self.step <= 0:
|
||||
raise ValueError(f"Step must be positive, got {self.step}")
|
||||
|
||||
candidates = list(range(self.min, self.max + 1, self.step))
|
||||
|
||||
if hasattr(self, "exclude") and self.exclude:
|
||||
if not isinstance(self.exclude, list):
|
||||
raise TypeError("exclude must be list type")
|
||||
exclude_set = set(self.exclude)
|
||||
candidates = [x for x in candidates if x not in exclude_set]
|
||||
|
||||
if not candidates:
|
||||
raise ValueError(
|
||||
f"No valid candidates for range [{self.min}-{self.max}] "
|
||||
f"with step {self.step} and excludes {self.exclude}"
|
||||
)
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataType:
|
||||
"""Configuration class for data type parameter."""
|
||||
|
||||
a_datatype: str
|
||||
b_datatype: str
|
||||
e_datatype: str
|
||||
d0_datatype: str
|
||||
d1_datatype: str
|
||||
ds_datatype: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Layout:
|
||||
"""Configuration class for Layout parameter."""
|
||||
|
||||
a_layout: str
|
||||
b_layout: str
|
||||
e_layout: str
|
||||
d0_layout: str
|
||||
d1_layout: str
|
||||
ds_layout: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArgumentConfig:
|
||||
"""Configuration class for Argument parameter."""
|
||||
|
||||
datatypes: DataType
|
||||
layouts: Layout
|
||||
function_name: str
|
||||
|
||||
@classmethod
|
||||
def from_args(
|
||||
cls: Type["ArgumentConfig"],
|
||||
datatype: str,
|
||||
layout: str,
|
||||
elementwise_function: str,
|
||||
) -> "ArgumentConfig":
|
||||
"""configuration loader with validation controls"""
|
||||
|
||||
datatypes = DataType(
|
||||
a_datatype=datatype,
|
||||
b_datatype=datatype,
|
||||
e_datatype=datatype,
|
||||
d0_datatype=datatype,
|
||||
d1_datatype=datatype,
|
||||
ds_datatype=[datatype, datatype],
|
||||
)
|
||||
|
||||
layout_parts = layout.lower()
|
||||
assert len(layout_parts) == 4, (
|
||||
f"Invalid layout string: {layout} (must be 4 characters like 'rcrr' where r stands for row major and c stands for column major)"
|
||||
)
|
||||
assert layout_parts[0] in ("r", "c"), (
|
||||
f"Invalid matrix_a layout: {layout_parts[0]} (must be 'r' for row major or or 'c' for column major)"
|
||||
)
|
||||
assert layout_parts[1] in ("r", "c"), (
|
||||
f"Invalid matrix_b layout: {layout_parts[1]} (must be 'r' for row major or or 'c' for column major)"
|
||||
)
|
||||
assert layout_parts[2] == "r", (
|
||||
f"Invalid matrix_e layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
|
||||
)
|
||||
assert layout_parts[3] == "r", (
|
||||
f"Invalid D dimension layout: {layout_parts[3]} (must be 'r' only as currently we are supporting only row major)"
|
||||
)
|
||||
|
||||
layouts = Layout(
|
||||
a_layout=layout[0],
|
||||
b_layout=layout[1],
|
||||
e_layout=layout[2],
|
||||
d0_layout=layout[3],
|
||||
d1_layout=layout[3],
|
||||
ds_layout=[layout[3], layout[3]],
|
||||
)
|
||||
# Elementwise function name validation
|
||||
valid_functions = ["mul", "add", "passthrough"]
|
||||
if elementwise_function not in valid_functions:
|
||||
raise ValueError(
|
||||
f"Invalid elementwise function: {elementwise_function}. "
|
||||
f"Valid options are: {', '.join(valid_functions)}"
|
||||
)
|
||||
|
||||
# Set the function name based on the elementwise function
|
||||
if elementwise_function == "mul":
|
||||
function_name = "MultiDMultiply"
|
||||
elif elementwise_function == "add":
|
||||
function_name = "MultiDAdd"
|
||||
elif elementwise_function == "passthrough":
|
||||
function_name = "PassThrough" # TODO Change this
|
||||
|
||||
return cls(datatypes=datatypes, layouts=layouts, function_name=function_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TileConfig:
|
||||
"""Configuration class for tile parameter."""
|
||||
|
||||
tile_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
tile_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
tile_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
|
||||
warp_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
|
||||
warp_tile_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_tile_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_tile_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraitConfig:
|
||||
"""Configuration class for kernel traits."""
|
||||
|
||||
pipeline: EnumConfigParam
|
||||
scheduler: EnumConfigParam
|
||||
epilogue: EnumConfigParam
|
||||
pad_m: EnumConfigParam
|
||||
pad_n: EnumConfigParam
|
||||
pad_k: EnumConfigParam
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonConfig:
|
||||
"""Configuration class for JSON parameter."""
|
||||
|
||||
tile_config: TileConfig
|
||||
trait_config: TraitConfig
|
||||
|
||||
@classmethod
|
||||
def from_json(cls: Type["JsonConfig"], filepath: str) -> "JsonConfig":
|
||||
"""JSON configuration loader with validation controls"""
|
||||
config_path = Path(filepath)
|
||||
|
||||
try:
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Config file {filepath} not found")
|
||||
|
||||
with config_path.open("r") as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Parse tile config
|
||||
def create_param(param_dict):
|
||||
if "values" in param_dict:
|
||||
return EnumConfigParam(values=param_dict["values"])
|
||||
else:
|
||||
return RangeConfigParam(
|
||||
min=param_dict["min"],
|
||||
max=param_dict["max"],
|
||||
step=param_dict["step"],
|
||||
exclude=param_dict.get("exclude", []),
|
||||
)
|
||||
|
||||
tile_config = TileConfig(
|
||||
tile_m=create_param(config_dict["tile_config"]["tile_m"]),
|
||||
tile_n=create_param(config_dict["tile_config"]["tile_n"]),
|
||||
tile_k=create_param(config_dict["tile_config"]["tile_k"]),
|
||||
warp_m=create_param(config_dict["tile_config"]["warp_m"]),
|
||||
warp_n=create_param(config_dict["tile_config"]["warp_n"]),
|
||||
warp_k=create_param(config_dict["tile_config"]["warp_k"]),
|
||||
warp_tile_m=create_param(config_dict["tile_config"]["warp_tile_m"]),
|
||||
warp_tile_n=create_param(config_dict["tile_config"]["warp_tile_n"]),
|
||||
warp_tile_k=create_param(config_dict["tile_config"]["warp_tile_k"]),
|
||||
)
|
||||
|
||||
# Parse trait config
|
||||
trait_config = TraitConfig(
|
||||
pipeline=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["pipeline"]["values"]
|
||||
),
|
||||
scheduler=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["scheduler"]["values"]
|
||||
),
|
||||
epilogue=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["epilogue"]["values"]
|
||||
),
|
||||
pad_m=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["pad_m"]["values"]
|
||||
),
|
||||
pad_n=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["pad_n"]["values"]
|
||||
),
|
||||
pad_k=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["pad_k"]["values"]
|
||||
),
|
||||
)
|
||||
|
||||
return cls(tile_config=tile_config, trait_config=trait_config)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON format: {str(e)}")
|
||||
except KeyError as e:
|
||||
raise KeyError(f"Missing required configuration field: {str(e)}")
|
||||
164
tile_engine/ops/gemm_multi_d/gemm_multi_d_host_api.hpp
Normal file
164
tile_engine/ops/gemm_multi_d/gemm_multi_d_host_api.hpp
Normal file
@@ -0,0 +1,164 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_multi_d_dispatcher.hpp"
|
||||
#include "gemm_multi_d_common.hpp"
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
inline auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.")
|
||||
.insert("n", "4096", "The value for n dimension. Default is 4096.")
|
||||
.insert("k", "2048", "The value for k dimension. Default is 2048.")
|
||||
.insert("stride_a", "0", "The stride value for tensor A. Default is 0.")
|
||||
.insert("stride_b", "0", "The stride value for tensor B. Default is 0.")
|
||||
.insert("stride_ds", "0", "The stride value for tensor Ds Default is 0.")
|
||||
.insert("stride_e", "0", "The stride value for tensor E Default is 0.")
|
||||
.insert("split_k", "1", "The split value for k dimension. Default is 1.")
|
||||
.insert("verify",
|
||||
"1",
|
||||
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 "
|
||||
"for validation on GPU. Default is 1, validation on CPU, as validation on GPU is "
|
||||
"not supported.")
|
||||
.insert("log",
|
||||
"false",
|
||||
"Wether output kernel instance information or not. Possible values are true or "
|
||||
"false. Default is false")
|
||||
.insert("warmup",
|
||||
"50",
|
||||
"The number of iterations before benchmarking the kernel. Default is 50.")
|
||||
.insert("repeat",
|
||||
"100",
|
||||
"The number of iterations for benchmarking the kernel. Default is 100.")
|
||||
.insert("timer",
|
||||
"true",
|
||||
"Indicates whether the timer is a GPU timer. Possible values are true or false. "
|
||||
"Default is true.")
|
||||
.insert("init",
|
||||
"0",
|
||||
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
|
||||
"for constant(1). Default is 0, random.")
|
||||
.insert("flush_cache",
|
||||
"false",
|
||||
"To flush cache, possible values are true or false. "
|
||||
"Default is false.")
|
||||
.insert("rotating_count", "5", "number of iterations to rotate the cache. default is 5.")
|
||||
.insert("metric",
|
||||
"0",
|
||||
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
|
||||
"tflops, or 2 for bandwidth. Default is 0, latency.")
|
||||
.insert("csv_filename",
|
||||
"gemm_multi_d_kernel",
|
||||
"The filename of benchmark result. Default is set to gemm_multi_d_kernel.")
|
||||
.insert(
|
||||
"pipeline",
|
||||
"compv3",
|
||||
"The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.")
|
||||
.insert("scheduler",
|
||||
"intrawave",
|
||||
"The type of pipeline. Possible values are compv3, compv4 or mem. Default is "
|
||||
"compv3.")
|
||||
.insert(
|
||||
"epilogue",
|
||||
"cshuffle",
|
||||
"The type of epilogue. Possible values are cshuffle or default. Default is cshuffle.")
|
||||
.insert("pad_m",
|
||||
"false",
|
||||
"Whether pad or not in m direction. Possible values are true or false. Default is "
|
||||
"false.")
|
||||
.insert("pad_n",
|
||||
"false",
|
||||
"Whether pad or not in n direction. Possible values are true or false. Default is "
|
||||
"false.")
|
||||
.insert("pad_k",
|
||||
"false",
|
||||
"Whether pad or not in k direction. Possible values are true or false. Default is "
|
||||
"false.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
auto get_kernel_func_by_trait(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
KernelTraits trait;
|
||||
trait.pipeline = arg_parser.get_str("pipeline");
|
||||
trait.scheduler = arg_parser.get_str("scheduler");
|
||||
trait.epilogue = arg_parser.get_str("epilogue");
|
||||
trait.pad_m = arg_parser.get_bool("pad_m");
|
||||
trait.pad_n = arg_parser.get_bool("pad_n");
|
||||
trait.pad_k = arg_parser.get_bool("pad_k");
|
||||
|
||||
return GemmMultiDDispatcher::dispatch(trait);
|
||||
}
|
||||
755
tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py
Executable file
755
tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py
Executable file
@@ -0,0 +1,755 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
generate kernel instances to speed up compilation
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from gemm_multi_d_config import JsonConfig, ArgumentConfig, RangeConfigParam
|
||||
from gemm_multi_d_codegen_utils import (
|
||||
DATA_TYPE_MAP,
|
||||
LAYOUT_MAP,
|
||||
PIPELINE_MAP,
|
||||
SCHEDULER_MAP,
|
||||
EPILOGUE_MAP,
|
||||
BOOL_MAP,
|
||||
warp_tile_supported_combinations,
|
||||
trait_unsupported_combinations,
|
||||
element_size,
|
||||
get_gpu_name_by_id,
|
||||
)
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class GemmMultiDCodeGenerator:
|
||||
"""GEMM (General Matrix Multiplication) Multi D code generator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: argparse.Namespace,
|
||||
user_provided_config: Optional[JsonConfig] = None,
|
||||
):
|
||||
self.output_dir = Path(args.working_path)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if user_provided_config is not None:
|
||||
self.config = user_provided_config
|
||||
else:
|
||||
config_path = (
|
||||
Path(__file__).resolve().parent / "configs" / "default_config.json"
|
||||
)
|
||||
self.config = JsonConfig.from_json(config_path)
|
||||
|
||||
self.args = ArgumentConfig.from_args(
|
||||
args.datatype, args.layout, args.elementwise_function
|
||||
)
|
||||
|
||||
self.valid_trait_names: List[str] = []
|
||||
self.valid_trait_tile_combinations: map[str, list[tuple[int]]] = {}
|
||||
|
||||
def list_all_trait_names(self):
|
||||
"""List all possible kernel trait names into file."""
|
||||
w_p = Path(self.output_dir)
|
||||
file_path = w_p / "gemm_multi_d_instance_blobs.txt"
|
||||
self._generate_all_traits()
|
||||
self._get_valid_trait_tile_combinations()
|
||||
file_range_map = {}
|
||||
# Write all file paths to the header file
|
||||
files_listed = 0
|
||||
with file_path.open("w") as f:
|
||||
# Core files
|
||||
core_files = [
|
||||
"gemm_multi_d_common.hpp",
|
||||
"gemm_multi_d_instances.hpp",
|
||||
"gemm_multi_d_dispatcher.hpp",
|
||||
]
|
||||
for core_file in core_files:
|
||||
f.write(str(w_p / core_file) + "\n")
|
||||
files_listed += 1
|
||||
|
||||
# Trait header files
|
||||
for trait in self.valid_trait_names:
|
||||
trait_file = f"gemm_multi_d_{trait}.hpp"
|
||||
f.write(str(w_p / trait_file) + "\n")
|
||||
files_listed += 1
|
||||
file_name = set()
|
||||
# Instance source files
|
||||
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
|
||||
start_idx = files_listed
|
||||
for tile in tile_valid_params:
|
||||
for (
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) in tile:
|
||||
instance_name = f"gemm_multi_d_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp"
|
||||
|
||||
if instance_name not in file_name:
|
||||
file_name.add(instance_name)
|
||||
f.write(str(w_p / instance_name) + "\n")
|
||||
files_listed += 1
|
||||
|
||||
file_range_map[trait] = (start_idx, files_listed)
|
||||
|
||||
file_path = w_p / "gemm_multi_d_instance_blobs_range.txt"
|
||||
with file_path.open("w") as f:
|
||||
for name, ranges in file_range_map.items():
|
||||
start, last = ranges
|
||||
f.write(name + " " + f"{start}" + " " + f"{last}" + "\n")
|
||||
|
||||
def _generate_all_traits(self):
|
||||
"""Generate all possible kernel traits names."""
|
||||
params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"]
|
||||
|
||||
# Generate all unique_combinations
|
||||
_unique = set(
|
||||
itertools.product(
|
||||
*[getattr(self.config.trait_config, param).values for param in params]
|
||||
)
|
||||
)
|
||||
|
||||
for combo in _unique:
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = combo
|
||||
current_combination = (pipeline, epilogue, scheduler)
|
||||
|
||||
if current_combination not in trait_unsupported_combinations:
|
||||
trait_name = (
|
||||
f"{pipeline}_{epilogue}_{scheduler}_"
|
||||
f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}"
|
||||
)
|
||||
self.valid_trait_names.append(trait_name)
|
||||
else:
|
||||
logging.debug(f"Invalid combination: {pipeline}-{epilogue}-{scheduler}")
|
||||
|
||||
def _get_valid_trait_tile_combinations(self):
|
||||
def get_tile_value(tile_param):
|
||||
return (
|
||||
tile_param.generate_candidates()
|
||||
if isinstance(tile_param, RangeConfigParam)
|
||||
else tile_param.values
|
||||
)
|
||||
|
||||
tile_group = list(
|
||||
itertools.product(
|
||||
get_tile_value(self.config.tile_config.tile_m),
|
||||
get_tile_value(self.config.tile_config.tile_n),
|
||||
get_tile_value(self.config.tile_config.tile_k),
|
||||
)
|
||||
)
|
||||
|
||||
warp_group = list(
|
||||
itertools.product(
|
||||
get_tile_value(self.config.tile_config.warp_m),
|
||||
get_tile_value(self.config.tile_config.warp_n),
|
||||
get_tile_value(self.config.tile_config.warp_k),
|
||||
)
|
||||
)
|
||||
|
||||
warp_tile_group = list(
|
||||
itertools.product(
|
||||
get_tile_value(self.config.tile_config.warp_tile_m),
|
||||
get_tile_value(self.config.tile_config.warp_tile_n),
|
||||
get_tile_value(self.config.tile_config.warp_tile_k),
|
||||
)
|
||||
)
|
||||
|
||||
tile_params = {
|
||||
t + w + wt for t in tile_group for w in warp_group for wt in warp_tile_group
|
||||
}
|
||||
|
||||
for trait in self.valid_trait_names:
|
||||
tile_valid_params = [
|
||||
tile for tile in tile_params if self.is_tile_valid(tile, trait)
|
||||
]
|
||||
|
||||
if trait not in self.valid_trait_tile_combinations:
|
||||
self.valid_trait_tile_combinations[trait] = []
|
||||
self.valid_trait_tile_combinations[trait].append(tile_valid_params)
|
||||
|
||||
def is_tile_valid(self, tile: tuple, trait: str) -> bool:
|
||||
"""Check if the tile configuration is valid for the given trait."""
|
||||
(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
) = tile
|
||||
pipeline, *_ = trait.split("_")
|
||||
|
||||
# Parameter validity check
|
||||
invalid_params = []
|
||||
if (warp_m, warp_n, warp_k) not in [(1, 4, 1), (2, 2, 1), (4, 1, 1)]:
|
||||
invalid_params.append(
|
||||
f"warp_m({warp_m}) * warp_n({warp_n}) * warp_k({warp_k})"
|
||||
)
|
||||
if (warp_m * warp_tile_m) == 0:
|
||||
invalid_params.append(f"warp_m({warp_m}) * warp_tile_m({warp_tile_m})")
|
||||
if (warp_n * warp_tile_n) == 0:
|
||||
invalid_params.append(f"warp_n({warp_n}) * warp_tile_n({warp_tile_n})")
|
||||
if (warp_k * warp_tile_k) == 0:
|
||||
invalid_params.append(f"warp_k({warp_k}) * warp_tile_k({warp_tile_k})")
|
||||
|
||||
if invalid_params:
|
||||
logging.debug(
|
||||
f"Trait: [{trait}], Invalid warp configuration: {', '.join(invalid_params)}. "
|
||||
f"Parameter combination: warp=({warp_m},{warp_n},{warp_k}), "
|
||||
f"warp_tile=({warp_tile_m},{warp_tile_n},{warp_tile_k})"
|
||||
)
|
||||
return False
|
||||
# Dimension alignment check
|
||||
alignment_issues = []
|
||||
if tile_m % (warp_m * warp_tile_m) != 0:
|
||||
alignment_issues.append(
|
||||
f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}"
|
||||
)
|
||||
if tile_n % (warp_n * warp_tile_n) != 0:
|
||||
alignment_issues.append(
|
||||
f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}"
|
||||
)
|
||||
if tile_k % (warp_k * warp_tile_k) != 0:
|
||||
alignment_issues.append(
|
||||
f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}"
|
||||
)
|
||||
|
||||
if alignment_issues:
|
||||
logging.debug(
|
||||
f"Trait: [{trait}], Dimension alignment failed: {', '.join(alignment_issues)}. "
|
||||
f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by "
|
||||
f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
|
||||
)
|
||||
return False
|
||||
|
||||
# LDS capacity verification
|
||||
matrix_a_size = (tile_m * tile_k) * element_size(self.args.datatypes.a_datatype)
|
||||
|
||||
matrix_b_size = (tile_n * tile_k) * element_size(self.args.datatypes.b_datatype)
|
||||
|
||||
total_tile_in_lds = matrix_a_size + matrix_b_size
|
||||
|
||||
max_tile_size = 2**15 if pipeline == "compv4" else 2**16
|
||||
|
||||
if total_tile_in_lds > max_tile_size:
|
||||
logging.debug(
|
||||
f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > "
|
||||
f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n"
|
||||
f"- Matrix A ({self.config.problem.datatype_map['matrix_a']}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n"
|
||||
f"- Matrix B ({self.config.problem.datatype_map['matrix_b']}): {tile_n}x{tile_k} = {matrix_b_size:,}B"
|
||||
)
|
||||
return False
|
||||
|
||||
# Warp combination validation
|
||||
warp_tile_key = f"{self.args.datatypes.a_datatype}_{self.args.datatypes.b_datatype}_{self.args.datatypes.e_datatype}"
|
||||
|
||||
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
|
||||
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
|
||||
gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {})
|
||||
if not gpu_warp_tile_key:
|
||||
logging.debug(
|
||||
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check."
|
||||
)
|
||||
return False
|
||||
|
||||
allowed_combinations = gpu_warp_tile_key.get(warp_tile_key, [])
|
||||
if not allowed_combinations:
|
||||
logging.debug(
|
||||
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check."
|
||||
)
|
||||
return False
|
||||
|
||||
if current_combination not in allowed_combinations:
|
||||
logging.debug(
|
||||
f"Trait: [{trait}], Invalid warp combination: {current_combination} not in allowed list. "
|
||||
f"Valid combinations for data type '{warp_tile_key}': {allowed_combinations}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def generate_all_instance_files(self):
|
||||
"""Generate all kernel instances files."""
|
||||
self._generate_common_header_file()
|
||||
self._generate_all_trait_files()
|
||||
self._generate_dispatcher_file()
|
||||
|
||||
def _generate_common_header_file(self):
|
||||
"""Generate common header file with datatypes and layout."""
|
||||
|
||||
acc_type = "float" # As we are currently supporting only fp16
|
||||
|
||||
content = f"""
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
// Data types
|
||||
using ADataType = {DATA_TYPE_MAP[self.args.datatypes.a_datatype]};
|
||||
using BDataType = {DATA_TYPE_MAP[self.args.datatypes.b_datatype]};
|
||||
using AccDataType = {acc_type};
|
||||
using D0DataType = {DATA_TYPE_MAP[self.args.datatypes.d0_datatype]};
|
||||
using D1DataType = {DATA_TYPE_MAP[self.args.datatypes.d1_datatype]};
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
|
||||
using EDataType = {DATA_TYPE_MAP[self.args.datatypes.e_datatype]};
|
||||
|
||||
|
||||
// Layout configurations
|
||||
using ALayout = {LAYOUT_MAP[self.args.layouts.a_layout]};
|
||||
using BLayout = {LAYOUT_MAP[self.args.layouts.b_layout]};
|
||||
using D0Layout = {LAYOUT_MAP[self.args.layouts.d0_layout]};
|
||||
using D1Layout = {LAYOUT_MAP[self.args.layouts.d1_layout]};
|
||||
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
|
||||
using ELayout = {LAYOUT_MAP[self.args.layouts.e_layout]};
|
||||
|
||||
// Element-wise function for D
|
||||
using ElementWiseFn = ck_tile::element_wise::{self.args.function_name};
|
||||
|
||||
"""
|
||||
|
||||
(self.output_dir / "gemm_multi_d_common.hpp").write_text(content)
|
||||
|
||||
def _generate_all_trait_files(self):
|
||||
"""Generate all kernel traits into files."""
|
||||
if not self.valid_trait_names:
|
||||
self._generate_all_traits()
|
||||
self._get_valid_trait_tile_combinations()
|
||||
for trait in self.valid_trait_names:
|
||||
self._generate_trait_file(trait)
|
||||
self._generate_instantiation_source_files()
|
||||
self._generate_common_instance_header_file()
|
||||
|
||||
def _generate_trait_file(self, trait: str):
|
||||
"""Generate a trait with all tile/warp combinations."""
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = trait.split("_")
|
||||
filename = f"gemm_multi_d_{trait}.hpp"
|
||||
|
||||
content = f"""
|
||||
#pragma once
|
||||
|
||||
#include "gemm_multi_d_common.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
namespace {trait} {{
|
||||
"""
|
||||
# Add template struct with configuration
|
||||
content += self._generate_kernel_struct(
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k
|
||||
)
|
||||
|
||||
content += f"\n}} // namespace {trait}\n"
|
||||
(self.output_dir / filename).write_text(content)
|
||||
|
||||
def _generate_kernel_struct(
|
||||
self,
|
||||
pipeline: str,
|
||||
epilogue: str,
|
||||
scheduler: str,
|
||||
pad_m: str,
|
||||
pad_n: str,
|
||||
pad_k: str,
|
||||
) -> str:
|
||||
"""Generate the code block of kernel struct"""
|
||||
return f"""
|
||||
|
||||
template <int TileM, int TileN, int TileK,
|
||||
int WarpM, int WarpN, int WarpK,
|
||||
int WarpTileM, int WarpTileN, int WarpTileK,
|
||||
typename CDEElementWise = ElementWiseFn>
|
||||
struct GemmKernelMultiD {{
|
||||
static constexpr bool kPadM = {pad_m};
|
||||
static constexpr bool kPadN = {pad_n};
|
||||
static constexpr bool kPadK = {pad_k};
|
||||
|
||||
static float launch(ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream) {{
|
||||
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<TileM, TileN, TileK>,
|
||||
ck_tile::sequence<WarpM, WarpN, WarpK>,
|
||||
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
TileParitionerGroupNum,
|
||||
TileParitionerM01>;
|
||||
|
||||
using Traits =
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, ELayout>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
||||
ALayout, BLayout, ELayout, TransposeC>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * TileK;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{{0}};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = {SCHEDULER_MAP[scheduler]};
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = {PIPELINE_MAP[pipeline][1]}<UniversalGemmProblem>;
|
||||
{EPILOGUE_MAP[epilogue]}
|
||||
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
||||
}}
|
||||
|
||||
if(stream.log_level_ > 0)
|
||||
{{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
|
||||
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
|
||||
<< std::endl;
|
||||
}}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(stream,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
|
||||
Kernel{{}}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
|
||||
}};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{
|
||||
if(args.k_batch == 1) {{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{{}});
|
||||
}} else {{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{{}});
|
||||
}}
|
||||
}};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
}}
|
||||
|
||||
static std::string get_name() {{
|
||||
return std::string("gemm_multi_d_") + std::to_string(TileM) + "x" + std::to_string(TileN) + "x" + std::to_string(TileK) +
|
||||
"_" + std::to_string(WarpM) + "x" + std::to_string(WarpN) + "x" + std::to_string(WarpK) + "_" +
|
||||
std::to_string(WarpTileM) + "x" + std::to_string(WarpTileN) + "x" + std::to_string(WarpTileK) + "_" +
|
||||
"{pad_m}" + "_" +
|
||||
"{pad_n}" + "_" +
|
||||
"{pad_k}" + "_" +
|
||||
"{pipeline}" + "_" +
|
||||
"{epilogue}" + "_" +
|
||||
"{scheduler}";
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
|
||||
def _generate_instantiation_source_files(self):
|
||||
"""Generate kernel instance instantiation source files"""
|
||||
tile_map = {}
|
||||
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
|
||||
for tile in tile_valid_params:
|
||||
for (
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
) in tile:
|
||||
key = f"{tile_m}x{tile_n}x{tile_k}x{warp_m}x{warp_n}x{warp_k}"
|
||||
value = f"{warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
|
||||
if key not in tile_map:
|
||||
tile_map[key] = set()
|
||||
tile_map[key].add(value)
|
||||
|
||||
files_listed = 0
|
||||
for trait, _ in self.valid_trait_tile_combinations.items():
|
||||
for block_tile, warp_tiles in tile_map.items():
|
||||
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k = map(
|
||||
int, block_tile.split("x")
|
||||
)
|
||||
|
||||
content = f"""
|
||||
#include "gemm_multi_d_{trait}.hpp"
|
||||
|
||||
"""
|
||||
for warp_tile in warp_tiles:
|
||||
warp_tile_m, warp_tile_n, warp_tile_k = map(
|
||||
int, warp_tile.split("x")
|
||||
)
|
||||
|
||||
files_listed = files_listed + 1
|
||||
content = (
|
||||
content
|
||||
+ f"""
|
||||
template struct {trait}::GemmKernelMultiD<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}>;"""
|
||||
)
|
||||
content += """
|
||||
"""
|
||||
(
|
||||
self.output_dir
|
||||
/ f"gemm_multi_d_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp"
|
||||
).write_text(content)
|
||||
print(f"Generated {files_listed} kernel instances in total.")
|
||||
|
||||
def _generate_common_instance_header_file(self):
|
||||
"""Generate common instance header into file."""
|
||||
content = """
|
||||
#pragma once
|
||||
"""
|
||||
for trait in self.valid_trait_names:
|
||||
content += f'#include "gemm_multi_d_{trait}.hpp"\n'
|
||||
(self.output_dir / "gemm_multi_d_instances.hpp").write_text(content)
|
||||
|
||||
def _generate_dispatcher_file(self):
|
||||
"""Generate the code block of dispatch mechanism."""
|
||||
content = """
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#include "gemm_multi_d_common.hpp"
|
||||
#include "gemm_multi_d_instances.hpp"
|
||||
|
||||
/// @brief Defines the configuration parameters for a GEMM Multi D operation, enabling the selection of a
|
||||
/// specific kernel instance based on the provided settings.
|
||||
struct KernelTraits
|
||||
{
|
||||
/// @brief The name of the pipeline.
|
||||
std::string pipeline;
|
||||
/// @brief The name of the scheduler (e.g., "intrawave", "interwave").
|
||||
std::string scheduler;
|
||||
/// @brief The name of the epilogue (e.g., "cshuffle", "default").
|
||||
std::string epilogue;
|
||||
/// @brief Indicates whether padding is applied to the M dimension.
|
||||
bool pad_m;
|
||||
/// @brief Indicates whether padding is applied to the N dimension.
|
||||
bool pad_n;
|
||||
/// @brief Indicates whether padding is applied to the K dimension.
|
||||
bool pad_k;
|
||||
};
|
||||
|
||||
struct GemmMultiDDispatcher {
|
||||
static auto& get_kernel_map() {
|
||||
// Use a static local variable
|
||||
static std::unordered_map<
|
||||
std::string,
|
||||
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmMultiDHostArgs<DsDataType::size()>&, const ck_tile::stream_config&)>>>
|
||||
kernel_map;
|
||||
return kernel_map;
|
||||
}
|
||||
|
||||
static void init() {
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(!kernel_map.empty()) return;
|
||||
\n"""
|
||||
|
||||
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
|
||||
content += f""" kernel_map["{trait}"] = {{"""
|
||||
for _, tile in enumerate(tile_valid_params):
|
||||
for j in range(len(tile)):
|
||||
(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
) = tile[j]
|
||||
content += """[=](ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream) { """
|
||||
|
||||
content += f"""
|
||||
return run_kernel<{trait}::GemmKernelMultiD<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}>>(args, stream);"""
|
||||
|
||||
if j == len(tile) - 1:
|
||||
content += """
|
||||
} """
|
||||
else:
|
||||
content += """
|
||||
}, """
|
||||
content += """
|
||||
};\n """
|
||||
|
||||
content += """ }
|
||||
|
||||
template <typename Kernel>
|
||||
static std::tuple<std::string, float> run_kernel(ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream)
|
||||
{
|
||||
std::string name = Kernel::get_name();
|
||||
float avg_time = Kernel::launch(args, stream);
|
||||
|
||||
return std::make_tuple(name, avg_time);
|
||||
}
|
||||
|
||||
|
||||
static auto dispatch(const KernelTraits& trait) {
|
||||
init();
|
||||
const std::string key = assemble_key(trait);
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(auto it = kernel_map.find(key); it != kernel_map.end())
|
||||
{
|
||||
return it->second;
|
||||
}
|
||||
throw std::runtime_error("No suitable kernel found: " + key);
|
||||
}
|
||||
|
||||
private:
|
||||
static std::string assemble_key(const KernelTraits &trait) {
|
||||
return std::string(trait.pipeline) + "_" +
|
||||
trait.epilogue + "_" +
|
||||
trait.scheduler + "_" +
|
||||
(trait.pad_m ? "true" : "false") + "_" +
|
||||
(trait.pad_n ? "true" : "false") + "_" +
|
||||
(trait.pad_k ? "true" : "false");
|
||||
}
|
||||
};
|
||||
|
||||
"""
|
||||
(self.output_dir / "gemm_multi_d_dispatcher.hpp").write_text(content)
|
||||
|
||||
|
||||
def do_list_blobs(
|
||||
args: argparse.Namespace, user_provide_config: Optional[JsonConfig] = None
|
||||
):
|
||||
generator = GemmMultiDCodeGenerator(args, user_provide_config)
|
||||
generator.list_all_trait_names()
|
||||
|
||||
|
||||
def do_gen_blobs(
|
||||
args: argparse.Namespace, user_provide_config: Optional[JsonConfig] = None
|
||||
):
|
||||
generator = GemmMultiDCodeGenerator(args, user_provide_config)
|
||||
generator.generate_all_instance_files()
|
||||
|
||||
|
||||
def main(args):
|
||||
gemm_multi_d_config = JsonConfig.from_json(args.config_json)
|
||||
|
||||
if args.list_blobs:
|
||||
do_list_blobs(args, gemm_multi_d_config)
|
||||
elif args.gen_blobs:
|
||||
do_gen_blobs(args, gemm_multi_d_config)
|
||||
else:
|
||||
logging.warning(
|
||||
"No mode specified (use --list_blobs or --gen_blobs). Generating by default..."
|
||||
)
|
||||
do_gen_blobs(args, gemm_multi_d_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate",
|
||||
description="gen API for CK gemm multi D kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--working_path",
|
||||
default="./",
|
||||
required=False,
|
||||
help="The path where all the blobs are going to be generated",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--config_json",
|
||||
required=False,
|
||||
help="Path to the json which contains the configurations that user provide",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--datatype",
|
||||
required=True,
|
||||
help="Specify what datatype to use for the kernel generation, e.g. fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-ly",
|
||||
"--layout",
|
||||
required=True,
|
||||
help="Specify what layout to use for the kernel generation, e.g. rcrr, rrrr",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-ef",
|
||||
"--elementwise_function",
|
||||
required=True,
|
||||
help="Specify what element wise function for D, e.g. mul, add, passthrough",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--list_blobs",
|
||||
action="store_true",
|
||||
help="List all kernel instances to file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--gen_blobs",
|
||||
action="store_true",
|
||||
help="Generate all kernel instances into different files",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
278
tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp
Normal file
278
tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp
Normal file
@@ -0,0 +1,278 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "benchmark_gemm_multi_d.hpp"
|
||||
|
||||
class GemmMultiDProfiler
|
||||
{
|
||||
public:
|
||||
static GemmMultiDProfiler& instance(Setting setting)
|
||||
{
|
||||
static GemmMultiDProfiler instance{setting};
|
||||
return instance;
|
||||
}
|
||||
|
||||
void benchmark(
|
||||
GemmMultiDProblem& gemm_multi_d_problem,
|
||||
std::vector<std::function<std::tuple<std::string, float>(
|
||||
ck_tile::GemmMultiDHostArgs<DsDataType::size()>&, const ck_tile::stream_config&)>>&
|
||||
callables)
|
||||
{
|
||||
const ALayout layout_a = ALayout{};
|
||||
const BLayout layout_b = BLayout{};
|
||||
const D0Layout layout_d0 = D0Layout{};
|
||||
const D1Layout layout_d1 = D1Layout{};
|
||||
const ELayout layout_e = ELayout{};
|
||||
|
||||
gemm_multi_d_problem.stride_a_ = ck_tile::get_default_stride(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.k_,
|
||||
gemm_multi_d_problem.stride_a_,
|
||||
is_row_major(layout_a));
|
||||
gemm_multi_d_problem.stride_b_ = ck_tile::get_default_stride(gemm_multi_d_problem.k_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_b_,
|
||||
is_row_major(layout_b));
|
||||
gemm_multi_d_problem.stride_d0_ =
|
||||
ck_tile::get_default_stride(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_d0_,
|
||||
is_row_major(layout_d0));
|
||||
gemm_multi_d_problem.stride_d1_ =
|
||||
ck_tile::get_default_stride(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_d1_,
|
||||
is_row_major(layout_d1));
|
||||
gemm_multi_d_problem.stride_e_ = ck_tile::get_default_stride(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_e_,
|
||||
is_row_major(layout_e));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.k_,
|
||||
gemm_multi_d_problem.stride_a_,
|
||||
is_row_major(layout_a)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.k_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_b_,
|
||||
is_row_major(layout_b)));
|
||||
ck_tile::HostTensor<D0DataType> d0_m_n(
|
||||
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_d0_,
|
||||
is_row_major(layout_d0)));
|
||||
ck_tile::HostTensor<D1DataType> d1_m_n(
|
||||
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_d1_,
|
||||
is_row_major(layout_d1)));
|
||||
ck_tile::HostTensor<EDataType> e_m_n_device_result(
|
||||
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_e_,
|
||||
is_row_major(layout_e)));
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<D0DataType>{-1.f, 1.f}(d0_m_n);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(d1_m_n);
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.mData.data());
|
||||
d0_m_n_dev_buf.ToDevice(d0_m_n.mData.data());
|
||||
d1_m_n_dev_buf.ToDevice(d1_m_n.mData.data());
|
||||
|
||||
e_m_n_dev_buf.SetZero();
|
||||
e_m_n_device_result.SetZero();
|
||||
|
||||
std::array<const void*, DsDataType::size()> ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(),
|
||||
d1_m_n_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck_tile::index_t, DsDataType::size()> stridesDs = {
|
||||
gemm_multi_d_problem.stride_d0_, gemm_multi_d_problem.stride_d1_};
|
||||
|
||||
ck_tile::GemmMultiDHostArgs<DsDataType::size()> gemm_multi_d_args = {
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
ds_ptr_buf,
|
||||
e_m_n_dev_buf.GetDeviceBuffer(),
|
||||
gemm_multi_d_problem.split_k_,
|
||||
gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.k_,
|
||||
gemm_multi_d_problem.stride_a_,
|
||||
gemm_multi_d_problem.stride_b_,
|
||||
stridesDs,
|
||||
gemm_multi_d_problem.stride_e_,
|
||||
};
|
||||
|
||||
ck_tile::HostTensor<EDataType> e_m_n_host_result(
|
||||
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_e_,
|
||||
is_row_major(layout_e)));
|
||||
|
||||
if(setting_.verify_)
|
||||
{
|
||||
gemm_multi_d_host_reference(
|
||||
setting_.verify_, a_m_k, b_k_n, d0_m_n, d1_m_n, e_m_n_host_result);
|
||||
}
|
||||
|
||||
for(auto& callable : callables)
|
||||
{
|
||||
auto kernel_run_result =
|
||||
callable(gemm_multi_d_args,
|
||||
ck_tile::stream_config{
|
||||
nullptr, true, setting_.log_, setting_.n_warmup_, setting_.n_repeat_});
|
||||
|
||||
auto [kernel_name, execution_time] = kernel_run_result;
|
||||
|
||||
process_result(gemm_multi_d_problem,
|
||||
e_m_n_dev_buf,
|
||||
e_m_n_host_result,
|
||||
e_m_n_device_result,
|
||||
kernel_run_result);
|
||||
}
|
||||
}
|
||||
|
||||
void process_result(const GemmMultiDProblem& gemm_multi_d_problem,
|
||||
ck_tile::DeviceMem& e_m_n_dev_buf,
|
||||
ck_tile::HostTensor<EDataType>& e_m_n_host_result,
|
||||
ck_tile::HostTensor<EDataType>& e_m_n_dev_result,
|
||||
const std::tuple<std::string, float>& kernel_run_result)
|
||||
{
|
||||
auto [name, avg_time] = kernel_run_result;
|
||||
|
||||
KernelInstance kernel_instance{name, gemm_multi_d_problem, {-1.0f, -1.0f, -1.0f}};
|
||||
|
||||
static constexpr ck_tile::index_t NumDTensor = DsDataType::size();
|
||||
std::size_t flop = 0, num_byte = 0;
|
||||
flop += std::size_t(2) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_ *
|
||||
gemm_multi_d_problem.k_;
|
||||
ck_tile::static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
num_byte += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) *
|
||||
gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
|
||||
flop += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) *
|
||||
gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
|
||||
});
|
||||
num_byte += sizeof(ADataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.k_ +
|
||||
sizeof(BDataType) * gemm_multi_d_problem.k_ * gemm_multi_d_problem.n_ +
|
||||
sizeof(EDataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
|
||||
|
||||
kernel_instance.perf_result_.latency_ = avg_time;
|
||||
kernel_instance.perf_result_.tflops_ = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time;
|
||||
|
||||
if(setting_.log_ > 0)
|
||||
{
|
||||
std::cout << kernel_instance << std::endl;
|
||||
}
|
||||
|
||||
e_m_n_dev_buf.FromDevice(e_m_n_dev_result.data());
|
||||
bool verified_correct =
|
||||
!setting_.verify_ ||
|
||||
compare(name, gemm_multi_d_problem.k_, e_m_n_dev_result, e_m_n_host_result);
|
||||
|
||||
if(verified_correct)
|
||||
{
|
||||
kernel_instances_.emplace_back(kernel_instance);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Verification failed, skip kernel: " << name << std::endl;
|
||||
}
|
||||
|
||||
e_m_n_dev_buf.SetZero();
|
||||
e_m_n_dev_result.SetZero();
|
||||
}
|
||||
|
||||
KernelInstance select_best_instance(Metric metric)
|
||||
{
|
||||
if(kernel_instances_.empty())
|
||||
throw std::runtime_error("Empty instances");
|
||||
|
||||
auto kernel_instance = *std::max_element(kernel_instances_.begin(),
|
||||
kernel_instances_.end(),
|
||||
[metric](const auto& a, const auto& b) {
|
||||
return PerformanceResult::compare(
|
||||
b.perf_result_, a.perf_result_, metric);
|
||||
});
|
||||
|
||||
std::cout << "**********************************" << std::endl;
|
||||
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
|
||||
<< "The best kernel instance is: " << kernel_instance << std::endl;
|
||||
std::cout << "**********************************" << std::endl;
|
||||
|
||||
if(!setting_.csv_filename_.empty())
|
||||
{
|
||||
std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app);
|
||||
|
||||
if(!file.is_open())
|
||||
{
|
||||
std::cerr << "Warning: Failed to open CSV file for writing." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(file.tellp() == 0)
|
||||
{
|
||||
file << "rocm_version,device_name,"
|
||||
<< "split_k,m,n,k,stride_a,stride_b,stride_c,"
|
||||
<< "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c,"
|
||||
<< "structured_sparsity," << "name,"
|
||||
<< "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n";
|
||||
}
|
||||
|
||||
const auto& problem = kernel_instance.problem_;
|
||||
const auto& name = kernel_instance.name_;
|
||||
const auto& perf = kernel_instance.perf_result_;
|
||||
|
||||
file << get_rocm_version() << "," << ck_tile::get_device_name() << ","
|
||||
<< problem.split_k_ << "," << problem.m_ << "," << problem.n_ << ","
|
||||
<< problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << ","
|
||||
<< problem.stride_d0_ << "," << problem.stride_d1_ << "," << problem.stride_e_
|
||||
<< "," << problem.dtype_a_ << "," << problem.dtype_b_ << ","
|
||||
<< problem.dtype_d0_ << "," << problem.dtype_d1_ << "," << problem.dtype_acc_
|
||||
<< "," << problem.dtype_e_ << "," << problem.layout_a_ << ","
|
||||
<< problem.layout_b_ << "," << problem.layout_d0_ << "," << problem.layout_d1_
|
||||
<< "," << problem.layout_e_ << "," << "," << name << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.latency_ << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.tflops_ << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric)
|
||||
<< "\n";
|
||||
|
||||
if(!file)
|
||||
{
|
||||
std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return kernel_instance;
|
||||
}
|
||||
|
||||
GemmMultiDProfiler(const GemmMultiDProfiler&) = delete;
|
||||
GemmMultiDProfiler& operator=(const GemmMultiDProfiler&) = delete;
|
||||
|
||||
private:
|
||||
~GemmMultiDProfiler() { kernel_instances_.clear(); }
|
||||
GemmMultiDProfiler(Setting setting) : setting_(setting) {}
|
||||
|
||||
Setting setting_;
|
||||
|
||||
std::vector<KernelInstance> kernel_instances_;
|
||||
};
|
||||
Reference in New Issue
Block a user