Reorganize project folders (#6)

This commit is contained in:
Joseph Macaranas
2025-04-30 13:46:39 -04:00
committed by GitHub
commit 1eb2e57380
3952 changed files with 654944 additions and 0 deletions

5
tile_engine/CMakeLists.txt Executable file
View File

@@ -0,0 +1,5 @@
include_directories(BEFORE
${CMAKE_CURRENT_LIST_DIR}/include
)
add_subdirectory(ops)

View File

@@ -0,0 +1 @@
message("Add include directory")

1
tile_engine/ops/CMakeLists.txt Executable file
View File

@@ -0,0 +1 @@
add_subdirectory(gemm)

View File

@@ -0,0 +1,51 @@
# generate a list of kernels, but not actually emit files at config stage
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
--list_blobs
RESULT_VARIABLE ret
)
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS
${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS)
add_custom_command(
OUTPUT ${GEMM_CODEGEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
--gen_blobs
DEPENDS ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt
${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
)
set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm")
message("adding example ${EXECUTABLE_GEMM_INSTANCE}")
# use build as include directory
include_directories(${CMAKE_CURRENT_BINARY_DIR})
add_executable(${EXECUTABLE_GEMM_INSTANCE} EXCLUDE_FROM_ALL gemm_host_api.cpp)
target_include_directories(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${GEMM_CODEGEN_BLOBS})
set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS)
list(APPEND EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress)
target_compile_options(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS})
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

View File

