mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
Reorganize project folders (#6)
This commit is contained in:
5
tile_engine/CMakeLists.txt
Executable file
5
tile_engine/CMakeLists.txt
Executable file
@@ -0,0 +1,5 @@
|
||||
include_directories(BEFORE
|
||||
${CMAKE_CURRENT_LIST_DIR}/include
|
||||
)
|
||||
|
||||
add_subdirectory(ops)
|
||||
1
tile_engine/include/CMakeLists.txt
Executable file
1
tile_engine/include/CMakeLists.txt
Executable file
@@ -0,0 +1 @@
|
||||
message("Add include directory")
|
||||
1
tile_engine/ops/CMakeLists.txt
Executable file
1
tile_engine/ops/CMakeLists.txt
Executable file
@@ -0,0 +1 @@
|
||||
add_subdirectory(gemm)
|
||||
51
tile_engine/ops/gemm/CMakeLists.txt
Normal file
51
tile_engine/ops/gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,51 @@
|
||||
|
||||
|
||||
# generate a list of kernels, but not actually emit files at config stage
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
|
||||
--list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS
|
||||
${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${GEMM_CODEGEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
|
||||
--gen_blobs
|
||||
DEPENDS ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt
|
||||
${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
|
||||
)
|
||||
|
||||
set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm")
|
||||
message("adding example ${EXECUTABLE_GEMM_INSTANCE}")
|
||||
|
||||
# use build as include directory
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
|
||||
add_executable(${EXECUTABLE_GEMM_INSTANCE} EXCLUDE_FROM_ALL gemm_host_api.cpp)
|
||||
target_include_directories(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${GEMM_CODEGEN_BLOBS})
|
||||
|
||||
set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS)
|
||||
|
||||
list(APPEND EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress)
|
||||
|
||||
target_compile_options(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS})
|
||||
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
92
tile_engine/ops/gemm/README.md
Normal file
92
tile_engine/ops/gemm/README.md
Normal file
@@ -0,0 +1,92 @@
|
||||
# GEMM Matrix Multiplication
|
||||
|
||||
CK Tile Engine GEMM is used to generate and run GEMM kernels with different combinations of BlockTile sizes, WarpTile sizes, WarpTile mapping for all valid pipelines, schedulers and epilogues.
|
||||
|
||||
# Kernel Configurations
|
||||
|
||||
Kernel parameters are specified in the `instance_combination.json` file, including matrix layouts, data types, padding settings, pipelines, schedulers, epilogues, and numerical values for tile and warp sizes.
|
||||
|
||||
Given a valid set of values, tile_engine_gemm will automatically iterate over all possible combinations of BlockTile and WarpTile sizes, as well as the specified pipelines, schedulers, and epilogues from `./configs/instance_combination.json`, and build the corresponding kernels.
|
||||
|
||||
|
||||
## Build Instructions
|
||||
``` bash
|
||||
# in the root of composable kernel create build directory
|
||||
mkdir build && cd build
|
||||
# build composable kernel
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # replace <arch> with the appropriate architecture (example gfx942) or leave blank
|
||||
# generate the executable
|
||||
make tile_engine_gemm -j
|
||||
```
|
||||
`tile_engine_gemm` will be located in the `./bin/` directory.
|
||||
|
||||
_`tile_engine_gemm` must be rebuilt everytime `instance_combination.json` is modified._
|
||||
``` bash
|
||||
rm -rf tile_engine/ && make tile_engine_gemm -j # rebuild
|
||||
```
|
||||
|
||||
## tile_engine_gemm inputs
|
||||
```
|
||||
|
||||
-m m dimension (default:3840)
|
||||
-n n dimension (default:4096)
|
||||
-k k dimension (default:2048)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-split_k SplitK value (default:1)
|
||||
-v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2)
|
||||
-warmup Number of iterations before benchmark the kernel (default:50)
|
||||
-repeat Number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0)
|
||||
-structured_sparsity Sparsity for tensor - 0:false, 1:true (default: 0)
|
||||
-pipeline possible values are: compv3, compv4, mem (default:compv3)
|
||||
-scheduler possible values are: intrawave, interwave (default:intrawave)
|
||||
-epilogue possible values are: cshuffle, default (default:cshuffle)
|
||||
-pad_m Pad in m direction - true/false (default:false)
|
||||
-pad_n Pad in n direction - true/false (default:false)
|
||||
-pad_k Pad in k direction - true/false (default:false)
|
||||
|
||||
Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in instance_combination.json
|
||||
```
|
||||
Note: In `./configs/instance_combination.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above.
|
||||
|
||||
## Example
|
||||
|
||||
The following JSON file specifies parameters used to generate and build GEMM kernels across all possible combinations of pipelines, schedulers, epilogues with different tile and warp sizes.
|
||||
|
||||
```json
|
||||
{
|
||||
/// other parameters ///
|
||||
|
||||
"tile_m": {
|
||||
"values": [256]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [256]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [64, 32]
|
||||
},
|
||||
|
||||
/// other parameters ///
|
||||
|
||||
"pipeline": {
|
||||
"values": ["compv3", "compv4", "mem"]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": ["intrawave", "interwave"]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": ["default", "cshuffle"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
At runtime, a specific subset of the generated kernels can be selected using command-line arguments.
|
||||
``` bash
|
||||
./bin/tile_engine_gemm -pipeline=compv3 -scheduler=intrawave -epilogue=default
|
||||
```
|
||||
The above command runs kernels configured with the compv3 pipeline, intrawave scheduler, and default epilogue, while sweeping over different BlockTile sizes, WarpTile sizes, and WarpTile mappings.
|
||||
|
||||
60
tile_engine/ops/gemm/configs/instance_combination.json
Normal file
60
tile_engine/ops/gemm/configs/instance_combination.json
Normal file
@@ -0,0 +1,60 @@
|
||||
{
|
||||
|
||||
"layout_a": {
|
||||
"values": ["r"]
|
||||
},
|
||||
"layout_b": {
|
||||
"values": ["c"]
|
||||
},
|
||||
"layout_c": {
|
||||
"values": ["r"]
|
||||
},
|
||||
"datatype": {
|
||||
"values": ["fp16"]
|
||||
},
|
||||
"tile_m": {
|
||||
"values": [256]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [256]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [64, 32]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [2]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [2]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [1]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [32]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [32]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [16]
|
||||
},
|
||||
"kPadM": {
|
||||
"values": [false]
|
||||
},
|
||||
"kPadN": {
|
||||
"values": [false]
|
||||
},
|
||||
"kPadK": {
|
||||
"values": [false]
|
||||
},
|
||||
"pipeline": {
|
||||
"values": ["compv3", "compv4", "mem"]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": ["intrawave", "interwave"]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": ["default", "cshuffle"]
|
||||
}
|
||||
}
|
||||
192
tile_engine/ops/gemm/gemm_host_api.cpp
Executable file
192
tile_engine/ops/gemm/gemm_host_api.cpp
Executable file
@@ -0,0 +1,192 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_common.hpp"
|
||||
#include "gemm_dispatcher.hpp"
|
||||
#include "gemm_host_api.hpp"
|
||||
|
||||
void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify,
|
||||
bool structured_sparsity,
|
||||
KernelTraits& trait,
|
||||
ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& stream)
|
||||
{
|
||||
return GemmDispatcher::dispatch(c_m_n_dev_buf,
|
||||
c_m_n_host_result,
|
||||
c_m_n_dev_result,
|
||||
verify,
|
||||
structured_sparsity,
|
||||
trait,
|
||||
args,
|
||||
stream);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
void run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
const ALayout a_layout = ALayout{};
|
||||
const BLayout b_layout = BLayout{};
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
int verify = arg_parser.get_int("v");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
bool structured_sparsity = arg_parser.get_bool("structured_sparsity");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(1)}(a_m_k);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(1)}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
b_k_n.SetZero();
|
||||
}
|
||||
|
||||
if(structured_sparsity)
|
||||
{
|
||||
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
// permute_tensor_b<decltype(b_k_n_dev)>(b_k_n_dev);
|
||||
permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::GemmHostArgs gemm_args;
|
||||
gemm_args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
gemm_args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
gemm_args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
gemm_args.k_batch = kbatch;
|
||||
gemm_args.M = M;
|
||||
gemm_args.N = N;
|
||||
gemm_args.K = K;
|
||||
gemm_args.stride_A = stride_A;
|
||||
gemm_args.stride_B = stride_B;
|
||||
gemm_args.stride_C = stride_C;
|
||||
|
||||
KernelTraits trait;
|
||||
trait.pipeline = arg_parser.get_str("pipeline");
|
||||
trait.scheduler = arg_parser.get_str("scheduler");
|
||||
trait.epilogue = arg_parser.get_str("epilogue");
|
||||
trait.kPadM = arg_parser.get_bool("pad_m");
|
||||
trait.kPadN = arg_parser.get_bool("pad_n");
|
||||
trait.kPadK = arg_parser.get_bool("pad_k");
|
||||
|
||||
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
|
||||
<< " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name
|
||||
<< " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits<ADataType>::name
|
||||
<< " B Type = " << DataTypeTraits<BDataType>::name
|
||||
<< " C Type = " << DataTypeTraits<CDataType>::name << std::endl;
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
if(verify)
|
||||
{
|
||||
gemm_host_reference<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(verify,
|
||||
a_m_k,
|
||||
b_k_n,
|
||||
c_m_n_host_result,
|
||||
a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C);
|
||||
}
|
||||
|
||||
gemm_kernel_launch(c_m_n_dev_buf,
|
||||
c_m_n_host_result,
|
||||
c_m_n_dev_result,
|
||||
verify,
|
||||
structured_sparsity,
|
||||
trait,
|
||||
gemm_args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
auto [result, parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
run<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(parser);
|
||||
return 0;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << "\n";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
263
tile_engine/ops/gemm/gemm_host_api.hpp
Executable file
263
tile_engine/ops/gemm/gemm_host_api.hpp
Executable file
@@ -0,0 +1,263 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
/// @brief Defines the configuration parameters for a GEMM operation, enabling the selection of a
|
||||
/// specific kernel instance based on the provided settings.
|
||||
struct KernelTraits
|
||||
{
|
||||
/// @brief The name of the pipeline.
|
||||
std::string pipeline;
|
||||
/// @brief The name of the scheduler (e.g., "intrawave", "interwave").
|
||||
std::string scheduler;
|
||||
/// @brief The name of the epilogue (e.g., "cshuffle", "default").
|
||||
std::string epilogue;
|
||||
/// @brief Indicates whether padding is applied to the M dimension.
|
||||
bool kPadM;
|
||||
/// @brief Indicates whether padding is applied to the N dimension.
|
||||
bool kPadN;
|
||||
/// @brief Indicates whether padding is applied to the K dimension.
|
||||
bool kPadK;
|
||||
};
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
inline auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("structured_sparsity", "0", "0:false, 1:true")
|
||||
.insert("pipeline", "compv3", "compv3, compv4, mem")
|
||||
.insert("scheduler", "intrawave", "intrawave, interwave")
|
||||
.insert("epilogue", "cshuffle", "cshuffle, default")
|
||||
.insert("pad_m", "false", "true, false")
|
||||
.insert("pad_n", "false", "true, false")
|
||||
.insert("pad_k", "false", "true, false");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename Tensor>
|
||||
void permute_vectors_i4x4_b(Tensor& tensor)
|
||||
{
|
||||
const ck_tile::index_t K = tensor.get_length(0);
|
||||
const ck_tile::index_t N = tensor.get_length(1);
|
||||
// vector pk_i4x4 permute
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j += 8)
|
||||
{
|
||||
int8_t input[8];
|
||||
|
||||
for(int k = 0; k < 4; k++)
|
||||
{
|
||||
int8_t i4x2 = tensor(j + k * 2, i).data;
|
||||
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
|
||||
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
|
||||
}
|
||||
|
||||
// permute 01234567->20643175
|
||||
{
|
||||
int8_t hi = input[2];
|
||||
int8_t lo = input[0];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 0, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[6];
|
||||
int8_t lo = input[4];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 2, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[3];
|
||||
int8_t lo = input[1];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 4, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[7];
|
||||
int8_t lo = input[5];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 6, i) = i4x2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Function to compare the results of the device and host computations
|
||||
void compare(ck_tile::index_t K,
|
||||
ck_tile::index_t kbatch,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
/// @brief Function to get the kernel output with reference implementation on CPU/GPU
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
void gemm_host_reference(int verify,
|
||||
ck_tile::HostTensor<ADataType>& a_m_k,
|
||||
ck_tile::HostTensor<BDataType>& b_k_n,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C)
|
||||
{
|
||||
if(verify == 1)
|
||||
{
|
||||
c_m_n_host_result.SetZero();
|
||||
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_result);
|
||||
}
|
||||
else if(verify == 2)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Restore input for B for gpu reference
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes());
|
||||
c_m_n_host_result.SetZero();
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data());
|
||||
}
|
||||
}
|
||||
644
tile_engine/ops/gemm/gemm_instance_builder.py
Executable file
644
tile_engine/ops/gemm/gemm_instance_builder.py
Executable file
@@ -0,0 +1,644 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import argparse
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import List, Optional, Dict, Any
|
||||
import functools
|
||||
import itertools
|
||||
import copy
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
DATA_TYPE_MAP = {'fp32' : 'float',
|
||||
'fp16' : 'ck_tile::half_t',
|
||||
'bf16' : 'ck_tile::bf16_t',
|
||||
'int8' : 'ck_tile::int8_t',
|
||||
'fp8' : 'ck_tile::fp8_t',
|
||||
'bf8' : 'ck_tile::bf8_t',
|
||||
'int4' : 'ck_tile::pk_int4_t'
|
||||
}
|
||||
|
||||
LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor',
|
||||
'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'}
|
||||
|
||||
DEFAULT_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
kPadM,
|
||||
kPadN,
|
||||
WarpTileM,
|
||||
WarpTileN,
|
||||
WarpTileK,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
"""
|
||||
|
||||
CSHUFFLE_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
WarpM,
|
||||
WarpN,
|
||||
WarpTileM,
|
||||
WarpTileN,
|
||||
WarpTileK,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
"""
|
||||
HOT_LOOP_FALSE = """
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
Run(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
Run(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Num K loop must be larger than number of prefetech stages.");
|
||||
}
|
||||
"""
|
||||
RUN_MEM = """
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Two)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
if(tail_num == ck_tile::TailNumber::Four)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
|
||||
}
|
||||
if(tail_num == ck_tile::TailNumber::Five)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
|
||||
}
|
||||
if(tail_num == ck_tile::TailNumber::Six)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
|
||||
}
|
||||
if(tail_num == ck_tile::TailNumber::Seven)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
|
||||
}
|
||||
throw std::runtime_error("The tile number is wrong! It should not exceed the prefetch stage numbers");
|
||||
}
|
||||
"""
|
||||
|
||||
RUN_COMPV3 = """
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("The tail number is wrong. It should be Full, Odd, or Even.");
|
||||
}
|
||||
"""
|
||||
|
||||
RUN_COMPV4 = """
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
PIPELINE_MAP = {'mem' : ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'],
|
||||
'compv3' : ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'],
|
||||
'compv4' : ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']}
|
||||
|
||||
SCHEDULER_MAP = {'interwave' : 'ck_tile::GemmPipelineScheduler::Interwave',
|
||||
'intrawave' : 'ck_tile::GemmPipelineScheduler::Intrawave'}
|
||||
|
||||
EPILOGUE_MAP = {'default' :DEFAULT_EPILOGUE,
|
||||
'cshuffle' : CSHUFFLE_EPILOGUE}
|
||||
|
||||
HOT_LOOP_TRUE = {'mem' : RUN_MEM,
|
||||
'compv3' : RUN_COMPV3,
|
||||
'compv4' : RUN_COMPV4}
|
||||
|
||||
|
||||
def BOOL_MAP(b_) -> str:
|
||||
if b_:
|
||||
return 'true'
|
||||
else:
|
||||
return 'false'
|
||||
|
||||
@dataclass
|
||||
class GemmConfig:
|
||||
def __init__(self, config_data):
|
||||
self.matrix_cfg : Dict[str, Any] = {}
|
||||
self.impl_cfg : Dict[str, Any] = {}
|
||||
for key, value in config_data.items():
|
||||
if key in ["datatype", "layout_a", "layout_b", "layout_c"]:
|
||||
self.matrix_cfg[key] = value
|
||||
else:
|
||||
self.impl_cfg[key] = value
|
||||
|
||||
@property
|
||||
def datatype(self) -> str:
|
||||
return self.matrix_cfg["datatype"]["values"][0]
|
||||
|
||||
@property
|
||||
def layouts(self) -> List[str]:
|
||||
return [
|
||||
self.matrix_cfg["layout_a"]["values"][0],
|
||||
self.matrix_cfg["layout_b"]["values"][0],
|
||||
self.matrix_cfg["layout_c"]["values"][0]
|
||||
]
|
||||
|
||||
|
||||
class GemmCodeGenerator:
|
||||
def __init__(self, output_dir: str, config: GemmConfig):
|
||||
self.output_dir = Path(output_dir)
|
||||
if not self.output_dir.exists():
|
||||
self.output_dir.mkdir()
|
||||
|
||||
self.config = config
|
||||
self.all_kernels = []
|
||||
self.unique_configs = []
|
||||
# Validate configurations
|
||||
self._validate_config()
|
||||
|
||||
def _validate_config(self):
|
||||
"""Validate matrix and implementation configurations"""
|
||||
# Matrix config validation
|
||||
for param in ["datatype", "layout_a", "layout_b", "layout_c"]:
|
||||
if len(self.config.matrix_cfg[param]["values"]) != 1:
|
||||
raise ValueError(f"Matrix config {param} must have exactly one value")
|
||||
|
||||
# Implementation traits validation
|
||||
required_params = ["tile_m", "tile_n", "tile_k", "warp_m", "warp_n", "warp_k",
|
||||
"warp_tile_m", "warp_tile_n", "warp_tile_k", "pipeline",
|
||||
"epilogue", "scheduler", "kPadM", "kPadN", "kPadK"]
|
||||
for param in required_params:
|
||||
if not self.config.impl_cfg.get(param, {}).get("values"):
|
||||
raise ValueError(f"Missing implementation parameter: {param}")
|
||||
|
||||
def list_all(self):
|
||||
"""List all possible kernel configurations"""
|
||||
w_p = Path(self.output_dir)
|
||||
list_p = w_p / 'gemm_instance_blobs.txt'
|
||||
self._list_config_groups()
|
||||
with list_p.open('w') as list_f:
|
||||
list_f.write(str(w_p / ("gemm_common.hpp")) + "\n")
|
||||
list_f.write(str(w_p / ("gemm_instances.hpp")) + "\n")
|
||||
list_f.write(str(w_p / ("gemm_dispatcher.hpp")) + "\n")
|
||||
for group in self.all_kernels:
|
||||
list_f.write(str(w_p / ("gemm_" + group + ".hpp")) + "\n")
|
||||
|
||||
|
||||
|
||||
def _list_config_groups(self):
|
||||
params = [
|
||||
("pipeline", "pipeline"),
|
||||
("epilogue", "epilogue"),
|
||||
("scheduler", "scheduler"),
|
||||
("kPadM", "kPadM"),
|
||||
("kPadN", "kPadN"),
|
||||
("kPadK", "kPadK")
|
||||
]
|
||||
|
||||
# Generate all unique_combinations
|
||||
_unique = set(itertools.product(*[self.config.impl_cfg[p]["values"] for (p, _) in params]))
|
||||
for combo in _unique:
|
||||
config = {name: value for (_, name), value in zip(params, combo)}
|
||||
pipeline, epilogue, scheduler, kPadM, kPadN, kPadK = config.values()
|
||||
# To remove some unsupported combinations
|
||||
unsupported_combination = [("compv3", "cshuffle", "interwave"),
|
||||
("compv3", "default", "interwave"),
|
||||
("compv4", "cshuffle", "interwave"),
|
||||
("compv4", "default", "interwave")]
|
||||
if (pipeline, epilogue, scheduler) not in unsupported_combination:
|
||||
group_name = f"{pipeline}_{epilogue}_{scheduler}_pad_{BOOL_MAP(kPadM)}_{BOOL_MAP(kPadN)}_{BOOL_MAP(kPadK)}"
|
||||
self.all_kernels.append(group_name)
|
||||
self.unique_configs.append(config)
|
||||
|
||||
def generate_all(self):
|
||||
self._generate_common_header()
|
||||
self._generate_config_groups()
|
||||
self._generate_dispatcher()
|
||||
|
||||
|
||||
def _generate_common_header(self):
|
||||
"""Generate common header with datatypes and layout"""
|
||||
ctype = self.config.datatype
|
||||
atype = self.config.datatype
|
||||
btype = self.config.datatype
|
||||
if self.config.datatype in ['fp8', 'bf8']:
|
||||
ctype = 'fp16'
|
||||
elif self.config.datatype in ['int4']:
|
||||
atype = 'fp16'
|
||||
ctype = 'fp16'
|
||||
|
||||
content = f"""// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
// Data types
|
||||
using ADataType = {DATA_TYPE_MAP[atype]};
|
||||
using BDataType = {DATA_TYPE_MAP[btype]};
|
||||
using AccDataType = float;
|
||||
using CDataType = {DATA_TYPE_MAP[ctype]};
|
||||
|
||||
// Layout configurations
|
||||
using ALayout = {LAYOUT_MAP[self.config.layouts[0]]};
|
||||
using BLayout = {LAYOUT_MAP[self.config.layouts[1]]};
|
||||
using CLayout = {LAYOUT_MAP[self.config.layouts[2]]};
|
||||
"""
|
||||
|
||||
|
||||
(self.output_dir / "gemm_common.hpp").write_text(content)
|
||||
|
||||
def _generate_config_groups(self):
|
||||
"""Generate implementation configuration groups"""
|
||||
if not self.unique_configs: # Check if the list is empty
|
||||
self._list_config_groups()
|
||||
for config in self.unique_configs:
|
||||
self._generate_config_group(**config)
|
||||
self.generate_common_instances_header()
|
||||
|
||||
|
||||
def _generate_config_group(self, pipeline: str, epilogue: str, scheduler: str,
|
||||
kPadM: bool, kPadN: bool, kPadK: bool):
|
||||
"""Generate a configuration group with all tile/warp combinations"""
|
||||
group_name = f"{pipeline}_{epilogue}_{scheduler}_pad_{BOOL_MAP(kPadM)}_{BOOL_MAP(kPadN)}_{BOOL_MAP(kPadK)}"
|
||||
filename = f"gemm_{group_name}.hpp"
|
||||
|
||||
content = f"""// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_common.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
namespace {group_name} {{
|
||||
"""
|
||||
# Add template struct with configuration
|
||||
content += self._generate_kernel_struct(pipeline, epilogue, scheduler, kPadM, kPadN, kPadK)
|
||||
|
||||
content += f"\n}} // namespace {group_name}\n"
|
||||
(self.output_dir / filename).write_text(content)
|
||||
|
||||
def _generate_kernel_struct(self, pipeline: str, epilogue: str, scheduler: str,
|
||||
kPadM: bool, kPadN: bool, kPadK: bool) -> str:
|
||||
"""Generate kernel struct template"""
|
||||
return f"""
|
||||
template <int TileM, int TileN, int TileK,
|
||||
int WarpM, int WarpN, int WarpK,
|
||||
int WarpTileM, int WarpTileN, int WarpTileK,
|
||||
bool structured_sparsity>
|
||||
struct GemmKernel {{
|
||||
static constexpr bool kPadM = {BOOL_MAP(kPadM)};
|
||||
static constexpr bool kPadN = {BOOL_MAP(kPadN)};
|
||||
static constexpr bool kPadK = {BOOL_MAP(kPadK)};
|
||||
|
||||
static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) {{
|
||||
static constexpr bool permuteA = false;
|
||||
static constexpr bool permuteB = false;
|
||||
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
|
||||
static constexpr bool TransposeC = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<TileM, TileN, TileK>,
|
||||
ck_tile::sequence<WarpM, WarpN, WarpK>,
|
||||
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>,
|
||||
permuteA,
|
||||
permuteB>;
|
||||
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
TileParitionerGroupNum,
|
||||
TileParitionerM01>;
|
||||
|
||||
using Traits =
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
||||
ALayout, BLayout, CLayout, TransposeC, structured_sparsity>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * TileK;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{{0}};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = {SCHEDULER_MAP[scheduler]};
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = {PIPELINE_MAP[pipeline][1]}<UniversalGemmProblem>;
|
||||
{EPILOGUE_MAP[epilogue]}
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
||||
}}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
|
||||
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
|
||||
<< std::endl;
|
||||
}}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
|
||||
Kernel{{}}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
|
||||
}};
|
||||
|
||||
if(has_hot_loop) {{
|
||||
{HOT_LOOP_TRUE[pipeline]}
|
||||
}} else {{
|
||||
{HOT_LOOP_FALSE}
|
||||
}}
|
||||
|
||||
return ave_time;
|
||||
}}
|
||||
static std::string get_name() {{
|
||||
return std::string("GemmKernel<Bllktile: ") + std::to_string(TileM) + "x" + std::to_string(TileN) + "x" + std::to_string(TileK) + ", " +
|
||||
"WaveMap: " + std::to_string(WarpM) + "x" + std::to_string(WarpN) + "x" + std::to_string(WarpK) + ", " +
|
||||
"WarpTile: " + std::to_string(WarpTileM) + "x" + std::to_string(WarpTileN) + "x" + std::to_string(WarpTileK) + ", " +
|
||||
"PadidngM: " + "{kPadM}" + ", " +
|
||||
"PaddingN: " + "{kPadN}" + ", " +
|
||||
"PaddingK: " + "{kPadK}" + ", " +
|
||||
"Pipeline: " + "{pipeline}" + ", " +
|
||||
"Epilogue: " + "{epilogue}" + ", " +
|
||||
"Scheduler: " + "{scheduler}";
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
|
||||
def generate_common_instances_header(self):
|
||||
"""Generate common instances header"""
|
||||
content = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
"""
|
||||
for group in self.all_kernels:
|
||||
content += f"#include \"gemm_{group}.hpp\"\n"
|
||||
(self.output_dir / "gemm_instances.hpp").write_text(content)
|
||||
|
||||
def _generate_dispatcher(self):
|
||||
"""Generate dispatch mechanism"""
|
||||
content = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_common.hpp"
|
||||
#include "gemm_instances.hpp"
|
||||
#include "gemm_host_api.hpp"
|
||||
#include <unordered_map>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
struct GemmDispatcher {
|
||||
static auto& get_kernel_map() {
|
||||
// Use a static local variable
|
||||
static std::unordered_map<std::string,
|
||||
std::function<void(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>> kernel_map;
|
||||
return kernel_map;
|
||||
}
|
||||
|
||||
static void init(bool structured_sparsity) {
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(!kernel_map.empty()) return;
|
||||
\n"""
|
||||
# Add tile/warp instantiations
|
||||
tile_params = set(itertools.product(
|
||||
self.config.impl_cfg["tile_m"]["values"],
|
||||
self.config.impl_cfg["tile_n"]["values"],
|
||||
self.config.impl_cfg["tile_k"]["values"],
|
||||
self.config.impl_cfg["warp_m"]["values"],
|
||||
self.config.impl_cfg["warp_n"]["values"],
|
||||
self.config.impl_cfg["warp_k"]["values"],
|
||||
self.config.impl_cfg["warp_tile_m"]["values"],
|
||||
self.config.impl_cfg["warp_tile_n"]["values"],
|
||||
self.config.impl_cfg["warp_tile_k"]["values"]
|
||||
))
|
||||
|
||||
|
||||
for group in self.all_kernels:
|
||||
content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& stream) {{
|
||||
"""
|
||||
for tile in tile_params:
|
||||
# Check if we have valid tile/warp combinations
|
||||
# (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m
|
||||
if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \
|
||||
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
|
||||
continue
|
||||
content += f"""
|
||||
if(structured_sparsity) {{
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {1}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);
|
||||
}} else {{
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {0}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);
|
||||
}}"""
|
||||
content += f"""
|
||||
}};\n"""
|
||||
|
||||
content += """ }
|
||||
|
||||
template <typename Kernel>
|
||||
static void run_kernel(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream)
|
||||
{
|
||||
float avg_time = Kernel::launch(args, stream);
|
||||
std::string description = Kernel::get_name();
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * args.M * args.K + sizeof(BDataType) * args.N * args.K + sizeof(CDataType) * args.M * args.N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Performance for " << description << " : " << avg_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
if(verify)
|
||||
compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result);
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
}
|
||||
|
||||
static auto dispatch(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, bool structured_sparsity, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
|
||||
const ck_tile::stream_config& stream) {
|
||||
init(structured_sparsity);
|
||||
const std::string key = assemble_key(trait);
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(auto it = kernel_map.find(key); it != kernel_map.end()) {
|
||||
return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, gemm_args, stream);
|
||||
}
|
||||
throw std::runtime_error("No suitable kernel found: " + key);
|
||||
}
|
||||
|
||||
private:
|
||||
static std::string assemble_key(const KernelTraits &trait) {
|
||||
return std::string(trait.pipeline) + "_" +
|
||||
trait.epilogue + "_" +
|
||||
trait.scheduler + "_" +
|
||||
"pad_" +
|
||||
(trait.kPadM ? "true" : "false") + "_" +
|
||||
(trait.kPadN ? "true" : "false") + "_" +
|
||||
(trait.kPadK ? "true" : "false");
|
||||
}
|
||||
};
|
||||
|
||||
"""
|
||||
(self.output_dir / "gemm_dispatcher.hpp").write_text(content)
|
||||
|
||||
|
||||
def do_list_blobs(args, gemm_config):
|
||||
generator = GemmCodeGenerator(args.working_path, gemm_config)
|
||||
generator.list_all()
|
||||
|
||||
def do_gen_blobs(args, gemm_config):
|
||||
generator = GemmCodeGenerator(args.working_path, gemm_config)
|
||||
generator.generate_all()
|
||||
|
||||
|
||||
|
||||
def main(args):
|
||||
# Read json file
|
||||
with open(args.json, 'r') as json_file:
|
||||
config_data = json.load(json_file)
|
||||
|
||||
gemm_config = GemmConfig(config_data)
|
||||
|
||||
if args.list_blobs:
|
||||
do_list_blobs(args, gemm_config)
|
||||
elif args.gen_blobs:
|
||||
do_gen_blobs(args, gemm_config)
|
||||
else:
|
||||
# If neither was specified, either do nothing or default to gen_blobs
|
||||
print("No mode specified (use --list_blobs or --gen_blobs). Generating by default...")
|
||||
do_gen_blobs(args, gemm_config)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate",
|
||||
description="gen API for CK gemm kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w", "--working_path", default="./", required=False, help="the path where all the blobs are going to be generated"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-j", "--json", required=True, help="Path to the json which contains the kernel configurations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l", "--list_blobs", action = 'store_true', help="List all kernel to file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g", "--gen_blobs", action = 'store_true', help="Generate all kernels into different files"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
Reference in New Issue
Block a user