mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[New] Build up the feature of CK Tile GEMM CodeGen (#1994)
* New branch for codegen changes
* Fix verify function for int4
* pk_int4 codegen
* Update to review comments
* Remove codegen directory and rename filenames
* Remove extra files; clean up CMake file
* New branch for codegen changes
* Fix verify function for int4
* pk_int4 codegen
* Update to review comments
* Remove codegen directory and rename filenames
* Remove extra files; clean up CMake file
* code changes for single instance
* config file rename, added few more combinations in json file
* Fix cmake file
* Addressing review comments
* Reverting files changed by merge to develop
---------
Co-authored-by: ThomasNing <thomas.ning@amd.com>
[ROCm/composable_kernel commit: fed0709121]
This commit is contained in:
@@ -610,6 +610,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
|
||||
PACKAGE_NAME examples
|
||||
)
|
||||
add_subdirectory(example)
|
||||
add_subdirectory(tile_engine)
|
||||
if(BUILD_TESTING)
|
||||
add_subdirectory(test)
|
||||
endif()
|
||||
|
||||
5
tile_engine/CMakeLists.txt
Executable file
5
tile_engine/CMakeLists.txt
Executable file
@@ -0,0 +1,5 @@
|
||||
include_directories(BEFORE
|
||||
${CMAKE_CURRENT_LIST_DIR}/include
|
||||
)
|
||||
|
||||
add_subdirectory(ops)
|
||||
1
tile_engine/ops/CMakeLists.txt
Executable file
1
tile_engine/ops/CMakeLists.txt
Executable file
@@ -0,0 +1 @@
|
||||
add_subdirectory(gemm)
|
||||
45
tile_engine/ops/gemm/CMakeLists.txt
Normal file
45
tile_engine/ops/gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,45 @@
|
||||
|
||||
|
||||
# generate a list of kernels, but not actually emit files at config stage
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
|
||||
--list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${GEMM_CODEGEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
|
||||
--gen_blobs
|
||||
DEPENDS ${GEMM_CODEGEN_BLOBS}
|
||||
)
|
||||
|
||||
set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm")
|
||||
message("adding example ${EXECUTABLE_GEMM_INSTANCE}")
|
||||
|
||||
# use build as include directory
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
|
||||
add_executable(${EXECUTABLE_GEMM_INSTANCE} EXCLUDE_FROM_ALL gemm_host_api.cpp)
|
||||
target_include_directories(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${GEMM_CODEGEN_BLOBS})
|
||||
|
||||
set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS)
|
||||
|
||||
list(APPEND EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress)
|
||||
|
||||
target_compile_options(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS})
|
||||
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
60
tile_engine/ops/gemm/configs/instance_combination.json
Normal file
60
tile_engine/ops/gemm/configs/instance_combination.json
Normal file
@@ -0,0 +1,60 @@
|
||||
{
|
||||
|
||||
"layout_a": {
|
||||
"values": ["r"]
|
||||
},
|
||||
"layout_b": {
|
||||
"values": ["c"]
|
||||
},
|
||||
"layout_c": {
|
||||
"values": ["r"]
|
||||
},
|
||||
"datatype": {
|
||||
"values": ["fp16"]
|
||||
},
|
||||
"tile_m": {
|
||||
"values": [256]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [256]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [64]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [2]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [2]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [1]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [32]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [32]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [16]
|
||||
},
|
||||
"kPadM": {
|
||||
"values": [false]
|
||||
},
|
||||
"kPadN": {
|
||||
"values": [false]
|
||||
},
|
||||
"kPadK": {
|
||||
"values": [false]
|
||||
},
|
||||
"pipeline": {
|
||||
"values": ["compv3", "mem"]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": ["intrawave", "interwave"]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": ["default", "cshuffle"]
|
||||
}
|
||||
}
|
||||
169
tile_engine/ops/gemm/gemm_host_api.cpp
Normal file
169
tile_engine/ops/gemm/gemm_host_api.cpp
Normal file
@@ -0,0 +1,169 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_common.hpp"
|
||||
#include "gemm_dispatcher.hpp"
|
||||
#include "gemm_host_api.hpp"
|
||||
|
||||
float gemm_kernel_launch(KernelTraits& trait,
|
||||
ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
return GemmDispatcher::dispatch(trait, args, s);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
const ALayout a_layout = ALayout{};
|
||||
const BLayout b_layout = BLayout{};
|
||||
// const CLayout c_layout = CLayout{};
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
int verify = arg_parser.get_int("v");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(1)}(a_m_k);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(1)}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
b_k_n.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
// permute_tensor_b<decltype(b_k_n_dev)>(b_k_n_dev);
|
||||
permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::GemmHostArgs gemm_args;
|
||||
gemm_args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
gemm_args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
gemm_args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
gemm_args.k_batch = kbatch;
|
||||
gemm_args.M = M;
|
||||
gemm_args.N = N;
|
||||
gemm_args.K = K;
|
||||
gemm_args.stride_A = stride_A;
|
||||
gemm_args.stride_B = stride_B;
|
||||
gemm_args.stride_C = stride_C;
|
||||
|
||||
KernelTraits trait;
|
||||
trait.pipeline = arg_parser.get_str("pipeline");
|
||||
trait.scheduler = arg_parser.get_str("scheduler");
|
||||
trait.epilogue = arg_parser.get_str("epilogue");
|
||||
trait.kPadM = arg_parser.get_bool("pad_m");
|
||||
trait.kPadN = arg_parser.get_bool("pad_n");
|
||||
trait.kPadK = arg_parser.get_bool("pad_k");
|
||||
|
||||
float ave_time = gemm_kernel_launch(
|
||||
trait, gemm_args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
|
||||
<< " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name
|
||||
<< " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits<ADataType>::name
|
||||
<< " B Type = " << DataTypeTraits<BDataType>::name
|
||||
<< " C Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
if(verify)
|
||||
{
|
||||
pass = gemm_verify<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
verify,
|
||||
a_m_k,
|
||||
b_k_n,
|
||||
c_m_n_dev_result,
|
||||
a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch);
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
auto [result, parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
return run<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(parser);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << "\n";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
287
tile_engine/ops/gemm/gemm_host_api.hpp
Normal file
287
tile_engine/ops/gemm/gemm_host_api.hpp
Normal file
@@ -0,0 +1,287 @@
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
struct KernelTraits
|
||||
{
|
||||
std::string pipeline;
|
||||
std::string scheduler;
|
||||
std::string epilogue;
|
||||
bool kPadM;
|
||||
bool kPadN;
|
||||
bool kPadK;
|
||||
};
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
inline auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("pipeline", "compv3", "compv3, compv4, mem")
|
||||
.insert("scheduler", "intrawave", "intrawave, interwave")
|
||||
.insert("epilogue", "cshuffle", "cshuffle, default")
|
||||
.insert("pad_m", "false", "true, false")
|
||||
.insert("pad_n", "false", "true, false")
|
||||
.insert("pad_k", "false", "true, false");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename Tensor>
|
||||
void permute_vectors_i4x4_b(Tensor& tensor)
|
||||
{
|
||||
const ck_tile::index_t K = tensor.get_length(0);
|
||||
const ck_tile::index_t N = tensor.get_length(1);
|
||||
// vector pk_i4x4 permute
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j += 8)
|
||||
{
|
||||
int8_t input[8];
|
||||
|
||||
for(int k = 0; k < 4; k++)
|
||||
{
|
||||
int8_t i4x2 = tensor(j + k * 2, i).data;
|
||||
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
|
||||
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
|
||||
}
|
||||
|
||||
// permute 01234567->20643175
|
||||
{
|
||||
int8_t hi = input[2];
|
||||
int8_t lo = input[0];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 0, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[6];
|
||||
int8_t lo = input[4];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 2, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[3];
|
||||
int8_t lo = input[1];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 4, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[7];
|
||||
int8_t lo = input[5];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 6, i) = i4x2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// verification code
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
bool gemm_verify(int verify,
|
||||
ck_tile::HostTensor<ADataType>& a_m_k,
|
||||
ck_tile::HostTensor<BDataType>& b_k_n,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch)
|
||||
{
|
||||
bool pass = true;
|
||||
if(verify == 1)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
else if(verify == 2)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Restore input for B for gpu reference
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
|
||||
c_m_n_gpu_ref.SetZero();
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A;
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
|
||||
ck_tile::hip_check_error(hipMalloc(&d_A, a_m_k.get_element_space_size_in_bytes()));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_B, b_k_n.get_element_space_size_in_bytes()));
|
||||
ck_tile::hip_check_error(
|
||||
hipMalloc(&d_C, c_m_n_dev_result.get_element_space_size_in_bytes()));
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(d_A,
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
a_m_k.get_element_space_size_in_bytes(),
|
||||
hipMemcpyHostToDevice));
|
||||
ck_tile::hip_check_error(hipMemcpy(d_B,
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n.get_element_space_size_in_bytes(),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
|
||||
d_C,
|
||||
c_m_n_dev_result.get_element_space_size_in_bytes(),
|
||||
hipMemcpyDeviceToHost));
|
||||
|
||||
ck_tile::hip_check_error(hipFree(d_A));
|
||||
ck_tile::hip_check_error(hipFree(d_B));
|
||||
ck_tile::hip_check_error(hipFree(d_C));
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_gpu_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
596
tile_engine/ops/gemm/gemm_instance_builder.py
Executable file
596
tile_engine/ops/gemm/gemm_instance_builder.py
Executable file
@@ -0,0 +1,596 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import argparse
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import List, Optional, Dict, Any
|
||||
import functools
|
||||
import itertools
|
||||
import copy
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
DATA_TYPE_MAP = {'fp32' : 'float',
|
||||
'fp16' : 'ck_tile::half_t',
|
||||
'bf16' : 'ck_tile::bf16_t',
|
||||
'int8' : 'ck_tile::int8_t',
|
||||
'fp8' : 'ck_tile::fp8_t',
|
||||
'bf8' : 'ck_tile::bf8_t',
|
||||
'int4' : 'ck_tile::pk_int4_t'
|
||||
}
|
||||
|
||||
LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor',
|
||||
'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'}
|
||||
|
||||
DEFAULT_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
ck_tile::DefaultGemm2DEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
kPadM,
|
||||
kPadN,
|
||||
WarpTileM,
|
||||
WarpTileN,
|
||||
WarpTileK,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
"""
|
||||
|
||||
CSHUFFLE_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
WarpM,
|
||||
WarpN,
|
||||
WarpTileM,
|
||||
WarpTileN,
|
||||
WarpTileK,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
"""
|
||||
HOT_LOOP_FALSE = """
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
Run(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
Run(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Num K loop must be larger than number of prefetech stages.");
|
||||
}
|
||||
"""
|
||||
RUN_MEM = """
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Two)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
if(tail_num == ck_tile::TailNumber::Four)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
|
||||
}
|
||||
if(tail_num == ck_tile::TailNumber::Five)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
|
||||
}
|
||||
if(tail_num == ck_tile::TailNumber::Six)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
|
||||
}
|
||||
if(tail_num == ck_tile::TailNumber::Seven)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
|
||||
}
|
||||
throw std::runtime_error("The tile number is wrong! It should not exceed the prefetch stage numbers");
|
||||
}
|
||||
"""
|
||||
|
||||
RUN_COMPV3 = """
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("The tail number is wrong. It should be Full, Odd, or Even.");
|
||||
}
|
||||
"""
|
||||
|
||||
RUN_COMPV4 = """
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
PIPELINE_MAP = {'mem' : ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'],
|
||||
'compv3' : ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'],
|
||||
'compv4' : ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']}
|
||||
|
||||
SCHEDULER_MAP = {'interwave' : 'ck_tile::GemmPipelineScheduler::Interwave',
|
||||
'intrawave' : 'ck_tile::GemmPipelineScheduler::Intrawave'}
|
||||
|
||||
EPILOGUE_MAP = {'default' :DEFAULT_EPILOGUE,
|
||||
'cshuffle' : CSHUFFLE_EPILOGUE}
|
||||
|
||||
HOT_LOOP_TRUE = {'mem' : RUN_MEM,
|
||||
'compv3' : RUN_COMPV3,
|
||||
'compv4' : RUN_COMPV4}
|
||||
|
||||
|
||||
def BOOL_MAP(b_) -> str:
|
||||
if b_:
|
||||
return 'true'
|
||||
else:
|
||||
return 'false'
|
||||
|
||||
@dataclass
|
||||
class GemmConfig:
|
||||
def __init__(self, config_data):
|
||||
self.matrix_cfg : Dict[str, Any] = {}
|
||||
self.impl_cfg : Dict[str, Any] = {}
|
||||
for key, value in config_data.items():
|
||||
if key in ["datatype", "layout_a", "layout_b", "layout_c"]:
|
||||
self.matrix_cfg[key] = value
|
||||
else:
|
||||
self.impl_cfg[key] = value
|
||||
|
||||
@property
|
||||
def datatype(self) -> str:
|
||||
return self.matrix_cfg["datatype"]["values"][0]
|
||||
|
||||
@property
|
||||
def layouts(self) -> List[str]:
|
||||
return [
|
||||
self.matrix_cfg["layout_a"]["values"][0],
|
||||
self.matrix_cfg["layout_b"]["values"][0],
|
||||
self.matrix_cfg["layout_c"]["values"][0]
|
||||
]
|
||||
|
||||
|
||||
class GemmCodeGenerator:
|
||||
def __init__(self, output_dir: str, config: GemmConfig):
|
||||
self.output_dir = Path(output_dir)
|
||||
if not self.output_dir.exists():
|
||||
self.output_dir.mkdir()
|
||||
|
||||
self.config = config
|
||||
self.all_kernels = []
|
||||
self.unique_configs = []
|
||||
# Validate configurations
|
||||
self._validate_config()
|
||||
|
||||
def _validate_config(self):
|
||||
"""Validate matrix and implementation configurations"""
|
||||
# Matrix config validation
|
||||
for param in ["datatype", "layout_a", "layout_b", "layout_c"]:
|
||||
if len(self.config.matrix_cfg[param]["values"]) != 1:
|
||||
raise ValueError(f"Matrix config {param} must have exactly one value")
|
||||
|
||||
# Implementation traits validation
|
||||
required_params = ["tile_m", "tile_n", "tile_k", "warp_m", "warp_n", "warp_k",
|
||||
"warp_tile_m", "warp_tile_n", "warp_tile_k", "pipeline",
|
||||
"epilogue", "scheduler", "kPadM", "kPadN", "kPadK"]
|
||||
for param in required_params:
|
||||
if not self.config.impl_cfg.get(param, {}).get("values"):
|
||||
raise ValueError(f"Missing implementation parameter: {param}")
|
||||
|
||||
def list_all(self):
|
||||
"""List all possible kernel configurations"""
|
||||
w_p = Path(self.output_dir)
|
||||
list_p = w_p / 'gemm_instance_blobs.txt'
|
||||
self._list_config_groups()
|
||||
with list_p.open('w') as list_f:
|
||||
list_f.write(str(w_p / ("gemm_common.hpp")) + "\n")
|
||||
list_f.write(str(w_p / ("gemm_instances.hpp")) + "\n")
|
||||
list_f.write(str(w_p / ("gemm_dispatcher.hpp")) + "\n")
|
||||
for group in self.all_kernels:
|
||||
list_f.write(str(w_p / ("gemm_" + group + ".hpp")) + "\n")
|
||||
|
||||
|
||||
|
||||
def _list_config_groups(self):
|
||||
params = [
|
||||
("pipeline", "pipeline"),
|
||||
("epilogue", "epilogue"),
|
||||
("scheduler", "scheduler"),
|
||||
("kPadM", "kPadM"),
|
||||
("kPadN", "kPadN"),
|
||||
("kPadK", "kPadK")
|
||||
]
|
||||
|
||||
# Generate all unique_combinations
|
||||
_unique = set(itertools.product(*[self.config.impl_cfg[p]["values"] for (p, _) in params]))
|
||||
for combo in _unique:
|
||||
config = {name: value for (_, name), value in zip(params, combo)}
|
||||
pipeline, epilogue, scheduler, kPadM, kPadN, kPadK = config.values()
|
||||
# To remove some unsupported combinations
|
||||
unsupported_combination = [("compv3", "cshuffle", "interwave"),
|
||||
("compv3", "default", "interwave"),
|
||||
("compv4", "cshuffle", "interwave"),
|
||||
("compv4", "default", "interwave")]
|
||||
if (pipeline, epilogue, scheduler) not in unsupported_combination:
|
||||
group_name = f"{pipeline}_{epilogue}_{scheduler}_pad_{BOOL_MAP(kPadM)}_{BOOL_MAP(kPadN)}_{BOOL_MAP(kPadK)}"
|
||||
self.all_kernels.append(group_name)
|
||||
self.unique_configs.append(config)
|
||||
|
||||
def generate_all(self):
|
||||
self._generate_common_header()
|
||||
self._generate_config_groups()
|
||||
self._generate_dispatcher()
|
||||
|
||||
|
||||
def _generate_common_header(self):
|
||||
"""Generate common header with datatypes and layout"""
|
||||
ctype = self.config.datatype
|
||||
atype = self.config.datatype
|
||||
btype = self.config.datatype
|
||||
if self.config.datatype in ['fp8', 'bf8']:
|
||||
ctype = 'fp16'
|
||||
elif self.config.datatype in ['int4']:
|
||||
atype = 'fp16'
|
||||
ctype = 'fp16'
|
||||
|
||||
content = f"""// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
// Data types
|
||||
using ADataType = {DATA_TYPE_MAP[atype]};
|
||||
using BDataType = {DATA_TYPE_MAP[btype]};
|
||||
using AccDataType = float;
|
||||
using CDataType = {DATA_TYPE_MAP[ctype]};
|
||||
|
||||
// Layout configurations
|
||||
using ALayout = {LAYOUT_MAP[self.config.layouts[0]]};
|
||||
using BLayout = {LAYOUT_MAP[self.config.layouts[1]]};
|
||||
using CLayout = {LAYOUT_MAP[self.config.layouts[2]]};
|
||||
"""
|
||||
|
||||
|
||||
(self.output_dir / "gemm_common.hpp").write_text(content)
|
||||
|
||||
def _generate_config_groups(self):
|
||||
"""Generate implementation configuration groups"""
|
||||
if not self.unique_configs: # Check if the list is empty
|
||||
self._list_config_groups()
|
||||
for config in self.unique_configs:
|
||||
self._generate_config_group(**config)
|
||||
self.generate_common_instances_header()
|
||||
|
||||
|
||||
def _generate_config_group(self, pipeline: str, epilogue: str, scheduler: str,
|
||||
kPadM: bool, kPadN: bool, kPadK: bool):
|
||||
"""Generate a configuration group with all tile/warp combinations"""
|
||||
group_name = f"{pipeline}_{epilogue}_{scheduler}_pad_{BOOL_MAP(kPadM)}_{BOOL_MAP(kPadN)}_{BOOL_MAP(kPadK)}"
|
||||
filename = f"gemm_{group_name}.hpp"
|
||||
|
||||
content = f"""// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_common.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
namespace {group_name} {{
|
||||
"""
|
||||
# Add template struct with configuration
|
||||
content += self._generate_kernel_struct(pipeline, epilogue, scheduler, kPadM, kPadN, kPadK)
|
||||
|
||||
content += f"\n}} // namespace {group_name}\n"
|
||||
(self.output_dir / filename).write_text(content)
|
||||
|
||||
def _generate_kernel_struct(self, pipeline: str, epilogue: str, scheduler: str,
|
||||
kPadM: bool, kPadN: bool, kPadK: bool) -> str:
|
||||
"""Generate kernel struct template"""
|
||||
return f"""
|
||||
template <int TileM, int TileN, int TileK,
|
||||
int WarpM, int WarpN, int WarpK,
|
||||
int WarpTileM, int WarpTileN, int WarpTileK>
|
||||
struct GemmKernel {{
|
||||
static constexpr bool kPadM = {BOOL_MAP(kPadM)};
|
||||
static constexpr bool kPadN = {BOOL_MAP(kPadN)};
|
||||
static constexpr bool kPadK = {BOOL_MAP(kPadK)};
|
||||
|
||||
static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) {{
|
||||
static constexpr bool permuteA = false;
|
||||
static constexpr bool permuteB = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool TransposeC = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<TileM, TileN, TileK>,
|
||||
ck_tile::sequence<WarpM, WarpN, WarpK>,
|
||||
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>,
|
||||
permuteA,
|
||||
permuteB>;
|
||||
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
TileParitionerGroupNum,
|
||||
TileParitionerM01>;
|
||||
|
||||
using Traits =
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
||||
ALayout, BLayout, CLayout, TransposeC>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * TileK;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{{0}};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = {SCHEDULER_MAP[scheduler]};
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = {PIPELINE_MAP[pipeline][1]}<UniversalGemmProblem>;
|
||||
{EPILOGUE_MAP[epilogue]}
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
||||
}}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
|
||||
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
|
||||
<< std::endl;
|
||||
}}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
|
||||
Kernel{{}}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
|
||||
}};
|
||||
|
||||
if(has_hot_loop) {{
|
||||
{HOT_LOOP_TRUE[pipeline]}
|
||||
}} else {{
|
||||
{HOT_LOOP_FALSE}
|
||||
}}
|
||||
|
||||
return ave_time;
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
|
||||
def generate_common_instances_header(self):
|
||||
"""Generate common instances header"""
|
||||
content = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
"""
|
||||
for group in self.all_kernels:
|
||||
content += f"#include \"gemm_{group}.hpp\"\n"
|
||||
(self.output_dir / "gemm_instances.hpp").write_text(content)
|
||||
|
||||
def _generate_dispatcher(self):
|
||||
"""Generate dispatch mechanism"""
|
||||
content = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_common.hpp"
|
||||
#include "gemm_instances.hpp"
|
||||
#include "gemm_host_api.hpp"
|
||||
#include <unordered_map>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
struct GemmDispatcher {
|
||||
static auto& get_kernel_map() {
|
||||
// Use a static local variable
|
||||
static std::unordered_map<std::string,
|
||||
std::function<float(ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>> kernel_map;
|
||||
return kernel_map;
|
||||
}
|
||||
|
||||
static void init() {
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(!kernel_map.empty()) return;
|
||||
\n"""
|
||||
# Add tile/warp instantiations
|
||||
tile_params = set(itertools.product(
|
||||
self.config.impl_cfg["tile_m"]["values"],
|
||||
self.config.impl_cfg["tile_n"]["values"],
|
||||
self.config.impl_cfg["tile_k"]["values"],
|
||||
self.config.impl_cfg["warp_m"]["values"],
|
||||
self.config.impl_cfg["warp_n"]["values"],
|
||||
self.config.impl_cfg["warp_k"]["values"],
|
||||
self.config.impl_cfg["warp_tile_m"]["values"],
|
||||
self.config.impl_cfg["warp_tile_n"]["values"],
|
||||
self.config.impl_cfg["warp_tile_k"]["values"]
|
||||
))
|
||||
|
||||
|
||||
for group in self.all_kernels:
|
||||
content += f""" kernel_map["{group}"] = [](ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& s) {{
|
||||
std::vector<float> results;"""
|
||||
for tile in tile_params:
|
||||
# Check if we have valid tile/warp combinations
|
||||
# (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m
|
||||
if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \
|
||||
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
|
||||
continue
|
||||
content += f"""
|
||||
//we can have multiple tiles config for the one kernel_trait
|
||||
return {group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>::launch(args, s);"""
|
||||
content += """
|
||||
};\n"""
|
||||
|
||||
content += """ }
|
||||
|
||||
|
||||
static float dispatch(const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
|
||||
const ck_tile::stream_config& s) {
|
||||
init();
|
||||
const std::string key = assemble_key(trait);
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(auto it = kernel_map.find(key); it != kernel_map.end()) {
|
||||
return it->second(gemm_args, s); //Running single instance
|
||||
}
|
||||
throw std::runtime_error("No suitable kernel found: " + key);
|
||||
}
|
||||
|
||||
private:
|
||||
static std::string assemble_key(const KernelTraits &trait) {
|
||||
return std::string(trait.pipeline) + "_" +
|
||||
trait.epilogue + "_" +
|
||||
trait.scheduler + "_" +
|
||||
"pad_" +
|
||||
(trait.kPadM ? "true" : "false") + "_" +
|
||||
(trait.kPadN ? "true" : "false") + "_" +
|
||||
(trait.kPadK ? "true" : "false");
|
||||
}
|
||||
};
|
||||
|
||||
"""
|
||||
(self.output_dir / "gemm_dispatcher.hpp").write_text(content)
|
||||
|
||||
|
||||
def do_list_blobs(args, gemm_config):
|
||||
generator = GemmCodeGenerator(args.working_path, gemm_config)
|
||||
generator.list_all()
|
||||
|
||||
def do_gen_blobs(args, gemm_config):
|
||||
generator = GemmCodeGenerator(args.working_path, gemm_config)
|
||||
generator.generate_all()
|
||||
|
||||
|
||||
|
||||
def main(args):
|
||||
# Read and validate json file
|
||||
with open(args.json, 'r') as json_file:
|
||||
config_data = json.load(json_file)
|
||||
|
||||
# Validate and parse configuration
|
||||
gemm_config = GemmConfig(config_data)
|
||||
|
||||
if args.list_blobs:
|
||||
do_list_blobs(args, gemm_config)
|
||||
elif args.gen_blobs:
|
||||
do_gen_blobs(args, gemm_config)
|
||||
else:
|
||||
# If neither was specified, either do nothing or default to gen_blobs
|
||||
print("No mode specified (use --list_blobs or --gen_blobs). Generating by default...")
|
||||
do_gen_blobs(args, gemm_config)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate",
|
||||
description="gen API for CK gemm kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w", "--working_path", default="./", required=False, help="the path where all the blobs are going to be generated"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-j", "--json", required=True, help="Path to the json which contains the kernel configurations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l", "--list_blobs", action = 'store_true', help="List all kernel to file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g", "--gen_blobs", action = 'store_true', help="Generate all kernels into different files"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
Reference in New Issue
Block a user