@@ -0,0 +1,92 @@
# GEMM Matrix Multiplication
CK Tile Engine GEMM is used to generate and run GEMM kernels with different combinations of BlockTile sizes, WarpTile sizes, WarpTile mapping for all valid pipelines, schedulers and epilogues.
# Kernel Configurations
Kernel parameters are specified in the `instance_combination.json` file, including matrix layouts, data types, padding settings, pipelines, schedulers, epilogues, and numerical values for tile and warp sizes.
Given a valid set of values, tile_engine_gemm will automatically iterate over all possible combinations of BlockTile and WarpTile sizes, as well as the specified pipelines, schedulers, and epilogues from `./configs/instance_combination.json`, and build the corresponding kernels.
## Build Instructions
``` bash
# in the root of composable kernel create build directory
mkdir build && cd build
# build composable kernel
sh ../script/cmake-ck-dev.sh ../ <arch> # replace <arch> with the appropriate architecture (example gfx942) or leave blank
# generate the executable
make tile_engine_gemm -j
```
`tile_engine_gemm` will be located in the `./bin/` directory.
_`tile_engine_gemm` must be rebuilt everytime `instance_combination.json` is modified._
``` bash
rm -rf tile_engine/ && make tile_engine_gemm -j # rebuild
```
## tile_engine_gemm inputs
```
-m m dimension (default:3840)
-n n dimension (default:4096)
-k k dimension (default:2048)
-stride_a Tensor A stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-split_k SplitK value (default:1)
-v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2)
-warmup Number of iterations before benchmark the kernel (default:50)
-repeat Number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0)
-structured_sparsity Sparsity for tensor - 0:false, 1:true (default: 0)
-pipeline possible values are: compv3, compv4, mem (default:compv3)
-scheduler possible values are: intrawave, interwave (default:intrawave)
-epilogue possible values are: cshuffle, default (default:cshuffle)
-pad_m Pad in m direction - true/false (default:false)
-pad_n Pad in n direction - true/false (default:false)
-pad_k Pad in k direction - true/false (default:false)
Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in instance_combination.json
```
Note: In `./configs/instance_combination.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above.
## Example
The following JSON file specifies parameters used to generate and build GEMM kernels across all possible combinations of pipelines, schedulers, epilogues with different tile and warp sizes.
```json
{
/// other parameters ///
"tile_m": {
"values": [256]
},
"tile_n": {
"values": [256]
},
"tile_k": {
"values": [64, 32]
},
/// other parameters ///
"pipeline": {
"values": ["compv3", "compv4", "mem"]
},
"scheduler": {
"values": ["intrawave", "interwave"]
},
"epilogue": {
"values": ["default", "cshuffle"]
}
}
```
At runtime, a specific subset of the generated kernels can be selected using command-line arguments.
``` bash
./bin/tile_engine_gemm -pipeline=compv3 -scheduler=intrawave -epilogue=default
```
The above command runs kernels configured with the compv3 pipeline, intrawave scheduler, and default epilogue, while sweeping over different BlockTile sizes, WarpTile sizes, and WarpTile mappings.

View File

@@ -0,0 +1,60 @@
{
"layout_a": {
"values": ["r"]
},
"layout_b": {
"values": ["c"]
},
"layout_c": {
"values": ["r"]
},
"datatype": {
"values": ["fp16"]
},
"tile_m": {
"values": [256]
},
"tile_n": {
"values": [256]
},
"tile_k": {
"values": [64, 32]
},
"warp_m": {
"values": [2]
},
"warp_n": {
"values": [2]
},
"warp_k": {
"values": [1]
},
"warp_tile_m": {
"values": [32]
},
"warp_tile_n": {
"values": [32]
},
"warp_tile_k": {
"values": [16]
},
"kPadM": {
"values": [false]
},
"kPadN": {
"values": [false]
},
"kPadK": {
"values": [false]
},
"pipeline": {
"values": ["compv3", "compv4", "mem"]
},
"scheduler": {
"values": ["intrawave", "interwave"]
},
"epilogue": {
"values": ["default", "cshuffle"]
}
}

View File

@@ -0,0 +1,192 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/host.hpp"
#include "gemm_common.hpp"
#include "gemm_dispatcher.hpp"
#include "gemm_host_api.hpp"
void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify,
bool structured_sparsity,
KernelTraits& trait,
ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& stream)
{
return GemmDispatcher::dispatch(c_m_n_dev_buf,
c_m_n_host_result,
c_m_n_dev_result,
verify,
structured_sparsity,
trait,
args,
stream);
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
void run(const ck_tile::ArgParser& arg_parser)
{
const ALayout a_layout = ALayout{};
const BLayout b_layout = BLayout{};
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
int verify = arg_parser.get_int("v");
ck_tile::index_t init_method = arg_parser.get_int("init");
bool structured_sparsity = arg_parser.get_bool("structured_sparsity");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
}
else if(init_method == 2)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(1)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(1)}(b_k_n);
}
else
{
a_m_k.SetZero();
b_k_n.SetZero();
}
if(structured_sparsity)
{
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
// permute_tensor_b<decltype(b_k_n_dev)>(b_k_n_dev);
permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else
{
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
a_m_k_dev_buf.ToDevice(a_m_k.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
ck_tile::GemmHostArgs gemm_args;
gemm_args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
gemm_args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
gemm_args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
gemm_args.k_batch = kbatch;
gemm_args.M = M;
gemm_args.N = N;
gemm_args.K = K;
gemm_args.stride_A = stride_A;
gemm_args.stride_B = stride_B;
gemm_args.stride_C = stride_C;
KernelTraits trait;
trait.pipeline = arg_parser.get_str("pipeline");
trait.scheduler = arg_parser.get_str("scheduler");
trait.epilogue = arg_parser.get_str("epilogue");
trait.kPadM = arg_parser.get_bool("pad_m");
trait.kPadN = arg_parser.get_bool("pad_n");
trait.kPadK = arg_parser.get_bool("pad_k");
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name
<< " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits<ADataType>::name
<< " B Type = " << DataTypeTraits<BDataType>::name
<< " C Type = " << DataTypeTraits<CDataType>::name << std::endl;
ck_tile::HostTensor<CDataType> c_m_n_host_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
if(verify)
{
gemm_host_reference<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(verify,
a_m_k,
b_k_n,
c_m_n_host_result,
a_m_k_dev_buf,
b_k_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C);
}
gemm_kernel_launch(c_m_n_dev_buf,
c_m_n_host_result,
c_m_n_dev_result,
verify,
structured_sparsity,
trait,
gemm_args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
return;
}
int main(int argc, char* argv[])
{
try
{
auto [result, parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
run<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(parser);
return 0;
}
catch(const std::exception& e)
{
std::cerr << "Error: " << e.what() << "\n";
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,263 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <sstream>
#include <string>
#include <tuple>
#include "ck_tile/ops/gemm.hpp"
#pragma once
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
/// @brief Defines the configuration parameters for a GEMM operation, enabling the selection of a
/// specific kernel instance based on the provided settings.
struct KernelTraits
{
/// @brief The name of the pipeline.
std::string pipeline;
/// @brief The name of the scheduler (e.g., "intrawave", "interwave").
std::string scheduler;
/// @brief The name of the epilogue (e.g., "cshuffle", "default").
std::string epilogue;
/// @brief Indicates whether padding is applied to the M dimension.
bool kPadM;
/// @brief Indicates whether padding is applied to the N dimension.
bool kPadN;
/// @brief Indicates whether padding is applied to the K dimension.
bool kPadK;
};
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
inline auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("split_k", "1", "splitK value")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("structured_sparsity", "0", "0:false, 1:true")
.insert("pipeline", "compv3", "compv3, compv4, mem")
.insert("scheduler", "intrawave", "intrawave, interwave")
.insert("epilogue", "cshuffle", "cshuffle, default")
.insert("pad_m", "false", "true, false")
.insert("pad_n", "false", "true, false")
.insert("pad_k", "false", "true, false");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename Tensor>
void permute_vectors_i4x4_b(Tensor& tensor)
{
const ck_tile::index_t K = tensor.get_length(0);
const ck_tile::index_t N = tensor.get_length(1);
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
{
int8_t input[8];
for(int k = 0; k < 4; k++)
{
int8_t i4x2 = tensor(j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->20643175
{
int8_t hi = input[2];
int8_t lo = input[0];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 0, i) = i4x2;
}
{
int8_t hi = input[6];
int8_t lo = input[4];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 2, i) = i4x2;
}
{
int8_t hi = input[3];
int8_t lo = input[1];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 4, i) = i4x2;
}
{
int8_t hi = input[7];
int8_t lo = input[5];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 6, i) = i4x2;
}
}
}
}
/// @brief Function to compare the results of the device and host computations
void compare(ck_tile::index_t K,
ck_tile::index_t kbatch,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
{
const float max_accumulated_value =
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_result,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
/// @brief Function to get the kernel output with reference implementation on CPU/GPU
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
void gemm_host_reference(int verify,
ck_tile::HostTensor<ADataType>& a_m_k,
ck_tile::HostTensor<BDataType>& b_k_n,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C)
{
if(verify == 1)
{
c_m_n_host_result.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_result);
}
else if(verify == 2)
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Restore input for B for gpu reference
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes());
c_m_n_host_result.SetZero();
c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data());
}
}

View File

@@ -0,0 +1,644 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import argparse
from enum import IntEnum
from pathlib import Path
import sys
from typing import List, Optional, Dict, Any
import functools
import itertools
import copy
import json
from dataclasses import dataclass
DATA_TYPE_MAP = {'fp32' : 'float',
'fp16' : 'ck_tile::half_t',
'bf16' : 'ck_tile::bf16_t',
'int8' : 'ck_tile::int8_t',
'fp8' : 'ck_tile::fp8_t',
'bf8' : 'ck_tile::bf8_t',
'int4' : 'ck_tile::pk_int4_t'
}
LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor',
'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'}
DEFAULT_EPILOGUE = """
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
kPadM,
kPadN,
WarpTileM,
WarpTileN,
WarpTileK,
UniversalGemmProblem::TransposeC>>;
"""
CSHUFFLE_EPILOGUE = """
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
WarpM,
WarpN,
WarpTileM,
WarpTileN,
WarpTileK,
UniversalGemmProblem::TransposeC>>;
"""
HOT_LOOP_FALSE = """
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("Num K loop must be larger than number of prefetech stages.");
}
"""
RUN_MEM = """
if(tail_num == ck_tile::TailNumber::One)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
if(tail_num == ck_tile::TailNumber::Two)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
if(tail_num == ck_tile::TailNumber::Four)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
}
if(tail_num == ck_tile::TailNumber::Five)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
}
if(tail_num == ck_tile::TailNumber::Six)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
}
if(tail_num == ck_tile::TailNumber::Seven)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
}
throw std::runtime_error("The tile number is wrong! It should not exceed the prefetch stage numbers");
}
"""
RUN_COMPV3 = """
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("The tail number is wrong. It should be Full, Odd, or Even.");
}
"""
RUN_COMPV4 = """
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
else
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
"""
PIPELINE_MAP = {'mem' : ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'],
'compv3' : ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'],
'compv4' : ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']}
SCHEDULER_MAP = {'interwave' : 'ck_tile::GemmPipelineScheduler::Interwave',
'intrawave' : 'ck_tile::GemmPipelineScheduler::Intrawave'}
EPILOGUE_MAP = {'default' :DEFAULT_EPILOGUE,
'cshuffle' : CSHUFFLE_EPILOGUE}
HOT_LOOP_TRUE = {'mem' : RUN_MEM,
'compv3' : RUN_COMPV3,
'compv4' : RUN_COMPV4}
def BOOL_MAP(b_) -> str:
if b_:
return 'true'
else:
return 'false'
@dataclass
class GemmConfig:
def __init__(self, config_data):
self.matrix_cfg : Dict[str, Any] = {}
self.impl_cfg : Dict[str, Any] = {}
for key, value in config_data.items():
if key in ["datatype", "layout_a", "layout_b", "layout_c"]:
self.matrix_cfg[key] = value
else:
self.impl_cfg[key] = value
@property
def datatype(self) -> str:
return self.matrix_cfg["datatype"]["values"][0]
@property
def layouts(self) -> List[str]:
return [
self.matrix_cfg["layout_a"]["values"][0],
self.matrix_cfg["layout_b"]["values"][0],
self.matrix_cfg["layout_c"]["values"][0]
]
class GemmCodeGenerator:
def __init__(self, output_dir: str, config: GemmConfig):
self.output_dir = Path(output_dir)
if not self.output_dir.exists():
self.output_dir.mkdir()
self.config = config
self.all_kernels = []
self.unique_configs = []
# Validate configurations
self._validate_config()
def _validate_config(self):
"""Validate matrix and implementation configurations"""
# Matrix config validation
for param in ["datatype", "layout_a", "layout_b", "layout_c"]:
if len(self.config.matrix_cfg[param]["values"]) != 1:
raise ValueError(f"Matrix config {param} must have exactly one value")
# Implementation traits validation
required_params = ["tile_m", "tile_n", "tile_k", "warp_m", "warp_n", "warp_k",
"warp_tile_m", "warp_tile_n", "warp_tile_k", "pipeline",
"epilogue", "scheduler", "kPadM", "kPadN", "kPadK"]
for param in required_params:
if not self.config.impl_cfg.get(param, {}).get("values"):
raise ValueError(f"Missing implementation parameter: {param}")
def list_all(self):
"""List all possible kernel configurations"""
w_p = Path(self.output_dir)
list_p = w_p / 'gemm_instance_blobs.txt'
self._list_config_groups()
with list_p.open('w') as list_f:
list_f.write(str(w_p / ("gemm_common.hpp")) + "\n")
list_f.write(str(w_p / ("gemm_instances.hpp")) + "\n")
list_f.write(str(w_p / ("gemm_dispatcher.hpp")) + "\n")
for group in self.all_kernels:
list_f.write(str(w_p / ("gemm_" + group + ".hpp")) + "\n")
def _list_config_groups(self):
params = [
("pipeline", "pipeline"),
("epilogue", "epilogue"),
("scheduler", "scheduler"),
("kPadM", "kPadM"),
("kPadN", "kPadN"),
("kPadK", "kPadK")
]
# Generate all unique_combinations
_unique = set(itertools.product(*[self.config.impl_cfg[p]["values"] for (p, _) in params]))
for combo in _unique:
config = {name: value for (_, name), value in zip(params, combo)}
pipeline, epilogue, scheduler, kPadM, kPadN, kPadK = config.values()
# To remove some unsupported combinations
unsupported_combination = [("compv3", "cshuffle", "interwave"),
("compv3", "default", "interwave"),
("compv4", "cshuffle", "interwave"),
("compv4", "default", "interwave")]
if (pipeline, epilogue, scheduler) not in unsupported_combination:
group_name = f"{pipeline}_{epilogue}_{scheduler}_pad_{BOOL_MAP(kPadM)}_{BOOL_MAP(kPadN)}_{BOOL_MAP(kPadK)}"
self.all_kernels.append(group_name)
self.unique_configs.append(config)
def generate_all(self):
self._generate_common_header()
self._generate_config_groups()
self._generate_dispatcher()
def _generate_common_header(self):
"""Generate common header with datatypes and layout"""
ctype = self.config.datatype
atype = self.config.datatype
btype = self.config.datatype
if self.config.datatype in ['fp8', 'bf8']:
ctype = 'fp16'
elif self.config.datatype in ['int4']:
atype = 'fp16'
ctype = 'fp16'
content = f"""// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
// Data types
using ADataType = {DATA_TYPE_MAP[atype]};
using BDataType = {DATA_TYPE_MAP[btype]};
using AccDataType = float;
using CDataType = {DATA_TYPE_MAP[ctype]};
// Layout configurations
using ALayout = {LAYOUT_MAP[self.config.layouts[0]]};
using BLayout = {LAYOUT_MAP[self.config.layouts[1]]};
using CLayout = {LAYOUT_MAP[self.config.layouts[2]]};
"""
(self.output_dir / "gemm_common.hpp").write_text(content)
def _generate_config_groups(self):
"""Generate implementation configuration groups"""
if not self.unique_configs: # Check if the list is empty
self._list_config_groups()
for config in self.unique_configs:
self._generate_config_group(**config)
self.generate_common_instances_header()
def _generate_config_group(self, pipeline: str, epilogue: str, scheduler: str,
kPadM: bool, kPadN: bool, kPadK: bool):
"""Generate a configuration group with all tile/warp combinations"""
group_name = f"{pipeline}_{epilogue}_{scheduler}_pad_{BOOL_MAP(kPadM)}_{BOOL_MAP(kPadN)}_{BOOL_MAP(kPadK)}"
filename = f"gemm_{group_name}.hpp"
content = f"""// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_common.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/host.hpp"
namespace {group_name} {{
"""
# Add template struct with configuration
content += self._generate_kernel_struct(pipeline, epilogue, scheduler, kPadM, kPadN, kPadK)
content += f"\n}} // namespace {group_name}\n"
(self.output_dir / filename).write_text(content)
def _generate_kernel_struct(self, pipeline: str, epilogue: str, scheduler: str,
kPadM: bool, kPadN: bool, kPadK: bool) -> str:
"""Generate kernel struct template"""
return f"""
template <int TileM, int TileN, int TileK,
int WarpM, int WarpN, int WarpK,
int WarpTileM, int WarpTileN, int WarpTileK,
bool structured_sparsity>
struct GemmKernel {{
static constexpr bool kPadM = {BOOL_MAP(kPadM)};
static constexpr bool kPadN = {BOOL_MAP(kPadN)};
static constexpr bool kPadK = {BOOL_MAP(kPadK)};
static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) {{
static constexpr bool permuteA = false;
static constexpr bool permuteB = false;
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
static constexpr bool TransposeC = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<TileM, TileN, TileK>,
ck_tile::sequence<WarpM, WarpN, WarpK>,
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>,
permuteA,
permuteB>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
TileParitionerGroupNum,
TileParitionerM01>;
using Traits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, CLayout, TransposeC, structured_sparsity>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * TileK;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{{0}};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = {SCHEDULER_MAP[scheduler]};
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = {PIPELINE_MAP[pipeline][1]}<UniversalGemmProblem>;
{EPILOGUE_MAP[epilogue]}
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}}
if(s.log_level_ > 0)
{{
std::cout << "Launching kernel with args:"
<< " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
<< std::endl;
}}
ave_time = ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
Kernel{{}}, grids, blocks, 0, kargs));
return ave_time;
}};
if(has_hot_loop) {{
{HOT_LOOP_TRUE[pipeline]}
}} else {{
{HOT_LOOP_FALSE}
}}
return ave_time;
}}
static std::string get_name() {{
return std::string("GemmKernel<Bllktile: ") + std::to_string(TileM) + "x" + std::to_string(TileN) + "x" + std::to_string(TileK) + ", " +
"WaveMap: " + std::to_string(WarpM) + "x" + std::to_string(WarpN) + "x" + std::to_string(WarpK) + ", " +
"WarpTile: " + std::to_string(WarpTileM) + "x" + std::to_string(WarpTileN) + "x" + std::to_string(WarpTileK) + ", " +
"PadidngM: " + "{kPadM}" + ", " +
"PaddingN: " + "{kPadN}" + ", " +
"PaddingK: " + "{kPadK}" + ", " +
"Pipeline: " + "{pipeline}" + ", " +
"Epilogue: " + "{epilogue}" + ", " +
"Scheduler: " + "{scheduler}";
}}
}};
"""
def generate_common_instances_header(self):
"""Generate common instances header"""
content = """// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
"""
for group in self.all_kernels:
content += f"#include \"gemm_{group}.hpp\"\n"
(self.output_dir / "gemm_instances.hpp").write_text(content)
def _generate_dispatcher(self):
"""Generate dispatch mechanism"""
content = """// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_common.hpp"
#include "gemm_instances.hpp"
#include "gemm_host_api.hpp"
#include <unordered_map>
#include <functional>
#include <vector>
struct GemmDispatcher {
static auto& get_kernel_map() {
// Use a static local variable
static std::unordered_map<std::string,
std::function<void(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>> kernel_map;
return kernel_map;
}
static void init(bool structured_sparsity) {
auto& kernel_map = get_kernel_map();
if(!kernel_map.empty()) return;
\n"""
# Add tile/warp instantiations
tile_params = set(itertools.product(
self.config.impl_cfg["tile_m"]["values"],
self.config.impl_cfg["tile_n"]["values"],
self.config.impl_cfg["tile_k"]["values"],
self.config.impl_cfg["warp_m"]["values"],
self.config.impl_cfg["warp_n"]["values"],
self.config.impl_cfg["warp_k"]["values"],
self.config.impl_cfg["warp_tile_m"]["values"],
self.config.impl_cfg["warp_tile_n"]["values"],
self.config.impl_cfg["warp_tile_k"]["values"]
))
for group in self.all_kernels:
content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& stream) {{
"""
for tile in tile_params:
# Check if we have valid tile/warp combinations
# (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m
if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
continue
content += f"""
if(structured_sparsity) {{
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {1}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);
}} else {{
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {0}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);
}}"""
content += f"""
}};\n"""
content += """ }
template <typename Kernel>
static void run_kernel(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream)
{
float avg_time = Kernel::launch(args, stream);
std::string description = Kernel::get_name();
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
std::size_t num_byte = sizeof(ADataType) * args.M * args.K + sizeof(BDataType) * args.N * args.K + sizeof(CDataType) * args.M * args.N;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_byte / 1.E6 / avg_time;
std::cout << "Performance for " << description << " : " << avg_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
if(verify)
compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result);
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
}
static auto dispatch(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, bool structured_sparsity, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
const ck_tile::stream_config& stream) {
init(structured_sparsity);
const std::string key = assemble_key(trait);
auto& kernel_map = get_kernel_map();
if(auto it = kernel_map.find(key); it != kernel_map.end()) {
return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, gemm_args, stream);
}
throw std::runtime_error("No suitable kernel found: " + key);
}
private:
static std::string assemble_key(const KernelTraits &trait) {
return std::string(trait.pipeline) + "_" +
trait.epilogue + "_" +
trait.scheduler + "_" +
"pad_" +
(trait.kPadM ? "true" : "false") + "_" +
(trait.kPadN ? "true" : "false") + "_" +
(trait.kPadK ? "true" : "false");
}
};
"""
(self.output_dir / "gemm_dispatcher.hpp").write_text(content)
def do_list_blobs(args, gemm_config):
generator = GemmCodeGenerator(args.working_path, gemm_config)
generator.list_all()
def do_gen_blobs(args, gemm_config):
generator = GemmCodeGenerator(args.working_path, gemm_config)
generator.generate_all()
def main(args):
# Read json file
with open(args.json, 'r') as json_file:
config_data = json.load(json_file)
gemm_config = GemmConfig(config_data)
if args.list_blobs:
do_list_blobs(args, gemm_config)
elif args.gen_blobs:
do_gen_blobs(args, gemm_config)
else:
# If neither was specified, either do nothing or default to gen_blobs
print("No mode specified (use --list_blobs or --gen_blobs). Generating by default...")
do_gen_blobs(args, gemm_config)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate",
description="gen API for CK gemm kernel",
)
parser.add_argument(
"-w", "--working_path", default="./", required=False, help="the path where all the blobs are going to be generated"
)
parser.add_argument(
"-j", "--json", required=True, help="Path to the json which contains the kernel configurations"
)
parser.add_argument(
"-l", "--list_blobs", action = 'store_true', help="List all kernel to file"
)
parser.add_argument(
"-g", "--gen_blobs", action = 'store_true', help="Generate all kernels into different files"
)
args = parser.parse_args()
main(args)