Merge remote-tracking branch 'origin/develop' into samremes/bmatrix_2d_blockscale

This commit is contained in:
Sami Remes
2025-10-28 17:49:53 +00:00
142 changed files with 12769 additions and 1430 deletions

3
.gitignore vendored
View File

@@ -36,6 +36,9 @@ tags
# Editors
.vscode
# Cline
.cline*
# build-in-source directory (see exceptions below)
build*

View File

@@ -1,9 +1,8 @@
ARG BASE_DOCKER="rocm/composable_kernel-private:ck_aiter_base"
ARG BASE_DOCKER="rocm/pytorch:latest"
FROM $BASE_DOCKER
ARG AITER_BRANCH="main"
ARG CK_AITER_BRANCH="develop"
RUN groupadd irc && \
pip install pandas zmq einops && \
RUN pip install pandas zmq einops ninja && \
pip install numpy==1.26.2 && \
sudo mkdir /home/jenkins && \
sudo mkdir /home/jenkins/workspace && \
@@ -14,6 +13,8 @@ RUN groupadd irc && \
rm -rf 3rdparty/composable_kernel/ && \
git clone -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git 3rdparty/composable_kernel/ && \
python3 setup.py develop && \
chown -R jenkins:jenkins /home/jenkins/workspace && \
chmod -R a+rwx /home/jenkins/workspace && \
groupadd -g 1001 jenkins && \
useradd -u 1001 -g 1001 -m -s /bin/bash jenkins && \
chown -R jenkins:jenkins /home/jenkins && \
chmod -R a+rwx /home/jenkins && \
sudo usermod -aG irc jenkins

101
Jenkinsfile vendored
View File

@@ -194,6 +194,33 @@ def check_arch(){
return arch_type
}
def check_arch_name(){
def arch_name = ""
sh 'rocminfo | tee rocminfo.log'
if ( runShell('grep -n "gfx90a" rocminfo.log') ){
arch_name = "gfx90a"
}
else if ( runShell('grep -n "gfx942" rocminfo.log') ) {
arch_name = "gfx942"
}
else if ( runShell('grep -n "gfx10" rocminfo.log') ) {
arch_name = "gfx10"
}
else if ( runShell('grep -n "gfx11" rocminfo.log') ) {
arch_name = "gfx11"
}
else if ( runShell('grep -n "gfx12" rocminfo.log') ) {
arch_name = "gfx12"
}
else if ( runShell('grep -n "gfx908" rocminfo.log') ) {
arch_name = "gfx908"
}
else if ( runShell('grep -n "gfx950" rocminfo.log') ) {
arch_name = "gfx950"
}
return arch_name
}
def getDockerImage(Map conf=[:]){
env.DOCKER_BUILDKIT=1
def prefixpath = conf.get("prefixpath", "/opt/rocm")
@@ -302,12 +329,6 @@ def cmake_build(Map conf=[:]){
//cmake_env can overwrite default CXX variables.
def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","")
def package_build = (conf.get("package_build","") == "true")
if (package_build == true) {
config_targets = "package"
}
if(conf.get("build_install","") == "true")
{
config_targets = 'install ' + config_targets
@@ -455,15 +476,20 @@ def cmake_build(Map conf=[:]){
else{
sh "ninja check"
}
if(params.BUILD_PACKAGES){
echo "Build ckProfiler packages"
sh 'ninja -j64 package'
def arch_name = check_arch_name()
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb"
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}"
}
}
if(params.BUILD_INSTANCES_ONLY){
// build deb packages
echo "Build packages"
echo "Build library package"
sh 'ninja -j64 package'
archiveArtifacts artifacts: 'composablekernel-dev*.deb'
sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.2.0_amd64.deb'
sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64.deb'
stash includes: "composablekernel-**.deb", name: "packages"
stash includes: "composablekernel-dev**.deb", name: "lib_package"
}
}
else{
@@ -475,15 +501,18 @@ def cmake_build(Map conf=[:]){
else{
sh "ninja check"
}
if(params.BUILD_PACKAGES){
echo "Build ckProfiler packages"
sh 'ninja -j64 package'
def arch_name = check_arch_name()
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb"
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}"
}
}
}
}
}
// Only archive from develop
if (package_build == true && env.BRANCH_NAME == "develop") {
archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true
}
//check the node gpu architecture
def arch = check_arch()
if (params.RUN_CK_TILE_FMHA_TESTS){
@@ -823,9 +852,42 @@ def process_results(Map conf=[:]){
}
if (params.BUILD_INSTANCES_ONLY){
// unstash deb packages
unstash "packages"
try{
unstash "lib_package"
}
catch(Exception err){
echo "could not locate lib_package."
}
sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/"
}
if (params.BUILD_PACKAGES){
// unstash deb packages
try{
unstash "profiler_package_gfx90a"
}
catch(Exception err){
echo "could not locate profiler_package_gfx90a."
}
try{
unstash "profiler_package_gfx942"
}
catch(Exception err){
echo "could not locate profiler_package_gfx942."
}
try{
unstash "profiler_package_gfx950"
}
catch(Exception err){
echo "could not locate profiler_package_gfx950."
}
try{
unstash "profiler_package_gfx12"
}
catch(Exception err){
echo "could not locate profiler_package_gfx12."
}
sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-ckprofiler*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/"
}
else{
// unstash perf files to master
try{
@@ -993,7 +1055,7 @@ def run_pytorch_tests(Map conf=[:]){
//launch develop branch daily jobs
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_PERFORMANCE_TESTS=true;FORCE_CI=true
0 22 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true;BUILD_PACKAGES=true
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true
@@ -1085,6 +1147,10 @@ pipeline {
name: "BUILD_INSTANCES_ONLY",
defaultValue: false,
description: "Test building instances for various architectures simultaneously (default: OFF)")
booleanParam(
name: "BUILD_PACKAGES",
defaultValue: false,
description: "Build packages for the libraries and/or ckProfiler (default: OFF)")
booleanParam(
name: "BUILD_GFX908",
defaultValue: false,
@@ -1574,7 +1640,6 @@ pipeline {
-D GPU_TARGETS="gfx1201" \
-D GEMM_DATATYPE="fp16" \
-D GEMM_LAYOUT="rcr;rrr;crr;ccr" \
-DGEMM_CONFIG_FILE=gfx120x_config.json \
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
ninja -j64 benchmark_gemm_all && \
python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \
@@ -1830,7 +1895,7 @@ pipeline {
stage("Process results"){
when {
beforeAgent true
expression { (params.RUN_PERFORMANCE_TESTS.toBoolean() || params.BUILD_INSTANCES_ONLY.toBoolean() || params.RUN_CK_TILE_FMHA_TESTS.toBoolean()) && !params.BUILD_LEGACY_OS.toBoolean() }
expression { (params.RUN_PERFORMANCE_TESTS.toBoolean() || params.BUILD_INSTANCES_ONLY.toBoolean() || params.RUN_CK_TILE_FMHA_TESTS.toBoolean()|| params.BUILD_PACKAGES.toBoolean()) && !params.BUILD_LEGACY_OS.toBoolean() }
}
agent { label 'mici' }
steps{

View File

@@ -52,7 +52,7 @@ struct kernel
template <class... Ts>
auto launch(hipStream_t stream, std::size_t global, std::size_t local, Ts... zs) const
{
return [=](auto&&... xs) {
return [=, this](auto&&... xs) {
launch(stream, global, local, std::vector<kernel_argument>{xs...}, zs...);
};
}

View File

@@ -59,4 +59,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
#include "run_grouped_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
int main(int argc, char* argv[])
{
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
return 0;
}
return !run_grouped_gemm_example(argc, argv);
}

View File

@@ -278,19 +278,20 @@ bool run_grouped_gemm_example(int argc, char* argv[])
problem_size.group_count = 16;
if(argc == 4)
if(argc == 1)
{
// use default cases
}
else if(argc == 4 || argc == 6)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else if(argc == 6)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
config.async_hargs = std::stoi(argv[4]);
problem_size.group_count = std::stoi(argv[5]);
if(argc == 6)
{
config.async_hargs = std::stoi(argv[4]);
problem_size.group_count = std::stoi(argv[5]);
}
}
else
{
@@ -299,18 +300,33 @@ bool run_grouped_gemm_example(int argc, char* argv[])
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4: async hargs (0=n0, 1=yes)\n");
printf("arg5: group count (default=16)");
exit(0);
exit(1);
}
// Lambda to get stride based on layout
auto get_stride = [](auto layout, auto row_dim, auto col_dim) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return col_dim;
}
else
{
return row_dim;
}
};
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(256 + 256 * i);
problem_size.Ns.push_back(128 + 128 * i);
problem_size.Ks.push_back(128 + 64 * i);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
problem_size.stride_As.push_back(
get_stride(ALayout{}, problem_size.Ms[i], problem_size.Ks[i]));
problem_size.stride_Bs.push_back(
get_stride(BLayout{}, problem_size.Ks[i], problem_size.Ns[i]));
problem_size.stride_Cs.push_back(
get_stride(ELayout{}, problem_size.Ms[i], problem_size.Ns[i]));
}
return run_grouped_gemm(problem_size, config);

View File

@@ -82,37 +82,29 @@ int main(int argc, char* argv[])
bool do_verification = true;
bool time_kernel = true;
ck::index_t M = 48 * 256;
ck::index_t N = 1024;
if(argc == 1)
{
// use default
}
else if(argc == 3)
else if(argc == 3 || argc == 5)
{
do_verification = std::stoi(argv[1]);
time_kernel = std::stoi(argv[2]);
if(argc == 5)
{
M = std::stoi(argv[3]);
N = std::stoi(argv[4]);
}
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: time kernel (0=no, 1=yes)\n");
exit(0);
}
ck::index_t M = 48 * 256;
ck::index_t N = 1024;
if(argc == 1)
{
// use default case
}
else if(argc == 3)
{
M = std::stoi(argv[1]);
N = std::stoi(argv[2]);
}
else
{
std::cerr << "arg1 to 2: M, N" << std::endl;
return 1;
printf("arg3-4: M, N\n");
exit(1);
}
ck::index_t Stride = N;

View File

@@ -182,6 +182,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool Persistent = true;
static constexpr bool DoubleSmemBuffer = true;
};

View File

@@ -167,6 +167,113 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
template <typename GemmConfig,
typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
bool splitk)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
GemmConfig::Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
float ave_time{0};
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType (empty for no D tensors)
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout (empty for no D tensors)
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
return ave_time;
};
if(splitk)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
else
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
return ave_time;
}
#include "run_grouped_gemm_example.inc"
template <typename GemmConfig, typename PrecType>

View File

@@ -29,7 +29,7 @@ template <typename GemmConfig,
typename BQDataType,
typename AccDataType,
typename CDataType,
ck_tile::QuantType QuantMode>
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr)
@@ -48,8 +48,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
false,
false,
false, // PreshuffleQuant
false, // PreshuffleB
ALayout,
BLayout,
CLayout,
@@ -67,18 +67,29 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
constexpr auto memory_operation = memory_operation_.value;
constexpr bool transpose_c = false;
using QuantGemmProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
transpose_c,
BDataType,
scheduler>;
using QuantGemmProblem = typename std::conditional<
QuantMode == ck_tile::QuantType::BQuantGrouped,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
128>, // QuantGroupSize
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
transpose_c,
BDataType,
scheduler>>::type;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<QuantGemmProblem>;
using GemmPipeline =
typename std::conditional<QuantMode == ck_tile::QuantType::BQuantGrouped,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>::type;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,

View File

@@ -11,6 +11,7 @@
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_BQUANT_COMPUTE_V3 2
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
@@ -41,6 +42,14 @@ struct GemmTypeConfig<ck_tile::fp8_t>
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
struct GemmConfigBase
{
@@ -77,24 +86,11 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int kBlockPerCu = 1;
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
auto create_args(int argc, char* argv[])
@@ -122,8 +118,7 @@ auto create_args(int argc, char* argv[])
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "8", "group count.")
.insert("kbatch", "1", "kbatch for SplitK")
.insert("quant_mode", "tensor", "Choose tensor (default), or rowcol");
;
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -43,8 +43,8 @@ template <typename GemmConfig,
typename BLayout,
typename BQLayout,
typename CLayout,
ck_tile::QuantType QuantMode,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(int n_warmup,
int n_repeat,
int group_count,
@@ -159,11 +159,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
return group_count != 0 && ((args.size() == static_cast<size_t>(group_count)) && ...);
};
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
const int kbatch = arg_parser.get_int("kbatch");
bool validate = arg_parser.get_bool("validate");
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
const int kbatch = arg_parser.get_int("kbatch");
bool validate = arg_parser.get_bool("validate");
const ck_tile::index_t QuantGroupSize = 128;
if(kbatch > 1 && validate && warmup + repeat > 1)
{
@@ -172,9 +173,11 @@ int run_grouped_gemm_example_with_layouts(int argc,
validate = false;
}
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
std::vector<ck_tile::index_t> AQs; // dimension of AQ tensor is calculated from A tensor
std::vector<ck_tile::index_t> BQs; // dimension of BQ tensor is calculated from B tensor
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
@@ -252,6 +255,15 @@ int run_grouped_gemm_example_with_layouts(int argc,
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
AQK = 0; // No A quantization
BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize
if(K % QuantGroupSize != 0)
{
throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode");
}
}
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
@@ -289,6 +301,13 @@ int run_grouped_gemm_example_with_layouts(int argc,
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout))));
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout))));
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout))));
}
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
@@ -394,6 +413,17 @@ int run_grouped_gemm_example_with_layouts(int argc,
bq_tensors[i],
c_m_n_host_ref);
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
ck_tile::reference_gemm_quant<ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
false>(
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
}
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
@@ -441,42 +471,6 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
QuantMode>(
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
CDataType,
AccDataType,
QuantMode>(
argc, argv, Row{}, Row{}, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
CDataType,
AccDataType,
QuantMode>(
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
CDataType,
AccDataType,
QuantMode>(
argc, argv, Col{}, Col{}, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
@@ -513,6 +507,41 @@ int run_grouped_gemm_example(int argc, char* argv[])
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported quantization mode!");
}
}
if(data_type == "bf8")
{
if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported quantization mode!");

View File

@@ -70,23 +70,13 @@ float invoke_gemm(int n_warmup,
}
else
{
if(GemmConfig::Preshuffle)
{
// not supported yet
throw std::runtime_error(
"Persistent grouped gemm with preshuffle is not supported yet");
}
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to haveCollapse
// commentComment on line L74tenpercent commented on Sep 5, 2025 tenpercenton Sep 5,
// 2025ContributorMore actionsdid you intend to remove the comment?Write a replyResolve
// commentCode has comments. Press enter to view. the gemm problems known on the host.
// Instead, we can just pass the pointer to the kernel and let the workgroups figure out
// which tiles to work on. This is useful when the gemm problems are generated dynamically.
// In this example however, we generate the `kargs` using the known gemm_descs,
// and copy the gemm descriptions to the device memory.
// The contents of the memory pointed to by `kargs_ptr` pointer could be
// written by e.g. another kernel from earlier stage.
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have the gemm
// problems known on the host. Instead, we can just pass the pointer to the kernel and let
// the workgroups figure out which tiles to work on. This is useful when the gemm problems
// are generated dynamically. In this example however, we generate the `kargs` using the
// known gemm_descs, and copy the gemm descriptions to the device memory. The contents of
// the memory pointed to by `kargs_ptr` pointer could be written by e.g. another kernel from
// earlier stage.
std::vector<ck_tile::GemmTransKernelArg<>> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();

View File

@@ -4,6 +4,9 @@ list(APPEND EXAMPLE_CONV_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion
add_executable(tile_example_grouped_conv_fwd EXCLUDE_FROM_ALL grouped_convolution_forward.cpp)
target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_executable(tile_example_grouped_conv_fwd_bias_clamp EXCLUDE_FROM_ALL grouped_convolution_forward_bias_clamp.cpp)
target_compile_options(tile_example_grouped_conv_fwd_bias_clamp PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_executable(tile_example_grouped_conv_bwd_weight EXCLUDE_FROM_ALL grouped_convolution_backward_weight.cpp)
target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -0,0 +1,58 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "grouped_convolution_utils.hpp"
#include "grouped_convolution_forward_invoker.hpp"
#include "run_grouped_convolution_fwd_bias_clamp_example.inc"
template <template <typename PrecType> typename GemmConfig>
int run_grouped_conv_fwd_bias_clamp_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionForwardInvoker;
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string in_layout = arg_parser.get_str("in_layout");
std::string wei_layout = arg_parser.get_str("wei_layout");
std::string out_layout = arg_parser.get_str("out_layout");
if(data_type == "fp16")
{
return run_grouped_conv_fwd_bias_clamp_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_fwd_bias_clamp_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_fwd_bias_clamp_example<GemmConfigComputeV3_WMMA>(argc, argv);
#else
return !run_grouped_conv_fwd_bias_clamp_example<GemmConfigComputeV3>(argc, argv);
#endif
}

View File

@@ -15,10 +15,10 @@ struct GroupedConvolutionForwardInvoker
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDElementWise = ck_tile::element_wise::PassThrough>
static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<CDElementWise>& args,
const ck_tile::stream_config& s)
{
constexpr int kBlockPerCu = 1;
@@ -49,7 +49,8 @@ struct GroupedConvolutionForwardInvoker
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
VectorSizeC,
CDElementWise>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GemmConfig::kPadM,
@@ -128,7 +129,7 @@ struct GroupedConvolutionForwardInvoker
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
CDElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,

View File

@@ -0,0 +1,301 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
using BiasAndClamp = ck_tile::element_wise::
Compose<ck_tile::element_wise::MultiDAdd, ck_tile::element_wise::Clamp, true>;
template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename Invoker,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
float invoke_grouped_conv_fwd_bias_clamp(const ck_tile::GroupedConvFwdHostArgs<BiasAndClamp>& args,
int n_warmup,
int n_repeat)
{
float ave_time = Invoker::template grouped_conv_fwd<NDimSpatial,
GemmWarpConfig,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout,
ck_tile::tuple<OutDataType>,
ck_tile::tuple<OutLayout>,
BiasAndClamp>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = args.GetFlops();
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename Invoker,
typename InDataType,
typename WeiDataType = InDataType,
typename OutDataType = InDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
int run_grouped_conv_fwd_bias_clamp_example_with_layouts(
int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout)
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using AccDataType = float;
std::vector<ck_tile::index_t> filter_spatial_lengths;
std::vector<ck_tile::index_t> image_spatial_lengths;
std::vector<ck_tile::index_t> strides;
std::vector<ck_tile::index_t> dilations;
std::vector<ck_tile::index_t> lpads;
std::vector<ck_tile::index_t> rpads;
const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths,
image_spatial_lengths,
strides,
dilations,
lpads,
rpads,
arg_parser);
ck_tile::conv::ConvParam conv_param{num_dim_sp,
arg_parser.get_int("g"),
arg_parser.get_int("n"),
arg_parser.get_int("k"),
arg_parser.get_int("c"),
filter_spatial_lengths,
image_spatial_lengths,
strides,
dilations,
lpads,
rpads};
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
const float floor = -100.f;
const float ceil = 100.f;
const ck_tile::element_wise::MultiDAdd bias_op{};
const ck_tile::element_wise::Clamp clamp_op{floor, ceil};
const BiasAndClamp bias_clamp_op{bias_op, clamp_op};
const auto in_g_n_c_wis_desc =
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc =
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc =
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
ck_tile::HostTensor<InDataType> input(in_g_n_c_wis_desc);
ck_tile::HostTensor<WeiDataType> weight(wei_g_k_c_xs_desc);
ck_tile::HostTensor<OutDataType> output(out_g_n_k_wos_desc);
ck_tile::HostTensor<OutDataType> bias(out_g_n_k_wos_desc);
std::string bias_str = "";
if(init_method == 0)
{
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(input);
ck_tile::FillUniformDistribution<WeiDataType>{-5.f, 5.f}(weight);
ck_tile::FillUniformDistribution<OutDataType>{-5.f, 5.f}(bias);
bias_str = " (Uniform(-5,5))";
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<InDataType>{}(input);
ck_tile::FillMonotonicSeq<WeiDataType>{}(weight);
ck_tile::FillMonotonicSeq<OutDataType>{}(bias);
bias_str = " (Monotonic)";
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<InDataType>{1.f, 1.f}(input);
ck_tile::FillUniformDistribution<WeiDataType>{1.f, 1.f}(weight);
ck_tile::FillUniformDistribution<OutDataType>{1.f, 1.f}(bias);
bias_str = " (Constant 1)";
}
else
{
input.SetZero();
weight.SetZero();
bias.SetZero();
}
ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes());
ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes());
ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_dev_buf(bias.get_element_space_size_in_bytes());
input_dev_buf.ToDevice(input.data());
weight_dev_buf.ToDevice(weight.data());
output_dev_buf.SetZero();
bias_dev_buf.ToDevice(bias.data());
ck_tile::GroupedConvFwdHostArgs<BiasAndClamp> args(conv_param,
input_dev_buf.GetDeviceBuffer(),
weight_dev_buf.GetDeviceBuffer(),
{bias_dev_buf.GetDeviceBuffer()},
output_dev_buf.GetDeviceBuffer(),
kbatch,
bias_clamp_op);
std::cout << "Run Grouped Conv Fwd kernel with bias" << bias_str << " and clamp (" << floor
<< ", " << ceil << ")." << std::endl;
std::cout << "input: " << input.mDesc << std::endl;
std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_fwd_bias_clamp<NDimSpatial,
GemmWarpConfig,
Invoker,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(args, n_warmup, n_repeat);
output_dev_buf.FromDevice(output.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
// FIXME: Address this issue
if(arg_parser.get_int("g") > 1 && init_method == 0)
std::cerr << "Adding different bias to different groups yield incorrect results"
<< std::endl;
ck_tile::HostTensor<OutDataType> output_host_ref(out_g_n_k_wos_desc);
output_host_ref.SetZero();
auto bias_clamp_host = [floor,
ceil](float& y, const float& x, const OutDataType& element_bias) {
float x_float = ck_tile::type_convert<float>(x);
x_float += ck_tile::type_convert<float>(element_bias);
if(x_float < floor)
x_float = floor;
else if(x_float > ceil)
x_float = ceil;
y = x_float;
};
auto bias_tuple = ck_tile::make_tuple(bias);
ck_tile::reference_grouped_conv_fwd<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
decltype(bias_clamp_host)>(
input,
weight,
output_host_ref,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
bias_clamp_host,
bias_tuple);
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
const float max_accumulated_value =
*std::max_element(output_host_ref.mData.begin(), output_host_ref.mData.end());
const auto rtol_atol =
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
GemmK, kbatch, max_accumulated_value);
pass = ck_tile::check_err(output,
output_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
throw std::runtime_error("Unsupported gpu verification !!!");
}
return pass;
}
template <typename Invoker,
typename GemmWarpConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
int run_grouped_conv_fwd_bias_clamp_example_prec_type(
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
{
// using NWGC = ck_tile::tensor_layout::convolution::NWGC;
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
// using GKXC = ck_tile::tensor_layout::convolution::GKXC;
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
// using NWGK = ck_tile::tensor_layout::convolution::NWGK;
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
// FIXME: Fix crash in 1D convolution whem using Ds tensor.
throw std::runtime_error("1D Convolution does not support bias.");
// return run_grouped_conv_fwd_bias_clamp_example_with_layouts<ck_tile::number<1>{},
// GemmWarpConfig,
// Invoker,
// InPrecType,
// WeiPrecType,
// OutPrecType>(
// argc, argv, NWGC{}, GKXC{}, NWGK{});
}
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_fwd_bias_clamp_example_with_layouts<ck_tile::number<2>{},
GemmWarpConfig,
Invoker,
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
}
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_fwd_bias_clamp_example_with_layouts<ck_tile::number<3>{},
GemmWarpConfig,
Invoker,
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
}
else
{
throw std::runtime_error("Unsupported memory layout!");
}
}

View File

@@ -12,7 +12,7 @@ template <ck_tile::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout>
float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<>& args,
int n_warmup,
int n_repeat)
{
@@ -128,12 +128,12 @@ int run_grouped_conv_fwd_example_with_layouts(
weight_dev_buf.ToDevice(weight.data());
output_dev_buf.SetZero();
ck_tile::GroupedConvFwdHostArgs args(conv_param,
input_dev_buf.GetDeviceBuffer(),
weight_dev_buf.GetDeviceBuffer(),
{},
output_dev_buf.GetDeviceBuffer(),
kbatch);
ck_tile::GroupedConvFwdHostArgs<> args(conv_param,
input_dev_buf.GetDeviceBuffer(),
weight_dev_buf.GetDeviceBuffer(),
{},
output_dev_buf.GetDeviceBuffer(),
kbatch);
std::cout << "Run Grouped Conv Fwd kernel" << std::endl;
std::cout << "input: " << input.mDesc << std::endl;

View File

@@ -5,7 +5,7 @@
#include <random>
#include <stdexcept>
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/shuffle_utils.hpp"
#include "ck_tile/host/tensor_shuffle_utils.hpp"
template <typename GemmConfig,
typename TypeConfig,

View File

@@ -1,5 +1,10 @@
if(GPU_TARGETS MATCHES "gfx9")
add_executable(tile_example_streamk_gemm_basic EXCLUDE_FROM_ALL streamk_gemm_basic.cpp)
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
target_compile_options(tile_example_streamk_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile streamk gemm tests for current target")
endif()

View File

@@ -28,10 +28,10 @@ args:
-stride_b tensor B stride (default:0)
-stride_c tensor C stride (default:0)
-v validation strategy. 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
-prec data type. fp16/bf16 (default:fp16)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-warmup number of iterations before benchmarking the kernel (default:50)
-repeat number of iterations to benchmark the kernel (default:100)
-timer timing mode. gpu:gpu timer, cpu:cpu timer (default:gpu)
-init data initialization strategy. 0:random, 1:linear, 2:constant(1) (default:0)
-flush_cache flush the cache before running the kernel (default:true)
```
```

View File

@@ -75,6 +75,18 @@ struct DataTypeTraits<ck_tile::bf16_t>
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
@@ -94,7 +106,7 @@ auto create_args(int argc, char* argv[])
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmarking the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")

View File

@@ -56,7 +56,7 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
GemmUniversalTraits,
GemmConfig::Scheduler>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
@@ -187,6 +187,18 @@ int run_gemm_example(int argc, char* argv[])
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");

View File

@@ -28,3 +28,4 @@ add_subdirectory(38_block_scale_gemm)
add_subdirectory(39_copy)
add_subdirectory(40_streamk_gemm)
add_subdirectory(41_batched_contraction)

View File

@@ -23,9 +23,18 @@ This project is a prototype for a more general builder pattern for all of compos
To enable the experimental builder, configure your build with:
```sh
cmake -DCK_EXPERIMENTAL_BUILDER=ON -DCMAKE_CXX_STANDARD=20 ...
```bash
cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx942;gfx950" \
-D CK_EXPERIMENTAL_BUILDER=ON \
-D CMAKE_CXX_STANDARD=20 \
-G Ninja \
..
```
## Building and testing
During development, build and test from the CK build directory with

View File

@@ -0,0 +1,143 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/sequence.hpp"
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder {
// Convert a static array to a sequence
// Usage example:
// static constexpr std::vector arr {1, 2, 3};
// using seq = to_sequence_v<arr>; // seq is ck::Sequence<1, 2, 3>
template <typename T, const T& Arr>
struct to_sequence_t
{
private:
template <std::size_t... Is>
static auto get_sequence_type(std::index_sequence<Is...>) -> ck::Sequence<Arr[Is]...>;
// Helper method to handler the unusual .Size() method name in ck::Array.
static constexpr auto get_size(const auto& arr)
{
if constexpr(requires { arr.size(); })
{
return arr.size();
}
else
{
return arr.Size();
}
}
public:
using value = decltype(get_sequence_type(std::make_index_sequence<get_size(Arr)>{}));
};
template <auto& Arr>
using to_sequence_v = typename to_sequence_t<std::remove_cvref_t<decltype(Arr)>, Arr>::value;
// Wrapper function to make constexpr strings a structural type for NTTP.
template <size_t N>
struct StringLiteral
{
char data[N];
constexpr StringLiteral(const char (&str)[N])
{
for(size_t i = 0; i < N; ++i)
data[i] = str[i];
}
constexpr bool operator==(const StringLiteral<N>& other) const
{
for(size_t i = 0; i < N; ++i)
{
if(data[i] != other.data[i])
{
return false;
}
}
return true;
}
};
// This is a C++17 deduction guide. It allows the compiler to automatically
// deduce the template argument `N` for `StringLiteral` from a string literal
// constructor argument. For example, you can write `StringLiteral s{"foo"};`
// instead of `StringLiteral<4> s{"foo"};`.
template <size_t N>
StringLiteral(const char (&)[N]) -> StringLiteral<N>;
// Helper to provide a readable error for unsupported enum values.
// The compiler will print the name of this struct in the error message, so
// the name of the enum value will appear instead of just its integer value.
template <auto T>
struct UnsupportedEnumValue
{
};
// Helper functions to convert enums to strings
constexpr std::string_view ConvDirectionToString(ConvDirection dir)
{
switch(dir)
{
case ConvDirection::FORWARD: return "Forward";
case ConvDirection::BACKWARD_DATA: return "Backward Data";
case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight";
default: return "Unknown";
}
}
constexpr std::string_view DataTypeToString(DataType dt)
{
switch(dt)
{
case DataType::FP16: return "FP16";
case DataType::FP32: return "FP32";
case DataType::BF16: return "BF16";
case DataType::FP8: return "FP8";
case DataType::I8: return "I8";
case DataType::U8: return "U8";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout1D layout)
{
switch(layout)
{
case GroupConvLayout1D::GNWC_GKXC_GNWK: return "GNWC_GKXC_GNWK";
case GroupConvLayout1D::NWGC_GKXC_NWGK: return "NWGC_GKXC_NWGK";
case GroupConvLayout1D::NGCW_GKXC_NGKW: return "NGCW_GKXC_NGKW";
case GroupConvLayout1D::NGCW_GKCX_NGKW: return "NGCW_GKCX_NGKW";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout2D layout)
{
switch(layout)
{
case GroupConvLayout2D::GNHWC_GKYXC_GNHWK: return "GNHWC_GKYXC_GNHWK";
case GroupConvLayout2D::NHWGC_GKYXC_NHWGK: return "NHWGC_GKYXC_NHWGK";
case GroupConvLayout2D::NGCHW_GKYXC_NGKHW: return "NGCHW_GKYXC_NGKHW";
case GroupConvLayout2D::NGCHW_GKCYX_NGKHW: return "NGCHW_GKCYX_NGKHW";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout3D layout)
{
switch(layout)
{
case GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK: return "GNDHWC_GKZYXC_GNDHWK";
case GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK: return "NDHWGC_GKZYXC_NDHWGK";
case GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW: return "NGCDHW_GKZYXC_NGKDHW";
case GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW: return "NGCDHW_GKCZYX_NGKDHW";
default: return "Unknown";
}
}
} // namespace ck_tile::builder

View File

@@ -0,0 +1,141 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <type_traits>
#include <concepts>
#include <array>
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder {
/********************************************************************/
/* Descriptors for individual elements of the algorithm description */
/********************************************************************/
// Concept for thread block dimensions for a GEMM problem.
template <typename T>
concept ThreadBlockDescriptor = requires(T t) {
{ t.block_size } -> std::convertible_to<size_t>;
{ t.tile_size.m } -> std::convertible_to<size_t>;
{ t.tile_size.n } -> std::convertible_to<size_t>;
{ t.tile_size.k } -> std::convertible_to<size_t>;
};
// Concept for parameters that describe a gridwise GEMM problem.
template <typename T>
concept GridwiseGemmDescriptor = requires(T t) {
{ t.ak1 } -> std::convertible_to<size_t>;
{ t.bk1 } -> std::convertible_to<size_t>;
{ t.m_per_xdl } -> std::convertible_to<size_t>;
{ t.n_per_xdl } -> std::convertible_to<size_t>;
{ t.m_xdl_per_wave } -> std::convertible_to<size_t>;
{ t.n_xdl_per_wave } -> std::convertible_to<size_t>;
};
// Concept for vectorized data transfer for convolution input tensors.
template <typename T>
concept BlockTransferDescriptor = requires(T t) {
{ t.k0 } -> std::convertible_to<size_t>;
{ t.m_n } -> std::convertible_to<size_t>;
{ t.k1 } -> std::convertible_to<size_t>;
};
// Concept for thread cluster dimensions for GEMM output tensor.
template <typename T>
concept ThreadClusterDescriptor = requires(T t) {
{ t.m_block } -> std::convertible_to<size_t>;
{ t.m_wave_per_xdl } -> std::convertible_to<size_t>;
{ t.n_block } -> std::convertible_to<size_t>;
{ t.n_wave_per_xdl } -> std::convertible_to<size_t>;
};
// Concept for the LDS transfer for the convolution input tensors.
template <typename T>
concept LdsTransferDescriptor = requires(T t) {
{ t.src_vector_dim } -> std::convertible_to<size_t>;
{ t.src_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.lds_dst_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.is_direct_load } -> std::convertible_to<bool>;
{ t.lds_padding } -> std::convertible_to<bool>;
};
// Concept for the convolution output tensor epilogue (copy from registers to global memory via
// LDS).
template <typename T>
concept EpilogueDescriptor = requires(T t) {
{ t.m_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
{ t.n_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
{ t.scalar_per_vector } -> std::convertible_to<size_t>;
};
// Concept for the thread cluster access order
template <typename T>
concept AccessOrderDescriptor = requires(T t) {
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
};
// No requirements yet for a ConvAlogorithm concept.
template <typename T>
concept ConvAlgorithmDescriptor = std::is_class_v<T>;
/******************************************** */
/* Requirements for the algorithm description */
/******************************************** */
// Concept to check if struct specifies thread block info.
template <typename T>
concept SpecifiesThreadBlock = requires {
{ T::thread_block } -> ThreadBlockDescriptor;
};
// Concept to check if a struct specifies gridwise GEMM info.
template <typename T>
concept SpecifiesGridwiseGemm = requires {
{ T::gridwise_gemm } -> GridwiseGemmDescriptor;
};
// Concept to check if a struct specifies convolution input and output block transfer info.
template <typename T>
concept SpecifiesBlockTransfer = requires(T t) {
{ T::block_transfer.block_transfer_a } -> BlockTransferDescriptor;
{ T::block_transfer.block_transfer_b } -> BlockTransferDescriptor;
{ T::block_transfer.thread_cluster_dims_c } -> ThreadClusterDescriptor;
};
// Concept to check if a struct specifies LDS transfer info for tensors A, B, and C.
template <typename T>
concept SpecifiesLdsTransfer = requires(T t) {
{ T::block_transfer.lds_transfer_a } -> LdsTransferDescriptor;
{ T::block_transfer.lds_transfer_b } -> LdsTransferDescriptor;
{ T::block_transfer.epilogue_c } -> EpilogueDescriptor;
};
// Concept to check if a struct specifies thread cluster access order info.
template <typename T>
concept SpecifiesThreadClusterAccessOrder = requires(T t) {
{ T::block_transfer.block_transfer_access_order_a } -> AccessOrderDescriptor;
{ T::block_transfer.block_transfer_access_order_b } -> AccessOrderDescriptor;
};
// Concept to check if a struct specifies source access order info.
template <typename T>
concept SpecifiesSourceAccessOrder = requires(T t) {
{ T::block_transfer.src_access_order_a } -> AccessOrderDescriptor;
{ T::block_transfer.src_access_order_b } -> AccessOrderDescriptor;
};
// Concept to check if struct specifies block_gemm_pipeline_version.
template <typename T>
concept SpecifiesGemmPipelineVersion = requires {
{ T::pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
};
template <typename T>
concept SpecifiesFwdConcSpecialization = requires {
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;
};
} // namespace ck_tile::builder

View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <type_traits>
#include <concepts>
namespace ck_tile::builder {
// Limits for input vector transfer.
template <auto Value>
concept InputVectorTransferLimits = requires {
requires Value.src_vector_dim > 0 && Value.src_scalar_per_vector > 0 &&
Value.lds_dst_scalar_per_vector > 0;
};
// Limits for output vector transfer.
template <auto Value>
concept OutputVectorTransferLimits = requires {
requires Value.scalar_per_vector > 0 && Value.m_xdl_per_wave_per_shuffle > 0 &&
Value.n_xdl_per_wave_per_shuffle > 0;
};
// Limits for access order. Must be a permutation of {0, 1, 2}.
template <auto Value>
concept AccessOrderLimits = requires {
requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[1] != Value[2]) &&
(Value[0] >= 0 && Value[0] < 3) && (Value[1] >= 0 && Value[1] < 3) &&
(Value[2] >= 0 && Value[2] < 3));
};
} // namespace ck_tile::builder

View File

@@ -0,0 +1,38 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <concepts>
#include <type_traits>
#include "ck_tile/builder/conv_factory.hpp"
#include "ck_tile/builder/versions.hpp"
namespace ck_tile::builder {
/**
* @brief Top-level builder for creating convolution kernel instances.
*
* This struct serves as the main entry point for generating a convolution kernel.
* It uses a factory pattern based on the provided signature, algorithm, and version
* to construct the appropriate kernel instance.
*
* @tparam SIGNATURE The convolution signature, which describes the mathematical functionality of
* the algorithm (e.g., data types, layouts, direction).
* @tparam ALGORITHM The specific convolution algorithm to be used for the implementation.
* @tparam VERSION The version of the builder implementation.
*/
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION = LATEST_API_VERSION>
requires SupportedVersion<VERSION> && ValidConvSignature<SIGNATURE>
struct ConvBuilder
{
static constexpr auto kVersion = VERSION;
using Factory = ConvFactory<SIGNATURE, ALGORITHM, VERSION>;
// Output: The kernel class.
using Instance = Factory::Instance;
};
} // namespace ck_tile::builder

View File

@@ -0,0 +1,539 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// A factory for instantiating CK convolution kernels.
//
// This file translates a semantic description of a convolution operation
// (`ConvSignatureDescriptor` and `ConvAlgorithmDescriptor`) into specific,
// low-level template arguments required by the underlying CK device-level
// kernel implementations. This abstraction enables more complex build
// time logic and simplifies the kernel specification.
//
// Key Components:
//
// Template Metaprogram:
// - ConvFactory: The main factory, with specializations for different
// convolution directions (currently only forward).
//
// Template Metaprogram Helpers:
// - ConvTensorLayouts: Maps layout enums to CK layout types for different
// spatial dimensions (2D/3D) and directions.
// - ConvTensorTypes: Maps data type enums (FP16, BF16, FP32) to C++ types used by CK.
// - ConvPassThroughOps: Hard-coded pass-through element-wise operations.
// - ConvSpec: Encapsulates convolution and GEMM specialization enums.
//
// `constexpr` Helper Functions:
// - SetThreadBlockInfo: Determines thread block dimensions and tile sizes.
// - SetConvTuningInfo: Sets XDL and AK1/BK1 tuning parameters.
// - SetFwdConvABlockTransfer: Configures A tensor block transfer parameters.
// - SetFwdConvBBlockTransfer: Configures B tensor block transfer parameters.
// - SetCBlockTransfer: Configures C tensor block transfer parameters.
// - SetBlockGemmPipelineVersion: Maps pipeline version enum to CK types.
//
// The primary entry point is the `ConvFactory` struct, which is currently
// specialized for forward convolutions and produces instances of
// DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3.
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/types.hpp"
#include "ck_tile/builder/versions.hpp"
namespace ck_tile::builder::factory_internal {
// Type mappings from the builder FwdGroupConvLayout enum classes to the CK tensor data types.
template <auto LayoutValue, size_t SPATIAL_DIM, ConvDirection DIR>
requires(ConvSpatialDim<SPATIAL_DIM> && ValidConvLayoutForSpatialDim<LayoutValue, SPATIAL_DIM>)
struct ConvTensorLayouts
{
// This will trigger if a specialization for the given layout is not found.
// We should always catch this in an earlier validation check.
using Layout = decltype(LayoutValue);
static_assert(sizeof(Layout) == 0,
"Internal error. Unsupported layout for convolution factory.");
};
// 1D Forward Convolution Layout Specializations
template <>
struct ConvTensorLayouts<GroupConvLayout1D::NWGC_GKXC_NWGK, 1, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::NWGC;
using BLayout = ck::tensor_layout::convolution::GKXC;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::NWGK;
};
template <>
struct ConvTensorLayouts<GroupConvLayout1D::NGCW_GKXC_NGKW, 1, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::NGCW;
using BLayout = ck::tensor_layout::convolution::GKXC;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::NGKW;
};
template <>
struct ConvTensorLayouts<GroupConvLayout1D::GNWC_GKXC_GNWK, 1, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::GNWC;
using BLayout = ck::tensor_layout::convolution::GKXC;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::GNWK;
};
template <>
struct ConvTensorLayouts<GroupConvLayout1D::NGCW_GKCX_NGKW, 1, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::NGCW;
using BLayout = ck::tensor_layout::convolution::GKCX;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::NGKW;
};
template <>
struct ConvTensorLayouts<GroupConvLayout2D::NGCHW_GKYXC_NGKHW, 2, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::NGCHW;
using BLayout = ck::tensor_layout::convolution::GKYXC;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::NGKHW;
};
template <>
struct ConvTensorLayouts<GroupConvLayout2D::NHWGC_GKYXC_NHWGK, 2, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::NHWGC;
using BLayout = ck::tensor_layout::convolution::GKYXC;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::NHWGK;
};
template <>
struct ConvTensorLayouts<GroupConvLayout2D::GNHWC_GKYXC_GNHWK, 2, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::GNHWC;
using BLayout = ck::tensor_layout::convolution::GKYXC;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::GNHWK;
};
template <>
struct ConvTensorLayouts<GroupConvLayout2D::NGCHW_GKCYX_NGKHW, 2, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::NGCHW;
using BLayout = ck::tensor_layout::convolution::GKCYX;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::NGKHW;
};
template <>
struct ConvTensorLayouts<GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, 3, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::NGCDHW;
using BLayout = ck::tensor_layout::convolution::GKCZYX;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::NGKDHW;
};
template <>
struct ConvTensorLayouts<GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, 3, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::NDHWGC;
using BLayout = ck::tensor_layout::convolution::GKZYXC;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::NDHWGK;
};
template <>
struct ConvTensorLayouts<GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, 3, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::GNDHWC;
using BLayout = ck::tensor_layout::convolution::GKZYXC;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::GNDHWK;
};
// Type mappings from builder convolution data type to CK tensor types.
template <DataType T>
struct ConvTensorTypes
{
// This will trigger if a specialization for the given DataType is not found.
// We should always catch this in an earlier validation check.
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
"Internal error. Unsupported data type for convolution factory.");
};
template <>
struct ConvTensorTypes<DataType::FP16>
{
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CShuffleDataType = ck::half_t;
using DsDataTypes = ck::Tuple<>;
using AccDataType = float;
using EDataType = ck::half_t;
};
template <>
struct ConvTensorTypes<DataType::BF16>
{
using ADataType = ck::bhalf_t;
using BDataType = ck::bhalf_t;
using CShuffleDataType = ck::bhalf_t;
using DsDataTypes = ck::Tuple<>;
using AccDataType = float;
using EDataType = ck::bhalf_t;
};
template <>
struct ConvTensorTypes<DataType::FP32>
{
using ADataType = float;
using BDataType = float;
using CShuffleDataType = float;
using DsDataTypes = ck::Tuple<>;
using AccDataType = float;
using EDataType = float;
};
template <ElementwiseOperation T>
struct ElementwiseOps
{
// This will trigger if a specialization for the given DataType is not found.
// We should always catch this in an earlier validation check.
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
"Internal error. Unsupported elementwise operation for convolution factory.");
};
template <>
struct ElementwiseOps<ElementwiseOperation::PASS_THROUGH>
{
using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
};
// The algorithm specializations for the convolution and GEMM.
template <typename CONV_ENUM>
requires(
std::is_same_v<CONV_ENUM, ck::tensor_operation::device::ConvolutionForwardSpecialization>)
struct ConvSpec
{
CONV_ENUM conv_spec;
ck::tensor_operation::device::GemmSpecialization gemm_spec;
};
// Deduction guide for ConvSpec to simplify brace initialization.
template <typename CONV_ENUM, typename GEMM_ENUM>
ConvSpec(CONV_ENUM, GEMM_ENUM) -> ConvSpec<CONV_ENUM>;
// Block info for a convolution.
struct MNK
{
size_t m{};
size_t n{};
size_t k{};
};
struct ConvBlock
{
size_t block_size = 0;
MNK per_block = {};
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr ConvBlock SetThreadBlockInfo()
{
constexpr auto& TB = ALGORITHM.thread_block;
return ConvBlock{.block_size = TB.block_size,
.per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k}};
}
// Convolution tuning parameters.
struct GridwiseGemm
{
size_t ak1 = 0;
size_t bk1 = 0;
size_t m_per_xdl = 0;
size_t n_per_xdl = 0;
size_t m_xdl_per_wave = 0;
size_t n_xdl_per_wave = 0;
};
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
constexpr GridwiseGemm SetGridwiseGemmInfo()
{
constexpr auto& TP = ALGORITHM.gridwise_gemm;
return GridwiseGemm{
.ak1 = TP.ak1,
.bk1 = TP.bk1,
.m_per_xdl = TP.m_per_xdl,
.n_per_xdl = TP.n_per_xdl,
.m_xdl_per_wave = TP.m_xdl_per_wave,
.n_xdl_per_wave = TP.n_xdl_per_wave,
};
}
// Block transfer parameters for A or B tensor.
struct BlockTransfer
{
ck::Array<size_t, 3> thread_cluster_dims = {0, 0, 0}; // k0, m, k1
ck::Array<size_t, 3> thread_cluster_order = {0, 0, 0};
ck::Array<size_t, 3> src_access_order = {0, 0, 0};
size_t src_vector_dim = 0;
size_t src_scalar_per_vector = 0;
size_t lds_dst_scalar_per_vector = 0;
bool is_direct_load = false;
bool lds_padding = false;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr BlockTransfer SetFwdConvABlockTransfer()
{
constexpr auto& TCL = ALGORITHM.block_transfer.block_transfer_a;
constexpr auto& TCO = ALGORITHM.block_transfer.block_transfer_access_order_a;
constexpr auto& SAO = ALGORITHM.block_transfer.src_access_order_a;
constexpr auto& LDS = ALGORITHM.block_transfer.lds_transfer_a;
BlockTransfer block_transfer{.thread_cluster_dims = {TCL.k0, TCL.m_n, TCL.k1},
.thread_cluster_order = {TCO.order[0], TCO.order[1], TCO.order[2]},
.src_access_order = {SAO.order[0], SAO.order[1], SAO.order[2]},
.src_vector_dim = LDS.src_vector_dim,
.src_scalar_per_vector = LDS.src_scalar_per_vector,
.lds_dst_scalar_per_vector = LDS.lds_dst_scalar_per_vector,
.is_direct_load = LDS.is_direct_load,
.lds_padding = LDS.lds_padding};
return block_transfer;
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr BlockTransfer SetFwdConvBBlockTransfer()
{
constexpr auto& TCL = ALGORITHM.block_transfer.block_transfer_b;
constexpr auto& TCO = ALGORITHM.block_transfer.block_transfer_access_order_b;
constexpr auto& SAO = ALGORITHM.block_transfer.src_access_order_b;
constexpr auto& LDS = ALGORITHM.block_transfer.lds_transfer_b;
BlockTransfer block_transfer{.thread_cluster_dims = {TCL.k0, TCL.m_n, TCL.k1},
.thread_cluster_order = {TCO.order[0], TCO.order[1], TCO.order[2]},
.src_access_order = {SAO.order[0], SAO.order[1], SAO.order[2]},
.src_vector_dim = LDS.src_vector_dim,
.src_scalar_per_vector = LDS.src_scalar_per_vector,
.lds_dst_scalar_per_vector = LDS.lds_dst_scalar_per_vector,
.is_direct_load = LDS.is_direct_load,
.lds_padding = LDS.lds_padding};
return block_transfer;
}
// Block transfer parameters for C tensor.
struct CBlockTransfer
{
size_t m_xdl_per_wave_per_shuffle = 0;
size_t n_xdl_per_wave_per_shuffle = 0;
ck::Array<size_t, 4> thread_cluster_dims = {0, 0, 0, 0};
size_t scalar_per_vector = 0;
};
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
constexpr CBlockTransfer SetCBlockTransfer()
{
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_dims_c;
constexpr auto& EPC = ALGORITHM.block_transfer.epilogue_c;
CBlockTransfer block_transfer{.m_xdl_per_wave_per_shuffle = EPC.m_xdl_per_wave_per_shuffle,
.n_xdl_per_wave_per_shuffle = EPC.n_xdl_per_wave_per_shuffle,
.thread_cluster_dims =
{
TCL.m_block,
TCL.m_wave_per_xdl,
TCL.n_block,
TCL.n_wave_per_xdl,
},
.scalar_per_vector = EPC.scalar_per_vector};
return block_transfer;
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
{
constexpr auto version = ALGORITHM.pipeline_version;
if constexpr(version == BlockGemmPipelineVersion::V1)
{
return ck::BlockGemmPipelineVersion::v1;
}
else if constexpr(version == BlockGemmPipelineVersion::V3)
{
return ck::BlockGemmPipelineVersion::v3;
}
else if constexpr(version == BlockGemmPipelineVersion::V4)
{
return ck::BlockGemmPipelineVersion::v4;
}
else if constexpr(version == BlockGemmPipelineVersion::V5)
{
return ck::BlockGemmPipelineVersion::v5;
}
else
{
static_assert(false, "Unknown BlockGemmPipelineVersion");
}
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdConvSpecialization()
{
constexpr auto specialization = ALGORITHM.fwd_specialization;
if constexpr(specialization == ConvFwdSpecialization::DEFAULT)
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
}
else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
}
else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
}
else if constexpr(specialization == ConvFwdSpecialization::FILTER_3x3)
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3;
}
else
{
static_assert(false, "Unknown ConvFwdSpecialization");
}
}
} // namespace ck_tile::builder::factory_internal
namespace ck_tile::builder {
// Primary template for the convolution factory.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
auto VERSION>
struct ConvFactory;
// Factory specialization for an instance of a grouped forward convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts =
factory_internal::ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::FORWARD>;
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using AlgorithmType = decltype(ALGORITHM);
// Check preconditions for the algorithm description.
static_assert(SPATIAL_DIM == 2 || SPATIAL_DIM == 3,
"Only 2D and 3D convolutions are supported in this factory.");
static_assert(SpecifiesThreadBlock<AlgorithmType>,
"The convolution algorithm descriptor must specify thread block info.");
static_assert(SpecifiesGridwiseGemm<AlgorithmType>,
"The convolution algorithm descriptor must specify gridwise GEMM info.");
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
"The convolution algorithm descriptor must specify block transfer info.");
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
"The convolution algorithm descriptor must specify LDS transfer info.");
static_assert(
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
"The convolution algorithm descriptor must specify thread cluster access order info.");
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
"The convolution algorithm descriptor must specify source access order info.");
static_assert(SpecifiesGemmPipelineVersion<AlgorithmType>,
"The convolution algorithm descriptor must specify block gemm pipeline version.");
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify forward convolution "
"specialization.");
static constexpr auto FWD_CONV_SPECIALIZATION =
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr factory_internal::ConvSpec SPECIALIZATION{
.conv_spec = FWD_CONV_SPECIALIZATION,
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
};
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM =
factory_internal::SetGridwiseGemmInfo<SIGNATURE, ALGORITHM>();
static constexpr auto A_BLOCK_TRANSFER =
factory_internal::SetFwdConvABlockTransfer<ALGORITHM>();
static constexpr auto B_BLOCK_TRANSFER =
factory_internal::SetFwdConvBBlockTransfer<ALGORITHM>();
static constexpr auto C_BLOCK_TRANSFER =
factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
static constexpr auto PIPELINE_VERSION =
factory_internal::SetBlockGemmPipelineVersion<ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
// The forward convolution kernel class instance.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< //
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Types::ADataType,
typename Types::BDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.ak1,
GRIDWISE_GEMM.bk1,
GRIDWISE_GEMM.m_per_xdl,
GRIDWISE_GEMM.n_per_xdl,
GRIDWISE_GEMM.m_xdl_per_wave,
GRIDWISE_GEMM.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
PIPELINE_SCHEDULER,
PIPELINE_VERSION>;
};
} // namespace ck_tile::builder

View File

@@ -0,0 +1,74 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// This file defines the compile-time "signature" for grouped convolution operations.
// A signature is a collection of properties that fully describe a convolution kernel's
// mathematical characteristics. It uses C++20 concepts and enums to specify these
// properties, enabling compile-time validation and specialization.
//
// The core components of a signature are:
// - Spatial dimensionality (1D, 2D, 3D)
// - Operational direction (Forward, Backward Data, Backward Weight)
// - Tensor memory layout (Channels First/Last)
// - Data type (FP32, FP16, BF16)
// - Fused element-wise operation (e.g., Bias, Clamp)
//
// The file also provides predicate concepts to query the properties of a given
// signature at compile time.
#pragma once
#include <concepts>
#include <type_traits>
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder {
// Constrains convolution to 1D, 2D, or 3D spatial dimensions.
template <auto N>
concept ConvSpatialDim = std::is_integral_v<decltype(N)> && (N == 1 || N == 2 || N == 3);
// Constraints for forward convolution layouts.
template <auto LayoutValue, size_t SpatialDim>
concept ValidConvLayoutForSpatialDim =
(SpatialDim == 1 && std::same_as<decltype(LayoutValue), GroupConvLayout1D>) ||
(SpatialDim == 2 && std::same_as<decltype(LayoutValue), GroupConvLayout2D>) ||
(SpatialDim == 3 && std::same_as<decltype(LayoutValue), GroupConvLayout3D>);
// Constrains convolution data types to common floating-point types.
template <DataType T>
concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) ||
(T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8);
// Concept for a type that defines a convolution's operational signature.
template <typename T>
concept ConvSignatureDescriptor = requires(T t) {
{ t.spatial_dim } -> std::convertible_to<unsigned int>;
{ t.direction } -> std::convertible_to<ConvDirection>;
requires std::convertible_to<decltype(t.layout), GroupConvLayout1D> ||
std::convertible_to<decltype(t.layout), GroupConvLayout2D> ||
std::convertible_to<decltype(t.layout), GroupConvLayout3D>;
{ t.data_type } -> std::convertible_to<DataType>;
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
};
// Concept to validate a convolution signature's values.
template <auto Sig>
concept ValidConvSignature = requires {
requires ConvSpatialDim<Sig.spatial_dim>;
requires ConvDataType<Sig.data_type>;
};
// Predicate for forward convolution.
template <auto Sig>
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
// Predicate for backward data convolution.
template <auto Sig>
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
// Predicate for backward weight convolution.
template <auto Sig>
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
} // namespace ck_tile::builder

View File

@@ -16,25 +16,19 @@
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck_tile/ops/common/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
namespace ck_tile::reflect::detail {
// Metaprogramming helper to convert ck::Sequence to constexpr std::array
template <typename Seq>
struct SequenceToArray;
template <ck::index_t... Is>
struct SequenceToArray<ck::Sequence<Is...>>
{
static constexpr std::array<int, sizeof...(Is)> value = {static_cast<int>(Is)...};
};
// Convert data types to string names
// Implementation detail for type name mapping
// This is the single source of truth for supported data types that
// returns an empty string to indicate an unsupported type.
namespace impl {
template <typename T>
consteval std::string_view type_name()
consteval std::string_view type_name_impl()
{
if constexpr(std::is_same_v<T, ck::half_t>)
return "fp16";
@@ -55,20 +49,38 @@ consteval std::string_view type_name()
else if constexpr(std::is_same_v<T, ck::bf8_t>)
return "bf8";
else
static_assert(false, "unknown_type");
return std::string_view{}; // Return empty for supported types
}
} // namespace impl
// Convert data types to string names
// Fails at compile time for unsupported types
template <typename T>
consteval std::string_view type_name()
{
constexpr auto name = impl::type_name_impl<T>();
static_assert(!name.empty(), "Unsupported data type");
return name;
}
// Convert layout types to string names
// Concept that checks if a type is a valid data type
// Uses the impl directly to avoid triggering static_assert during concept evaluation
template <typename T>
concept IsDataType = !impl::type_name_impl<T>().empty();
// Concept that checks valid layout types
template <typename T>
concept IsLayoutType = (std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> ||
std::is_base_of_v<ck::tensor_layout::BaseTensorLayout, T>) &&
requires {
{ T::name } -> std::convertible_to<std::string_view>;
};
// Convert layout types to string names
template <IsLayoutType T>
constexpr std::string_view layout_name()
{
if constexpr(std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> && requires {
{ T::name } -> std::convertible_to<std::string_view>;
})
return T::name;
else
static_assert(false,
"Layout type must derive from BaseTensorLayout and have name attribute");
return T::name;
}
// Convert element-wise operation types to string names
@@ -87,64 +99,64 @@ constexpr std::string_view elementwise_op_name()
constexpr std::string_view
conv_fwd_spec_name(ck::tensor_operation::device::ConvolutionForwardSpecialization spec)
{
using ck::tensor_operation::device::ConvolutionForwardSpecialization;
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
switch(spec)
{
case ConvolutionForwardSpecialization::Default: return "Default";
case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3";
case ConvolutionForwardSpecialization::OddC: return "OddC";
case Default: return "Default";
case Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
case Filter1x1Pad0: return "Filter1x1Pad0";
case Filter3x3: return "Filter3x3";
case OddC: return "OddC";
}
}
// Convert GemmSpecialization enum to string
constexpr std::string_view gemm_spec_name(ck::tensor_operation::device::GemmSpecialization spec)
{
using ck::tensor_operation::device::GemmSpecialization;
using enum ck::tensor_operation::device::GemmSpecialization;
switch(spec)
{
case GemmSpecialization::Default: return "Default";
case GemmSpecialization::MPadding: return "MPadding";
case GemmSpecialization::NPadding: return "NPadding";
case GemmSpecialization::KPadding: return "KPadding";
case GemmSpecialization::MNPadding: return "MNPadding";
case GemmSpecialization::MKPadding: return "MKPadding";
case GemmSpecialization::NKPadding: return "NKPadding";
case GemmSpecialization::MNKPadding: return "MNKPadding";
case GemmSpecialization::OPadding: return "OPadding";
case GemmSpecialization::MOPadding: return "MOPadding";
case GemmSpecialization::NOPadding: return "NOPadding";
case GemmSpecialization::KOPadding: return "KOPadding";
case GemmSpecialization::MNOPadding: return "MNOPadding";
case GemmSpecialization::MKOPadding: return "MKOPadding";
case GemmSpecialization::NKOPadding: return "NKOPadding";
case GemmSpecialization::MNKOPadding: return "MNKOPadding";
case Default: return "Default";
case MPadding: return "MPadding";
case NPadding: return "NPadding";
case KPadding: return "KPadding";
case MNPadding: return "MNPadding";
case MKPadding: return "MKPadding";
case NKPadding: return "NKPadding";
case MNKPadding: return "MNKPadding";
case OPadding: return "OPadding";
case MOPadding: return "MOPadding";
case NOPadding: return "NOPadding";
case KOPadding: return "KOPadding";
case MNOPadding: return "MNOPadding";
case MKOPadding: return "MKOPadding";
case NKOPadding: return "NKOPadding";
case MNKOPadding: return "MNKOPadding";
}
}
// Convert BlockGemmPipelineScheduler enum to string
constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineScheduler sched)
{
using ck::BlockGemmPipelineScheduler;
using enum ck::BlockGemmPipelineScheduler;
switch(sched)
{
case BlockGemmPipelineScheduler::Intrawave: return "Intrawave";
case BlockGemmPipelineScheduler::Interwave: return "Interwave";
case Intrawave: return "Intrawave";
case Interwave: return "Interwave";
}
}
// Convert BlockGemmPipelineVersion enum to string
constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver)
{
using ck::BlockGemmPipelineVersion;
using enum ck::BlockGemmPipelineVersion;
switch(ver)
{
case BlockGemmPipelineVersion::v1: return "v1";
case BlockGemmPipelineVersion::v2: return "v2";
case BlockGemmPipelineVersion::v3: return "v3";
case BlockGemmPipelineVersion::v4: return "v4";
case BlockGemmPipelineVersion::v5: return "v5";
case v1: return "v1";
case v2: return "v2";
case v3: return "v3";
case v4: return "v4";
case v5: return "v5";
}
}
@@ -164,12 +176,138 @@ inline std::string array_to_string(const std::array<T, N>& arr)
return oss.str();
}
// Handle ck::Tuple (empty tuple for DsLayout/DsDataType)
template <typename T>
constexpr std::string_view tuple_name()
// Metaprogramming helper to convert ck::Sequence to constexpr std::array
template <typename Seq>
struct SequenceToArray;
template <ck::index_t... Is>
struct SequenceToArray<ck::Sequence<Is...>>
{
// For now, just check if it's an empty tuple
return "EmptyTuple";
static constexpr std::array<int, sizeof...(Is)> value = {static_cast<int>(Is)...};
};
namespace detail {
// Generic helper to build list-like strings (Tuple, Seq, etc.)
//
// Example output: "Seq(1,2,3)"
//
// prefix: The list-like container name (e.g. "Tuple" or "Seq")
// converter_fn: A callable that converts each element to a string representation
// For types: converter_fn should be a template lambda like []<typename U>() { return
// type_name<U>(); } For values: converter_fn should be a regular lambda like [](auto value) {
// return std::to_string(value); }
template <typename ConverterFn, typename... Elements>
constexpr std::string build_list_string(std::string_view prefix, const ConverterFn& converter_fn)
{
if constexpr(sizeof...(Elements) == 0)
{
return std::string(prefix) + "()";
}
else
{
std::string result = std::string(prefix) + "(";
std::size_t index = 0;
((result +=
(index++ > 0 ? "," : "") + std::string(converter_fn.template operator()<Elements>())),
...);
result += ")";
return result;
}
}
// Overload for value-based lists (sequences)
template <typename ConverterFn, auto... Values>
constexpr std::string build_list_string_values(std::string_view prefix,
const ConverterFn& converter_fn)
{
if constexpr(sizeof...(Values) == 0)
{
return std::string(prefix) + "()";
}
else
{
std::string result = std::string(prefix) + "(";
std::size_t index = 0;
((result += (index++ > 0 ? "," : "") + converter_fn(Values)), ...);
result += ")";
return result;
}
}
} // namespace detail
// Convert ck::Sequence to string representation
// Converts a ck::Sequence type to a string in the format "Seq(v1,v2,...,vn)"
// where each value is converted using std::to_string.
//
// Template parameter:
// T: Must be a ck::Sequence<...> type
//
// Constraints:
// - Sequence elements must support std::to_string (typically ck::index_t)
//
// Examples:
// sequence_name<ck::Sequence<>>() returns "Seq()"
// sequence_name<ck::Sequence<42>>() returns "Seq(42)"
// sequence_name<ck::Sequence<1,2,3>>() returns "Seq(1,2,3)"
// sequence_name<ck::Sequence<256,128,64>>() returns "Seq(256,128,64)"
template <typename T>
requires requires { []<ck::index_t... Is>(ck::Sequence<Is...>*) {}(static_cast<T*>(nullptr)); }
constexpr std::string sequence_name()
{
return []<ck::index_t... Is>(ck::Sequence<Is...>*) constexpr {
auto to_string_fn = [](auto value) { return std::to_string(value); };
return detail::build_list_string_values<decltype(to_string_fn), Is...>("Seq", to_string_fn);
}(static_cast<T*>(nullptr));
}
// Convert ck::Tuple to string representation
// Converts a ck::Tuple type to a string in the format "Tuple(e1,e2,...,en)"
// where each element is converted based on its type (layout names or data type names).
//
// Template parameter:
// T: Must be a ck::Tuple<...> type
//
// Constraints:
// - Empty tuples are supported and return "EmptyTuple"
// - All tuple elements must be homogeneous: either all layouts (IsLayoutType) or all data types
// (IsDataType)
// - Mixed layouts and data types in the same tuple will cause a compile-time error
//
// Examples:
// tuple_name<ck::Tuple<>>() returns "EmptyTuple"
// tuple_name<ck::Tuple<ck::tensor_layout::gemm::RowMajor>>() returns "Tuple(RowMajor)"
// tuple_name<ck::Tuple<NCHW,NHWC>>() returns "Tuple(NCHW,NHWC)"
// tuple_name<ck::Tuple<ck::half_t>>() returns "Tuple(fp16)"
// tuple_name<ck::Tuple<ck::half_t,float,double>>() returns "Tuple(fp16,fp32,fp64)"
template <typename T>
requires requires { []<typename... Ts>(ck::Tuple<Ts...>*) {}(static_cast<T*>(nullptr)); }
constexpr std::string tuple_name()
{
return []<typename... Ts>(ck::Tuple<Ts...>*) constexpr {
if constexpr(sizeof...(Ts) == 0)
{
return std::string("EmptyTuple");
}
else if constexpr((IsLayoutType<Ts> && ...))
{
// Lambda wrapper for layout_name
auto layout_name_fn = []<typename U>() { return layout_name<U>(); };
return detail::build_list_string<decltype(layout_name_fn), Ts...>("Tuple",
layout_name_fn);
}
else if constexpr((IsDataType<Ts> && ...))
{
// Lambda wrapper for type_name
auto type_name_fn = []<typename U>() { return type_name<U>(); };
return detail::build_list_string<decltype(type_name_fn), Ts...>("Tuple", type_name_fn);
}
else
{
static_assert((IsLayoutType<Ts> && ...) || (IsDataType<Ts> && ...),
"Tuple elements must be all layouts or all data types, not mixed");
return std::string{}; // unreachable
}
}(static_cast<T*>(nullptr));
}
} // namespace ck_tile::reflect::detail

View File

@@ -0,0 +1,90 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile::builder {
enum class DataType
{
FP32,
FP16,
BF16,
FP8,
I8,
U8
};
// Memory layouts for 1D convolution tensors.
// G: Group, N: Batch, K: Output Channel, C: Input Channel, W: Width
// Enum defines Input, Weight, and Output tensor layouts respectively.
enum class GroupConvLayout1D
{
GNWC_GKXC_GNWK,
NWGC_GKXC_NWGK,
NGCW_GKXC_NGKW,
NGCW_GKCX_NGKW
};
// Memory layouts for 2D convolution tensors.
// G: Group, N: Batch, K: Output Channel, C: Input Channel, Y: Height, X: Width, H: Height
// Enum defines Input, Weight, and Output tensor layouts respectively.
enum class GroupConvLayout2D
{
GNHWC_GKYXC_GNHWK,
NHWGC_GKYXC_NHWGK,
NGCHW_GKYXC_NGKHW,
NGCHW_GKCYX_NGKHW
};
// Memory layouts for 3D convolution tensors.
// G: Group, N: Batch, K: Output Channel, C: Input Channel, Z: Depth, Y: Height, X: Width, D: Depth,
// H: Height Enum defines Input, Weight, and Output tensor layouts respectively.
enum class GroupConvLayout3D
{
GNDHWC_GKZYXC_GNDHWK,
NDHWGC_GKZYXC_NDHWGK,
NGCDHW_GKZYXC_NGKDHW,
NGCDHW_GKCZYX_NGKDHW,
};
// Direction of the convolution operation.
enum class ConvDirection
{
FORWARD,
BACKWARD_DATA,
BACKWARD_WEIGHT
};
// Fused element-wise operations.
enum class ElementwiseOperation
{
BIAS,
BIAS_CLAMP,
BIAS_BNORM_CLAMP,
BILINEAR,
CLAMP,
SCALE,
PASS_THROUGH
};
// Enums for the current block GEMM pipeline versions.
enum class BlockGemmPipelineVersion
{
V1,
V2,
V3,
V4,
V5
};
// Enums for the forward convolution specialization.
enum class ConvFwdSpecialization
{
DEFAULT,
FILTER_1X1_PAD0,
FILTER_1X1_STRIDE1_PAD0,
FILTER_3x3
};
} // namespace ck_tile::builder

View File

@@ -0,0 +1,18 @@
#pragma once
#include <concepts>
#include <string_view>
#include "ck_tile/builder/builder_utils.hpp"
namespace ck_tile::builder {
static constexpr StringLiteral V0_0_0 = "0.0.0";
static constexpr StringLiteral V0_1_0 = "0.1.0";
static constexpr StringLiteral LATEST_API_VERSION = V0_1_0;
template <StringLiteral V>
concept SupportedVersion = (V == V0_0_0) || (V == V0_1_0);
} // namespace ck_tile::builder

View File

@@ -2,11 +2,12 @@ include(gtest)
# Helper function to create a gtest executable with common properties
function(add_ck_builder_test test_name)
add_executable(${test_name} ${ARGN})
add_executable(${test_name} ${ARGN} testing_utils.cpp)
target_compile_features(${test_name} PRIVATE cxx_std_20)
target_include_directories(${test_name} PRIVATE
"${PROJECT_SOURCE_DIR}/experimental/builder/include"
"${PROJECT_SOURCE_DIR}/include"
"${CMAKE_CURRENT_SOURCE_DIR}"
)
target_compile_options(${test_name} PRIVATE
-Wno-global-constructors
@@ -15,12 +16,32 @@ function(add_ck_builder_test test_name)
target_link_libraries(${test_name} PRIVATE GTest::gtest_main GTest::gmock)
endfunction()
# The test_conv_builder target has all the unit tests (each test should run < 10 ms)
add_ck_builder_test(test_conv_builder
test_conv_builder.cpp
test_instance_traits.cpp
testing_utils.cpp)
test_instance_traits_util.cpp)
# Testing the virtual GetInstanceString methods requires kernel compilation.
add_ck_builder_test(test_get_instance_string
test_get_instance_string.cpp)
add_ck_builder_test(test_inline_diff test_inline_diff.cpp testing_utils.cpp)
add_ck_builder_test(test_inline_diff test_inline_diff.cpp)
function(add_ck_factory_test test_name)
add_ck_builder_test(${test_name} ${ARGN})
target_link_libraries(${test_name} PRIVATE composablekernels::device_conv_operations)
endfunction()
add_ck_factory_test(test_testing_utils test_testing_utils.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward test_ck_factory_grouped_convolution_forward.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_clamp test_ck_factory_grouped_convolution_forward_clamp.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_convscale test_ck_factory_grouped_convolution_forward_convscale.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_bilinear test_ck_factory_grouped_convolution_forward_bilinear.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_scale test_ck_factory_grouped_convolution_forward_scale.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_scaleadd_ab test_ck_factory_grouped_convolution_forward_scaleadd_ab.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_bias_clamp test_ck_factory_grouped_convolution_forward_bias_clamp.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_bias_bnorm_clamp test_ck_factory_grouped_convolution_forward_bias_bnorm_clamp.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp)

View File

@@ -0,0 +1,47 @@
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
class FwdConv2DBF16Test : public FwdConvBuilderTestBase
{
};
// 2D BF16 NHWGC (channels-last) with Pipeline V1 and DEFAULT
TEST_F(FwdConv2DBF16Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast)
{
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
.data_type = DataType::BF16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V1,
ConvFwdSpecialization::DEFAULT>();
}
// 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3
TEST_F(FwdConv2DBF16Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3)
{
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
.data_type = DataType::BF16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V5,
ConvFwdSpecialization::FILTER_3x3>();
}

View File

@@ -0,0 +1,26 @@
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
class FwdConv2DFP16Test : public FwdConvBuilderTestBase
{
};
TEST_F(FwdConv2DFP16Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC)
{
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V3,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}

View File

@@ -0,0 +1,26 @@
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
class FwdConv2DFP32Test : public FwdConvBuilderTestBase
{
};
TEST_F(FwdConv2DFP32Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX)
{
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW,
.data_type = DataType::FP32,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V4,
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
}

View File

@@ -0,0 +1,27 @@
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
class FwdConv3DBF16Test : public FwdConvBuilderTestBase
{
};
// 3D BF16 GNDHWC (group-first, channels-last) with Pipeline V3 and DEFAULT
TEST_F(FwdConv3DBF16Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC)
{
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
.spatial_dim = 3,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK,
.data_type = DataType::BF16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V3,
ConvFwdSpecialization::DEFAULT>();
}

View File

@@ -0,0 +1,27 @@
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
class FwdConv3DFP16Test : public FwdConvBuilderTestBase
{
};
// 3D FP16 NDHWGC (channels-last) with Pipeline V4 and FILTER_1X1_PAD0
TEST_F(FwdConv3DFP16Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast)
{
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
.spatial_dim = 3,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V4,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}

View File

@@ -0,0 +1,27 @@
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
class FwdConv3DFP32Test : public FwdConvBuilderTestBase
{
};
// 3D FP32 NGCDHW (channels-first) with Pipeline V1 and FILTER_1X1_PAD0
TEST_F(FwdConv3DFP32Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst)
{
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
.spatial_dim = 3,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW,
.data_type = DataType::FP32,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V1,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}

View File

@@ -0,0 +1,119 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
namespace ck_tile::builder::test {
namespace ckb = ck_tile::builder;
// Convenience struct for a tuple of m, n, and k values.
template <typename T>
struct MNK
{
T m{};
T n{};
T k{};
};
// Specify thread block dimensions for a GEMM.
struct ThreadBlock
{
// Thread block size.
size_t block_size;
// Size of the submatrix problem in a thread block.
MNK<size_t> tile_size;
};
static_assert(ckb::ThreadBlockDescriptor<ThreadBlock>);
// Describe gridwise GEMM parameters.
struct GridwiseGemm
{
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
size_t ak1 = 0;
size_t bk1 = 0;
size_t m_per_xdl = 0;
size_t n_per_xdl = 0;
size_t m_xdl_per_wave = 0;
size_t n_xdl_per_wave = 0;
};
static_assert(ckb::GridwiseGemmDescriptor<GridwiseGemm>);
// Describe Aand B block transfer thread cluster lengths.
struct BlockTransfer
{
size_t k0;
size_t m_n;
size_t k1;
};
static_assert(ckb::BlockTransferDescriptor<BlockTransfer>);
// Describe C block transfer thread cluster lengths.
struct ThreadCluster
{
size_t m_block;
size_t m_wave_per_xdl;
size_t n_block;
size_t n_wave_per_xdl;
};
static_assert(ThreadClusterDescriptor<ThreadCluster>);
struct LdsTransfer
{
size_t src_vector_dim;
size_t src_scalar_per_vector;
size_t lds_dst_scalar_per_vector;
bool is_direct_load;
bool lds_padding;
};
static_assert(LdsTransferDescriptor<LdsTransfer>);
struct Epilogue
{
size_t m_xdl_per_wave_per_shuffle;
size_t n_xdl_per_wave_per_shuffle;
size_t scalar_per_vector;
};
static_assert(EpilogueDescriptor<Epilogue>);
struct AccessOrder
{
std::array<size_t, 3> order;
};
static_assert(AccessOrderDescriptor<AccessOrder>);
struct BlockTransferABC
{
BlockTransfer block_transfer_a;
BlockTransfer block_transfer_b;
ThreadCluster thread_cluster_dims_c;
LdsTransfer lds_transfer_a;
LdsTransfer lds_transfer_b;
Epilogue epilogue_c;
AccessOrder block_transfer_access_order_a;
AccessOrder block_transfer_access_order_b;
AccessOrder src_access_order_a;
AccessOrder src_access_order_b;
};
struct ConvAlgorithm
{
ThreadBlock thread_block;
GridwiseGemm gridwise_gemm;
BlockTransferABC block_transfer;
BlockGemmPipelineVersion pipeline_version;
ConvFwdSpecialization fwd_specialization;
};
static_assert(ckb::ConvAlgorithmDescriptor<ConvAlgorithm>);
static_assert(ckb::SpecifiesThreadBlock<ConvAlgorithm>);
static_assert(ckb::SpecifiesGridwiseGemm<ConvAlgorithm>);
static_assert(ckb::SpecifiesBlockTransfer<ConvAlgorithm>);
static_assert(ckb::SpecifiesLdsTransfer<ConvAlgorithm>);
static_assert(ckb::SpecifiesThreadClusterAccessOrder<ConvAlgorithm>);
static_assert(ckb::SpecifiesSourceAccessOrder<ConvAlgorithm>);
static_assert(ckb::SpecifiesGemmPipelineVersion<ConvAlgorithm>);
static_assert(ckb::SpecifiesFwdConcSpecialization<ConvAlgorithm>);
} // namespace ck_tile::builder::test

View File

@@ -0,0 +1,23 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/builder/conv_signature_concepts.hpp"
namespace ck_tile::builder::test {
template <typename GroupConvLayout>
struct ConvSignature
{
int spatial_dim;
ConvDirection direction;
GroupConvLayout layout;
DataType data_type;
ElementwiseOperation elementwise_operation;
};
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout1D>>);
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout2D>>);
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout3D>>);
} // namespace ck_tile::builder::test

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,118 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp>
#include "ck/utility/data_type.hpp"
#include "testing_utils.hpp"
using ck_tile::test::InstanceSet;
using ck_tile::test::InstancesMatch;
namespace {
constexpr static auto NumDimSpatial = 3;
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
using DsLayout = ck::Tuple<ck::tensor_layout::convolution::NDHWGK>;
using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD;
using ck::tensor_operation::element_wise::Bilinear;
using ck::tensor_operation::element_wise::PassThrough;
template <typename type, typename computeType = type>
using DeviceOp = DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
type, // InDataType
type, // WeiDataType
ck::Tuple<type>,
type, // OutDataType
PassThrough,
PassThrough,
Bilinear,
computeType,
computeType>;
} // namespace
template <typename Case>
struct CkFactoryTestBilinearFwd : public testing::Test
{
static auto get_actual_instances()
{
return InstanceSet::from_factory<typename Case::DeviceOp>();
}
static auto get_expected_instances() { return InstanceSet(Case::expected); }
};
struct Bilinear_F32
{
using DeviceOp = ::DeviceOp<float>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct Bilinear_F32_TF32
{
using DeviceOp = ::DeviceOp<float, ck::tf32_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct Bilinear_F16
{
using DeviceOp = ::DeviceOp<ck::half_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct Bilinear_BF16
{
using DeviceOp = ::DeviceOp<ck::bhalf_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct Bilinear_INT8
{
using DeviceOp = ::DeviceOp<int8_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
using TestTypes =
::testing::Types<Bilinear_F32, Bilinear_F32_TF32, Bilinear_F16, Bilinear_BF16, Bilinear_INT8>;
TYPED_TEST_SUITE(CkFactoryTestBilinearFwd, TestTypes);
TYPED_TEST(CkFactoryTestBilinearFwd, TestInstances)
{
auto actual = TestFixture::get_actual_instances();
auto expected = TestFixture::get_expected_instances();
EXPECT_THAT(actual, InstancesMatch(expected));
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,246 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_add.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp>
#include <ck/library/tensor_operation_instance/device_operation_instance_factory.hpp>
#include "testing_utils.hpp"
using ck_tile::test::InstanceSet;
using ck_tile::test::InstancesMatch;
namespace {
constexpr static auto NumDimSpatial = 3;
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD;
using ck::tensor_operation::device::instance::CombConvScale;
using ck::tensor_operation::device::instance::CombConvScaleRelu;
using ck::tensor_operation::element_wise::ConvInvscale;
using ck::tensor_operation::element_wise::ConvScale;
using ck::tensor_operation::element_wise::ConvScaleAdd;
using ck::tensor_operation::element_wise::ConvScaleRelu;
using ck::tensor_operation::element_wise::PassThrough;
template <typename DsLayout,
typename DsDataType,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename Act,
typename AComputeType,
typename BComputeType>
using DeviceOp = DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType, // InDataType
WeiDataType, // WeiDataType
DsDataType,
OutDataType, // OutDataType
PassThrough,
PassThrough,
Act,
AComputeType,
BComputeType>;
} // namespace
template <typename Case>
struct CkFactoryTestConvFwd : public testing::Test
{
static auto get_actual_instances()
{
return InstanceSet::from_factory<typename Case::DeviceOp>();
}
static auto get_expected_instances() { return InstanceSet(Case::expected); }
};
struct F8_ConvScale
{
using DeviceOp = ::DeviceOp<ck::Tuple<>,
ck::Tuple<>,
ck::f8_t,
ck::f8_t,
ck::f8_t,
ConvScale,
ck::f8_t,
ck::f8_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct F8_BF8_comb1_ConvScale
{
using DeviceOp = ::DeviceOp<ck::Tuple<>,
ck::Tuple<>,
ck::bf8_t,
ck::bf8_t,
ck::f8_t,
ConvScale,
ck::bf8_t,
ck::bf8_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct F8_BF8_comb2_ConvScale
{
using DeviceOp = ::DeviceOp<ck::Tuple<>,
ck::Tuple<>,
ck::f8_t,
ck::bf8_t,
ck::f8_t,
ConvScale,
ck::f8_t,
ck::bf8_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct F8_BF8_comb3_ConvScale
{
using DeviceOp = ::DeviceOp<ck::Tuple<>,
ck::Tuple<>,
ck::bf8_t,
ck::f8_t,
ck::f8_t,
ConvScale,
ck::bf8_t,
ck::f8_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct F8_float_CombConvScale
{
using DeviceOp = ::DeviceOp<ck::Tuple<>,
ck::Tuple<>,
ck::f8_t,
ck::f8_t,
float,
CombConvScale,
ck::f8_t,
ck::f8_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct F8_ConvScaleRelu
{
using DeviceOp = ::DeviceOp<ck::Tuple<>,
ck::Tuple<>,
ck::f8_t,
ck::f8_t,
ck::f8_t,
ConvScaleRelu,
ck::f8_t,
ck::f8_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct F8_CombConvScaleRelu
{
using DeviceOp = ::DeviceOp<ck::Tuple<>,
ck::Tuple<>,
ck::f8_t,
ck::f8_t,
float,
CombConvScaleRelu,
ck::f8_t,
ck::f8_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct F8_ConvScaleAdd
{
using DeviceOp = ::DeviceOp<ck::Tuple<OutLayout>,
ck::Tuple<float>,
ck::f8_t,
ck::f8_t,
ck::f8_t,
ConvScaleAdd,
ck::f8_t,
ck::f8_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct F8_ConvInvscale
{
using DeviceOp = ::DeviceOp<ck::Tuple<>,
ck::Tuple<>,
ck::f8_t,
ck::f8_t,
ck::f8_t,
ConvInvscale,
ck::f8_t,
ck::f8_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
using TestTypes = ::testing::Types<F8_ConvScale,
F8_BF8_comb1_ConvScale,
F8_BF8_comb2_ConvScale,
F8_BF8_comb3_ConvScale,
F8_float_CombConvScale,
F8_ConvScaleRelu,
F8_CombConvScaleRelu,
F8_ConvScaleAdd,
F8_ConvInvscale>;
TYPED_TEST_SUITE(CkFactoryTestConvFwd, TestTypes);
TYPED_TEST(CkFactoryTestConvFwd, TestInstances)
{
auto actual = TestFixture::get_actual_instances();
auto expected = TestFixture::get_expected_instances();
EXPECT_THAT(actual, InstancesMatch(expected));
}

View File

@@ -0,0 +1,187 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dynamic_op.hpp>
#include "ck/utility/data_type.hpp"
#include "testing_utils.hpp"
using ck_tile::test::InstanceSet;
using ck_tile::test::InstancesMatch;
namespace {
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD;
using ck::tensor_operation::element_wise::DynamicUnaryOp;
using ck::tensor_operation::element_wise::PassThrough;
template <ck::index_t NumDimSpatial, typename T>
struct DeviceOpHelper;
template <typename T>
struct DeviceOpHelper<2, T>
{
using InLayout = ck::tensor_layout::convolution::NHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
using OutLayout = ck::tensor_layout::convolution::NHWGK;
using Type = DeviceGroupedConvFwdMultipleABD<2,
InLayout,
WeiLayout,
ck::Tuple<>, // DsLayout
OutLayout,
T, // InDataType
T, // WeiDataType
ck::Tuple<>, // DsDataType
T, // OutDataType
PassThrough,
PassThrough,
DynamicUnaryOp>;
};
template <typename T>
struct DeviceOpHelper<3, T>
{
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
using Type = DeviceGroupedConvFwdMultipleABD<3,
InLayout,
WeiLayout,
ck::Tuple<>, // DsLayout
OutLayout,
T, // InDataType
T, // WeiDataType
ck::Tuple<>, // DsDataType
T, // OutDataType
PassThrough,
PassThrough,
DynamicUnaryOp>;
};
template <ck::index_t NumDimSpatial, typename T>
using DeviceOp = DeviceOpHelper<NumDimSpatial, T>::Type;
} // namespace
template <typename Case>
struct CkFactoryTestBilinearFwd : public testing::Test
{
static auto get_actual_instances()
{
return InstanceSet::from_factory<typename Case::DeviceOp>();
}
static auto get_expected_instances() { return InstanceSet(Case::expected); }
};
struct DyOp_F32_2
{
using DeviceOp = ::DeviceOp<2, float>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct DyOp_F32_3
{
using DeviceOp = ::DeviceOp<3, float>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct DyOp_F16_2
{
using DeviceOp = ::DeviceOp<2, ck::half_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct DyOp_F16_3
{
using DeviceOp = ::DeviceOp<3, ck::half_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct DyOp_BF16_2
{
using DeviceOp = ::DeviceOp<2, ck::bhalf_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct DyOp_BF16_3
{
using DeviceOp = ::DeviceOp<3, ck::bhalf_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct DyOp_INT8_2
{
using DeviceOp = ::DeviceOp<2, int8_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct DyOp_INT8_3
{
using DeviceOp = ::DeviceOp<3, int8_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
using TestTypes = ::testing::Types<DyOp_F32_2,
DyOp_F32_3,
DyOp_F16_2,
DyOp_F16_3,
DyOp_BF16_2,
DyOp_BF16_3,
DyOp_INT8_2,
DyOp_INT8_3>;
TYPED_TEST_SUITE(CkFactoryTestBilinearFwd, TestTypes);
TYPED_TEST(CkFactoryTestBilinearFwd, TestInstances)
{
auto actual = TestFixture::get_actual_instances();
auto expected = TestFixture::get_expected_instances();
EXPECT_THAT(actual, InstancesMatch(expected));
}

View File

@@ -0,0 +1,115 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp>
#include "testing_utils.hpp"
using ck_tile::test::InstanceSet;
using ck_tile::test::InstancesMatch;
namespace {
constexpr static auto NumDimSpatial = 3;
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD;
using ck::tensor_operation::element_wise::PassThrough;
using ck::tensor_operation::element_wise::Scale;
template <typename T, typename ComputeType = T>
using DeviceOp = DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>, // DsLayout
OutLayout,
T, // InDataType
T, // WeiDataType
ck::Tuple<>, // DsDataType
T, // OutDataType
PassThrough,
PassThrough,
Scale,
ComputeType>;
} // namespace
template <typename Case>
struct CkFactoryTestConvFwd : public testing::Test
{
static auto get_actual_instances()
{
return InstanceSet::from_factory<typename Case::DeviceOp>();
}
static auto get_expected_instances() { return InstanceSet(Case::expected); }
};
struct F32
{
using DeviceOp = ::DeviceOp<float>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct F32_TF32
{
using DeviceOp = ::DeviceOp<float, ck::tf32_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct F16
{
using DeviceOp = ::DeviceOp<ck::half_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct BF16
{
using DeviceOp = ::DeviceOp<ck::bhalf_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct S8
{
using DeviceOp = ::DeviceOp<int8_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
using TestTypes = ::testing::Types<F32, F32_TF32, F16, BF16, S8>;
TYPED_TEST_SUITE(CkFactoryTestConvFwd, TestTypes);
TYPED_TEST(CkFactoryTestConvFwd, TestInstances)
{
auto actual = TestFixture::get_actual_instances();
auto expected = TestFixture::get_expected_instances();
EXPECT_THAT(actual, InstancesMatch(expected));
}

View File

@@ -0,0 +1,104 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp>
#include "testing_utils.hpp"
using ck_tile::test::InstanceSet;
using ck_tile::test::InstancesMatch;
namespace {
constexpr static auto NumDimSpatial = 3;
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD;
using ck::tensor_operation::element_wise::PassThrough;
using ck::tensor_operation::element_wise::ScaleAdd;
template <typename T>
using DeviceOp = DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>, // DsLayout
OutLayout,
ck::Tuple<T, T>, // InDataType
ck::Tuple<T, T>, // WeiDataType
ck::Tuple<>, // DsDataType
T, // OutDataType
ScaleAdd,
ScaleAdd,
PassThrough,
T>; // ComputeType
} // namespace
template <typename Case>
struct CkFactoryTestConvFwd : public testing::Test
{
static auto get_actual_instances()
{
return InstanceSet::from_factory<typename Case::DeviceOp>();
}
static auto get_expected_instances() { return InstanceSet(Case::expected); }
};
struct F32
{
using DeviceOp = ::DeviceOp<float>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct F16
{
using DeviceOp = ::DeviceOp<ck::half_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct BF16
{
using DeviceOp = ::DeviceOp<ck::bhalf_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct S8
{
using DeviceOp = ::DeviceOp<int8_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
using TestTypes = ::testing::Types<F32, F16, BF16, S8>;
TYPED_TEST_SUITE(CkFactoryTestConvFwd, TestTypes);
TYPED_TEST(CkFactoryTestConvFwd, TestInstances)
{
auto actual = TestFixture::get_actual_instances();
auto expected = TestFixture::get_expected_instances();
EXPECT_THAT(actual, InstancesMatch(expected));
}

View File

@@ -0,0 +1,105 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp>
#include "testing_utils.hpp"
using ck_tile::test::InstanceSet;
using ck_tile::test::InstancesMatch;
namespace {
constexpr static auto NumDimSpatial = 3;
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
using ck::tensor_layout::convolution::G_K;
using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD;
using ck::tensor_operation::element_wise::PassThrough;
using ck::tensor_operation::element_wise::ScaleAddScaleAddRelu;
template <typename T, typename U = T>
using DeviceOp = DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<OutLayout, G_K>, // DsLayout
OutLayout,
T, // InDataType
T, // WeiDataType
ck::Tuple<U, U>, // DsDataType
T, // OutDataType
PassThrough,
PassThrough,
ScaleAddScaleAddRelu,
T>; // ComputeType
} // namespace
template <typename Case>
struct CkFactoryTestConvFwd : public testing::Test
{
static auto get_actual_instances()
{
return InstanceSet::from_factory<typename Case::DeviceOp>();
}
static auto get_expected_instances() { return InstanceSet(Case::expected); }
};
struct F32
{
using DeviceOp = ::DeviceOp<float>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct F16
{
using DeviceOp = ::DeviceOp<ck::half_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct BF16
{
using DeviceOp = ::DeviceOp<ck::bhalf_t>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
struct S8
{
using DeviceOp = ::DeviceOp<int8_t, float>;
constexpr static auto expected = {
// clang-format off
""
// clang-format on
};
};
using TestTypes = ::testing::Types<F32, F16, BF16, S8>;
TYPED_TEST_SUITE(CkFactoryTestConvFwd, TestTypes);
TYPED_TEST(CkFactoryTestConvFwd, TestInstances)
{
auto actual = TestFixture::get_actual_instances();
auto expected = TestFixture::get_expected_instances();
EXPECT_THAT(actual, InstancesMatch(expected));
}

View File

@@ -0,0 +1,263 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <ck_tile/builder/reflect/instance_traits_util.hpp>
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
namespace ck_tile::reflect::detail {
namespace {
using ::testing::ElementsAre;
using ::testing::IsEmpty;
TEST(InstanceTraitsUtil, SequenceToArrayReturnsEmptyArrayForEmptySequence)
{
EXPECT_THAT(SequenceToArray<ck::Sequence<>>::value, IsEmpty());
}
TEST(InstanceTraitsUtil, SequenceToArrayReturnsArrayWithSingleElement)
{
EXPECT_THAT(SequenceToArray<ck::Sequence<42>>::value, ElementsAre(42));
}
TEST(InstanceTraitsUtil, SequenceToArrayReturnsArrayWithMultipleElements)
{
EXPECT_THAT((SequenceToArray<ck::Sequence<1, 2, 3, 4, 5>>::value), ElementsAre(1, 2, 3, 4, 5));
}
TEST(InstanceTraitsUtil, TypeNameReturnsCorrectStrings)
{
EXPECT_THAT((std::vector<std::string_view>{type_name<ck::half_t>(),
type_name<float>(),
type_name<double>(),
type_name<int8_t>(),
type_name<int32_t>(),
type_name<ck::bhalf_t>(),
type_name<ck::f8_t>(),
type_name<ck::bf8_t>()}),
ElementsAre("fp16", "fp32", "fp64", "s8", "s32", "bf16", "fp8", "bf8"));
}
TEST(InstanceTraitsUtil, LayoutNameReturnsCorrectStringsForGemmLayouts)
{
namespace gemm = ck::tensor_layout::gemm;
EXPECT_THAT((std::vector<std::string_view>{layout_name<gemm::RowMajor>(),
layout_name<gemm::ColumnMajor>(),
layout_name<gemm::MFMA>()}),
ElementsAre("RowMajor", "ColumnMajor", "MFMA"));
}
TEST(InstanceTraitsUtil, LayoutNameReturnsCorrectStringsForConvLayouts)
{
namespace conv = ck::tensor_layout::convolution;
EXPECT_THAT((std::vector<std::string_view>{
// Input tensor layouts
// TODO(deprecated): Remove non-grouped layouts once instances are removed.
layout_name<conv::NCHW>(),
layout_name<conv::NHWC>(),
layout_name<conv::NCDHW>(),
layout_name<conv::NDHWC>(),
// Grouped input layouts
layout_name<conv::GNCHW>(),
layout_name<conv::GNHWC>(),
// Weight tensor layouts
layout_name<conv::KCYX>(),
layout_name<conv::KYXC>(),
layout_name<conv::GKCYX>(),
layout_name<conv::GKYXC>(),
// Output tensor layouts
layout_name<conv::NKHW>(),
layout_name<conv::NHWK>(),
layout_name<conv::GNKHW>(),
layout_name<conv::GNHWK>(),
// Strided layouts
// TODO(deprecated): Remove strided layouts once instances are removed.
layout_name<conv::G_NHW_C>(),
layout_name<conv::G_K_YX_C>(),
layout_name<conv::G_NHW_K>(),
// Bias layouts
layout_name<conv::G_C>(),
layout_name<conv::G_K>()}),
ElementsAre("NCHW",
"NHWC",
"NCDHW",
"NDHWC",
"GNCHW",
"GNHWC",
"KCYX",
"KYXC",
"GKCYX",
"GKYXC",
"NKHW",
"NHWK",
"GNKHW",
"GNHWK",
"G_NHW_C",
"G_K_YX_C",
"G_NHW_K",
"G_C",
"G_K"));
}
TEST(InstanceTraitsUtil, ElementwiseOpNameReturnsCorrectStrings)
{
namespace element_wise = ck::tensor_operation::element_wise;
EXPECT_THAT((std::vector<std::string_view>{
elementwise_op_name<element_wise::PassThrough>(),
elementwise_op_name<element_wise::Scale>(),
elementwise_op_name<element_wise::Bilinear>(),
elementwise_op_name<element_wise::Add>(),
elementwise_op_name<element_wise::AddRelu>(),
elementwise_op_name<element_wise::Relu>(),
elementwise_op_name<element_wise::BiasNormalizeInInferClamp>(),
elementwise_op_name<element_wise::Clamp>(),
elementwise_op_name<element_wise::AddClamp>()}),
ElementsAre("PassThrough",
"Scale",
"Bilinear",
"Add",
"AddRelu",
"Relu",
"BiasNormalizeInInferClamp",
"Clamp",
"AddClamp"));
}
TEST(InstanceTraitsUtil, ConvFwdSpecNameReturnsCorrectStrings)
{
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
EXPECT_THAT(
(std::vector<std::string_view>{conv_fwd_spec_name(Default),
conv_fwd_spec_name(Filter1x1Stride1Pad0),
conv_fwd_spec_name(Filter1x1Pad0),
conv_fwd_spec_name(Filter3x3),
conv_fwd_spec_name(OddC)}),
ElementsAre("Default", "Filter1x1Stride1Pad0", "Filter1x1Pad0", "Filter3x3", "OddC"));
}
TEST(InstanceTraitsUtil, GemmSpecNameReturnsCorrectStrings)
{
using enum ck::tensor_operation::device::GemmSpecialization;
EXPECT_THAT((std::vector<std::string_view>{gemm_spec_name(Default),
gemm_spec_name(MPadding),
gemm_spec_name(NPadding),
gemm_spec_name(KPadding),
gemm_spec_name(MNPadding),
gemm_spec_name(MKPadding),
gemm_spec_name(NKPadding),
gemm_spec_name(MNKPadding),
gemm_spec_name(OPadding),
gemm_spec_name(MOPadding),
gemm_spec_name(NOPadding),
gemm_spec_name(KOPadding),
gemm_spec_name(MNOPadding),
gemm_spec_name(MKOPadding),
gemm_spec_name(NKOPadding),
gemm_spec_name(MNKOPadding)}),
ElementsAre("Default",
"MPadding",
"NPadding",
"KPadding",
"MNPadding",
"MKPadding",
"NKPadding",
"MNKPadding",
"OPadding",
"MOPadding",
"NOPadding",
"KOPadding",
"MNOPadding",
"MKOPadding",
"NKOPadding",
"MNKOPadding"));
}
TEST(InstanceTraitsUtil, PipelineSchedulerNameReturnsCorrectStrings)
{
using enum ck::BlockGemmPipelineScheduler;
EXPECT_THAT((std::vector<std::string_view>{pipeline_scheduler_name(Intrawave),
pipeline_scheduler_name(Interwave)}),
ElementsAre("Intrawave", "Interwave"));
}
TEST(InstanceTraitsUtil, PipelineVersionNameReturnsCorrectStrings)
{
using enum ck::BlockGemmPipelineVersion;
EXPECT_THAT((std::vector<std::string_view>{pipeline_version_name(v1),
pipeline_version_name(v2),
pipeline_version_name(v3),
pipeline_version_name(v4),
pipeline_version_name(v5)}),
ElementsAre("v1", "v2", "v3", "v4", "v5"));
}
TEST(InstanceTraitsUtil, TupleNameReturnsEmptyTupleForEmptyTuple)
{
EXPECT_EQ(tuple_name<ck::Tuple<>>(), "EmptyTuple");
}
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForSingleLayout)
{
EXPECT_EQ(tuple_name<ck::Tuple<ck::tensor_layout::convolution::NCHW>>(), "Tuple(NCHW)");
}
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForTwoLayouts)
{
EXPECT_EQ((tuple_name<ck::Tuple<ck::tensor_layout::convolution::NCHW,
ck::tensor_layout::convolution::NHWC>>()),
"Tuple(NCHW,NHWC)");
}
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForThreeLayouts)
{
EXPECT_EQ((tuple_name<ck::Tuple<ck::tensor_layout::convolution::NCHW,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NKHW>>()),
"Tuple(NCHW,NHWC,NKHW)");
}
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForSingleDataType)
{
EXPECT_EQ(tuple_name<ck::Tuple<ck::half_t>>(), "Tuple(fp16)");
}
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForTwoDataTypes)
{
EXPECT_EQ((tuple_name<ck::Tuple<ck::half_t, float>>()), "Tuple(fp16,fp32)");
}
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForThreeDataTypes)
{
EXPECT_EQ((tuple_name<ck::Tuple<ck::half_t, float, double>>()), "Tuple(fp16,fp32,fp64)");
}
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForEmptySequence)
{
EXPECT_EQ(sequence_name<ck::Sequence<>>(), "Seq()");
}
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForSingleValueSequence)
{
EXPECT_EQ(sequence_name<ck::Sequence<42>>(), "Seq(42)");
}
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForTwoValueSequence)
{
EXPECT_EQ((sequence_name<ck::Sequence<1, 2>>()), "Seq(1,2)");
}
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForMultipleValueSequence)
{
EXPECT_EQ((sequence_name<ck::Sequence<256, 128, 64, 32, 16>>()), "Seq(256,128,64,32,16)");
}
} // namespace
} // namespace ck_tile::reflect::detail

View File

@@ -0,0 +1,98 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp>
#include "testing_utils.hpp"
using ck_tile::test::InstanceMatcher;
using ck_tile::test::InstanceSet;
using ck_tile::test::StringEqWithDiff;
TEST(InstanceSet, FromFactory)
{
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
2, // NDimSpatial
ck::tensor_operation::device::instance::NHWGC, // InLayout
ck::tensor_operation::device::instance::GKYXC, // WeiLayout
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
ck::tensor_operation::device::instance::NHWGK, // OutLayout
ck::half_t, // ADataType
ck::half_t, // BDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::half_t, // AComputeType
ck::half_t>; // BComputeType
const auto instances = InstanceSet::from_factory<DeviceOp>();
EXPECT_THAT(instances.instances.size(), testing::Gt(0));
const auto* el =
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,"
"fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,PassThrough,OddC,MNKPadding,64,16,16,64,"
"8,8,16,16,1,1,Seq(8,8,1),Seq(1,0,2),Seq(1,0,2),2,8,8,0,Seq(8,8,1),Seq(1,0,2),Seq(1,0,2),2,"
"8,8,0,1,1,Seq(1,16,1,4),4,Intrawave,v2,fp16,fp16>";
EXPECT_THAT(instances.instances, testing::Contains(el));
}
TEST(InstanceMatcher, Basic)
{
auto a = InstanceSet{
"python",
"cobra",
"boa",
};
auto b = InstanceSet{
"cobra",
"boa",
"python",
};
auto c = InstanceSet{
"adder",
"boa",
"cobra",
};
auto d = InstanceSet{
"boa",
"python",
};
EXPECT_THAT(a, InstancesMatch(b));
EXPECT_THAT(c, Not(InstancesMatch(b)));
EXPECT_THAT(a, Not(InstancesMatch(d)));
EXPECT_THAT(d, Not(InstancesMatch(b)));
}
TEST(InstanceMatcher, ExplainMatchResult)
{
auto actual = InstanceSet{
"python",
"cobra",
"boa",
};
auto expected = InstanceSet{
"adder",
"boa",
"cobra",
"rattlesnake",
};
testing::StringMatchResultListener listener;
EXPECT_TRUE(!ExplainMatchResult(InstancesMatch(expected), actual, &listener));
EXPECT_THAT(listener.str(),
StringEqWithDiff("\n"
" Missing: 2\n"
"- adder\n"
"- rattlesnake\n"
"Unexpected: 1\n"
"- python\n"));
}

View File

@@ -1,21 +1,18 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <string>
#include <sstream>
#include <vector>
#include <algorithm>
#include <unistd.h>
#include "testing_utils.hpp"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include "testing_utils.hpp"
#include <unistd.h>
#include <string>
#include <sstream>
#include <ostream>
#include <vector>
#include <algorithm>
namespace ck_tile::test {
namespace {
} // namespace
// Wagner-Fischer Algorithm for Computing Edit Distance and Inline Diff
//
// OUTPUT FORMAT: [expected|actual] for differences, plain text for matches
@@ -216,4 +213,88 @@ void StringEqWithDiffMatcher::DescribeNegationTo(std::ostream* os) const
return ::testing::MakeMatcher(new StringEqWithDiffMatcher(expected));
}
std::ostream& operator<<(std::ostream& os, const InstanceSet& set)
{
// These sets can grow very large, and so its not very nice or useful to print them
// in the event of a mismatch. Just print a brief description here, and use
// InstancesMatcher to print a more useful message.
return (os << "(set of " << set.instances.size() << " instances)");
}
InstanceMatcher::InstanceMatcher(const InstanceSet& expected) : expected_(expected) {}
::testing::Matcher<InstanceSet> InstancesMatch(const InstanceSet& expected)
{
return ::testing::MakeMatcher(new InstanceMatcher(expected));
}
bool InstanceMatcher::MatchAndExplain(InstanceSet actual,
::testing::MatchResultListener* listener) const
{
if(actual.instances == expected_.instances)
{
return true;
}
if(listener->IsInterested())
{
std::vector<std::string> instances;
std::set_difference(expected_.instances.begin(),
expected_.instances.end(),
actual.instances.begin(),
actual.instances.end(),
std::back_inserter(instances));
*listener << "\n";
if(instances.size() > 0)
{
*listener << " Missing: " << instances.size() << "\n";
for(const auto& instance : instances)
{
if(instance == "")
{
*listener << "- (empty string)\n";
}
else
{
*listener << "- " << instance << "\n";
}
}
}
instances.clear();
std::set_difference(actual.instances.begin(),
actual.instances.end(),
expected_.instances.begin(),
expected_.instances.end(),
std::back_inserter(instances));
if(instances.size() > 0)
{
*listener << "Unexpected: " << instances.size() << "\n";
for(const auto& instance : instances)
{
if(instance == "")
{
*listener << "- (empty string)\n";
}
else
{
*listener << "- " << instance << "\n";
}
}
}
}
return false;
}
void InstanceMatcher::DescribeTo(std::ostream* os) const { *os << expected_; }
void InstanceMatcher::DescribeNegationTo(std::ostream* os) const
{
*os << "is not equal to " << expected_;
}
} // namespace ck_tile::test

View File

@@ -1,10 +1,15 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/device_operation_instance_factory.hpp>
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <string>
#include <sstream>
#include <iosfwd>
#include <set>
#include <vector>
#include <array>
namespace ck_tile::test {
@@ -40,4 +45,68 @@ class StringEqWithDiffMatcher : public ::testing::MatcherInterface<std::string>
// Factory function for the StringEqWithDiff matcher
::testing::Matcher<std::string> StringEqWithDiff(const std::string& expected);
using ck::tensor_operation::device::instance::DeviceOperationInstanceFactory;
// This utility concept checks whether a type is a valid "Device Operation" -
// that is, there is a valid specialization of `DeviceOperationInstanceFactory`
// for it available.
template <typename DeviceOp>
concept HasCkFactory = requires {
{
DeviceOperationInstanceFactory<DeviceOp>::GetInstances()
} -> std::convertible_to<std::vector<std::unique_ptr<DeviceOp>>>;
};
// This structure represents a (unique) set of instances, either a statically
// defined one (for testing) or one obtained from DeviceOperationInstanceFactory.
// The idea is that we use this structure as a utility to compare a set of
// instances. Instances are stored in a set so that they can be lexicographically
// compared, this helps generating readable error messages which just contain
// the differenses between sets.
struct InstanceSet
{
explicit InstanceSet() {}
explicit InstanceSet(std::initializer_list<const char*> items)
: instances(items.begin(), items.end())
{
}
template <HasCkFactory DeviceOp>
static InstanceSet from_factory()
{
auto set = InstanceSet();
const auto ops = DeviceOperationInstanceFactory<DeviceOp>::GetInstances();
for(const auto& op : ops)
{
set.instances.insert(op->GetInstanceString());
}
return set;
}
std::set<std::string> instances;
};
std::ostream& operator<<(std::ostream& os, const InstanceSet& set);
// This is a custom Google Test matcher which can be used to compare two sets
// of instance names, with utility functions that print a helpful error
// message about the difference between the checked sets. Use `InstancesMatch`
// to obtain an instance of this type.
struct InstanceMatcher : public ::testing::MatcherInterface<InstanceSet>
{
explicit InstanceMatcher(const InstanceSet& expected);
bool MatchAndExplain(InstanceSet actual,
::testing::MatchResultListener* listener) const override;
void DescribeTo(std::ostream* os) const override;
void DescribeNegationTo(std::ostream* os) const override;
InstanceSet expected_;
};
::testing::Matcher<InstanceSet> InstancesMatch(const InstanceSet& expected);
} // namespace ck_tile::test

View File

@@ -0,0 +1,103 @@
#pragma once
#include <gtest/gtest.h>
#include "impl/conv_algorithm_types.hpp"
#include "impl/conv_signature_types.hpp"
#include "ck_tile/builder/conv_builder.hpp"
namespace ck_tile::builder::test_utils {
using namespace ck_tile::builder;
using namespace test;
// Common test base class
class FwdConvBuilderTestBase : public ::testing::Test
{
};
// Common test implementation
template <auto FwdConvSignature,
ThreadBlock FwdThreadBlock,
BlockGemmPipelineVersion FwdPipelineVersion,
ConvFwdSpecialization FwdConvSpecialization>
constexpr void run_test()
{
constexpr GridwiseGemm FwdGemmParams{.ak1 = 8,
.bk1 = 8,
.m_per_xdl = 32,
.n_per_xdl = 32,
.m_xdl_per_wave = 4,
.n_xdl_per_wave = 4};
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 2,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = false},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = false},
.epilogue_c = {.m_xdl_per_wave_per_shuffle = 1,
.n_xdl_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {1, 0, 2},
.block_transfer_access_order_b = {1, 0, 2},
.src_access_order_a = {1, 0, 2},
.src_access_order_b = {1, 0, 2}};
constexpr ConvAlgorithm FwdConvAlgorithm{.thread_block = FwdThreadBlock,
.gridwise_gemm = FwdGemmParams,
.block_transfer = FwdBlockTransfer,
.pipeline_version = FwdPipelineVersion,
.fwd_specialization = FwdConvSpecialization};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
auto instance = typename Builder::Instance{};
const auto kernel_string = instance.GetTypeString();
std::cout << "Generated kernel: " << kernel_string << std::endl;
EXPECT_GT(kernel_string.size(), 0);
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"));
// Verify pipeline version is correct
if(FwdPipelineVersion == BlockGemmPipelineVersion::V1)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos);
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V3)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos);
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V4)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos);
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V5)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos);
// Verify specialization is correct
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
const auto invoker_ptr = instance.MakeInvokerPointer();
EXPECT_NE(invoker_ptr, nullptr);
}
// Common thread block configurations
constexpr ThreadBlock DefaultThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
constexpr ThreadBlock SmallThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};
} // namespace ck_tile::builder::test_utils

View File

@@ -12,6 +12,8 @@ namespace element_wise {
struct Add
{
static constexpr const char* name = "Add";
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
@@ -94,6 +96,8 @@ struct Add
struct Max
{
static constexpr const char* name = "Max";
template <typename Y, typename X0, typename X1>
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
{
@@ -105,6 +109,8 @@ struct Max
struct Min
{
static constexpr const char* name = "Min";
template <typename Y, typename X0, typename X1>
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
{
@@ -116,6 +122,8 @@ struct Min
struct Multiply
{
static constexpr const char* name = "Multiply";
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
@@ -208,6 +216,8 @@ struct Multiply
struct ScaleAdd
{
static constexpr const char* name = "ScaleAdd";
__host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {}
template <typename Y, typename X0, typename X1>
@@ -235,6 +245,8 @@ struct ScaleAdd
struct Subtract
{
static constexpr const char* name = "Subtract";
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
@@ -279,6 +291,8 @@ struct Subtract
struct Bilinear
{
static constexpr const char* name = "Bilinear";
Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
template <typename Y, typename X0, typename X1>
@@ -353,6 +367,8 @@ struct Bilinear
struct AddClamp
{
static constexpr const char* name = "AddClamp";
AddClamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
: floor_(floor), ceil_(ceil){};
@@ -442,6 +458,8 @@ struct AddClamp
struct AddRelu
{
static constexpr const char* name = "AddRelu";
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
@@ -523,6 +541,8 @@ struct AddRelu
struct AddHardswish
{
static constexpr const char* name = "AddHardswish";
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
@@ -560,6 +580,8 @@ struct AddHardswish
// E = FastGelu(C + D)
struct AddFastGelu
{
static constexpr const char* name = "AddFastGelu";
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
@@ -625,6 +647,8 @@ struct AddFastGelu
// E = MultiplyFastGelu(C + D)
struct MultiplyFastGelu
{
static constexpr const char* name = "MultiplyFastGelu";
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
@@ -690,6 +714,8 @@ struct MultiplyFastGelu
// E = Silu(C + D)
struct AddSilu
{
static constexpr const char* name = "AddSilu";
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
@@ -740,6 +766,8 @@ struct AddSilu
struct ConvScaleAdd
{
static constexpr const char* name = "ConvScaleAdd";
__host__ __device__ ConvScaleAdd(float scale_in = 1.f,
float scale_wei = 1.f,
float scale_out = 1.f)

View File

@@ -13,6 +13,8 @@ namespace element_wise {
template <typename... UnaryOpsSet>
struct UnaryCombinedOp
{
static constexpr const char* name = "UnaryCombinedOp";
__host__ __device__ UnaryCombinedOp() : unary_ops_() {}
__host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {}
@@ -33,6 +35,8 @@ struct UnaryCombinedOp
template <typename BinaryOp, typename UnaryOp0, typename UnaryOp1>
struct BinaryWithUnaryCombinedOp
{
static constexpr const char* name = "BinaryWithUnaryCombinedOp";
__host__ __device__ BinaryWithUnaryCombinedOp() : binary_op_(), unary_op0_(), unary_op1_() {}
__host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op,
@@ -66,6 +70,8 @@ template <typename BinaryOp0,
typename UnaryOp2>
struct TrinaryWithUnaryCombinedOp
{
static constexpr const char* name = "TrinaryWithUnaryCombinedOp";
__host__ __device__ TrinaryWithUnaryCombinedOp()
: binary_op0_(), binary_op1_(), unary_op0_(), unary_op1_(), unary_op2_()
{

View File

@@ -33,6 +33,8 @@ namespace element_wise {
struct AddReluAdd
{
static constexpr const char* name = "AddReluAdd";
template <typename Y, typename X0, typename X1, typename X2>
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
@@ -102,6 +104,8 @@ struct AddReluAdd
struct AddHardswishAdd
{
static constexpr const char* name = "AddHardswishAdd";
template <typename Y, typename X0, typename X1, typename X2>
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
@@ -134,6 +138,8 @@ struct AddHardswishAdd
// E = C + D0 + D1
struct AddAdd
{
static constexpr const char* name = "AddAdd";
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const
{
@@ -163,6 +169,8 @@ struct AddAdd
// E = (C + D0) x D1
struct AddMultiply
{
static constexpr const char* name = "AddMultiply";
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
@@ -199,6 +207,8 @@ struct AddMultiply
// E = C x D0 + D1
struct MultiplyAdd
{
static constexpr const char* name = "MultiplyAdd";
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
@@ -251,6 +261,8 @@ struct MultiplyAdd
struct MultiplyMultiply
{
static constexpr const char* name = "MultiplyMultiply";
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
@@ -306,6 +318,8 @@ struct MultiplyMultiply
struct MultiplyAddFastGelu
{
static constexpr const char* name = "MultiplyAddFastGelu";
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
@@ -327,6 +341,8 @@ struct MultiplyAddFastGelu
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
static constexpr const char* name = "AddAddFastGelu";
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
@@ -398,6 +414,7 @@ struct AddAddFastGelu
// E = Relu(alpha1 * C + alpha2 * D0 + D1)
struct ScaleAddScaleAddRelu
{
static constexpr const char* name = "ScaleAddScaleAddRelu";
ScaleAddScaleAddRelu(const float alpha1 = 1.f, const float alpha2 = 1.f)
: alpha1_(alpha1), alpha2_(alpha2)
@@ -462,6 +479,8 @@ struct ScaleAddScaleAddRelu
struct Normalize
{
static constexpr const char* name = "Normalize";
// FIXME: is double absolutely necessary?
Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {}
@@ -533,6 +552,8 @@ struct Normalize
// The data type of mean and variance is used as AccDataType
struct NormalizeInInfer
{
static constexpr const char* name = "NormalizeInInfer";
NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {}
template <typename T1, typename T2, typename T3, typename T4>
@@ -565,6 +586,8 @@ struct NormalizeInInfer
// used by Conv+Bias+BatchNorm+Clamp inference
struct BiasNormalizeInInferClamp
{
static constexpr const char* name = "BiasNormalizeInInferClamp";
BiasNormalizeInInferClamp(float floor = 0.f,
float ceil = NumericLimits<float>::Max(),
float epsilon = 1e-4)
@@ -620,6 +643,8 @@ struct UnaryTypeConvert;
template <>
struct UnaryTypeConvert<float, ck::bhalf_t>
{
static constexpr const char* name = "UnaryTypeConvert";
__host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
{
y = ck::type_convert<float, ck::bhalf_t>(x);
@@ -629,6 +654,8 @@ struct UnaryTypeConvert<float, ck::bhalf_t>
template <>
struct UnaryTypeConvert<ck::bhalf_t, float>
{
static constexpr const char* name = "UnaryTypeConvert";
__host__ __device__ void operator()(ck::bhalf_t& y, float& x) const
{
y = ck::type_convert<ck::bhalf_t, float>(x);

View File

@@ -24,6 +24,8 @@ namespace element_wise {
template <typename Activation>
struct Activation_Mul_Clamp
{
static constexpr const char* name = "Activation_Mul_Clamp";
// Convolution + Activation (piecewise linear function)
// If an activation is piecewise linear function, then Activation(Sy * Qy) = Sy * Activation(Qy)
// Z = Activation(Y) = Activation(W @ X)
@@ -71,6 +73,8 @@ struct Activation_Mul_Clamp
template <typename Activation>
struct Mul_Activation_Mul_Clamp
{
static constexpr const char* name = "Mul_Activation_Mul_Clamp";
// Convolution + Activation (non piecewise linear function)
// Z = Activation(Y) = Activation(W @ X)
// Sz * Qz = Activation(Sy * Qy)
@@ -101,6 +105,8 @@ struct Mul_Activation_Mul_Clamp
template <typename Activation>
struct Activation_Mul2_Clamp
{
static constexpr const char* name = "Activation_Mul2_Clamp";
Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
__host__ __device__ constexpr void
@@ -131,6 +137,8 @@ struct Activation_Mul2_Clamp
template <typename Activation>
struct Add_Activation_Mul_Clamp
{
static constexpr const char* name = "Add_Activation_Mul_Clamp";
// Convolution + bias
// Let Bias = B = Sw * Sx * Qb
// Where Qb is int32
@@ -175,6 +183,8 @@ struct Add_Activation_Mul_Clamp
template <typename Activation>
struct Add_Activation_Mul2_Clamp
{
static constexpr const char* name = "Add_Activation_Mul2_Clamp";
Add_Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
__host__ __device__ constexpr void
@@ -206,6 +216,8 @@ struct Add_Activation_Mul2_Clamp
template <typename Activation>
struct Add_Mul_Activation_Mul_Clamp
{
static constexpr const char* name = "Add_Mul_Activation_Mul_Clamp";
// Convolution + Activation (non piecewise linear function)
// Z = Activation(Y) = Activation(W @ X + B)
// Sz * Qz = Activation(Sy * Qy)
@@ -250,6 +262,8 @@ struct Add_Mul_Activation_Mul_Clamp
template <typename Activation>
struct Add_Mul2_Activation_Mul_Clamp
{
static constexpr const char* name = "Add_Mul2_Activation_Mul_Clamp";
Add_Mul2_Activation_Mul_Clamp(float scale_z_inv, Activation activationOp)
: scale_z_inv_(scale_z_inv), activationOp_(activationOp)
{

View File

@@ -157,6 +157,8 @@ namespace element_wise {
struct PassThroughPack8
{
static constexpr const char* name = "PassThroughPack8";
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
@@ -265,6 +267,8 @@ struct PassThroughPack8
struct DequantPack8
{
static constexpr const char* name = "DequantPack8";
template <typename Y, typename X, typename Z>
__host__ __device__ void operator()(Y& y, const X& x, const Z& z) const;
@@ -301,6 +305,8 @@ struct DequantPack8
struct PassThroughPack2
{
static constexpr const char* name = "PassThroughPack2";
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
@@ -332,6 +338,8 @@ struct PassThroughPack2
struct PassThrough
{
static constexpr const char* name = "PassThrough";
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
@@ -556,6 +564,8 @@ struct PassThrough
struct UnaryConvert
{
static constexpr const char* name = "UnaryConvert";
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
@@ -565,6 +575,8 @@ struct UnaryConvert
struct ConvertBF16RTN
{
static constexpr const char* name = "ConvertBF16RTN";
// convert to bf16 using round to nearest (rtn)
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
@@ -582,6 +594,8 @@ struct ConvertBF16RTN
struct ConvertF8SR
{
static constexpr const char* name = "ConvertF8SR";
// convert to fp8 using stochastic rounding (SR)
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
@@ -600,6 +614,8 @@ struct ConvertF8SR
struct ConvertF8RNE
{
static constexpr const char* name = "ConvertF8RNE";
// convert to fp8 using rounding to nearest even
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
@@ -618,6 +634,8 @@ struct ConvertF8RNE
struct Scale
{
static constexpr const char* name = "Scale";
__host__ __device__ Scale(float scale = 1.f) : scale_(scale) {}
template <typename Y, typename X>
@@ -663,6 +681,8 @@ struct Scale
struct ScaleAndResetNaNToMinusInfinity
{
static constexpr const char* name = "ScaleAndResetNaNToMinusInfinity";
__host__ __device__ ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {}
template <typename Y, typename X>
@@ -679,6 +699,8 @@ struct ScaleAndResetNaNToMinusInfinity
struct UnaryDivide
{
static constexpr const char* name = "UnaryDivide";
__host__ __device__ UnaryDivide(const int32_t divider = 1) : divider_(divider) {}
template <typename T>
@@ -723,6 +745,8 @@ struct UnaryDivide
struct UnarySquare
{
static constexpr const char* name = "UnarySquare";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -739,6 +763,8 @@ struct UnarySquare
struct UnaryAbs
{
static constexpr const char* name = "UnaryAbs";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -769,6 +795,8 @@ struct UnaryAbs
struct UnarySqrt
{
static constexpr const char* name = "UnarySqrt";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -781,6 +809,8 @@ struct UnarySqrt
struct Clamp
{
static constexpr const char* name = "Clamp";
Clamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
: floor_(floor), ceil_(ceil){};
@@ -854,6 +884,8 @@ struct Clamp
struct Relu
{
static constexpr const char* name = "Relu";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -890,6 +922,8 @@ struct Relu
// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
struct FastGelu
{
static constexpr const char* name = "FastGelu";
template <typename Y, typename X>
__host__ void operator()(Y& y, const X& x) const;
@@ -1005,6 +1039,8 @@ struct FastGelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct Gelu
{
static constexpr const char* name = "Gelu";
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
@@ -1023,6 +1059,8 @@ struct Gelu
struct Sigmoid
{
static constexpr const char* name = "Sigmoid";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1047,6 +1085,8 @@ struct Sigmoid
struct Silu
{
static constexpr const char* name = "SiLU";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1060,6 +1100,8 @@ struct Silu
struct TanH
{
static constexpr const char* name = "TanH";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1083,6 +1125,8 @@ struct TanH
struct ACos
{
static constexpr const char* name = "ACos";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1097,6 +1141,8 @@ struct ACos
struct Neg
{
static constexpr const char* name = "Neg";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1111,6 +1157,8 @@ struct Neg
struct ATan
{
static constexpr const char* name = "ATan";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1125,6 +1173,8 @@ struct ATan
struct Sin
{
static constexpr const char* name = "Sin";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1139,6 +1189,8 @@ struct Sin
struct ASinH
{
static constexpr const char* name = "ASinH";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1153,6 +1205,8 @@ struct ASinH
struct Cos
{
static constexpr const char* name = "Cos";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1167,6 +1221,8 @@ struct Cos
struct ACosH
{
static constexpr const char* name = "ACosH";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1181,6 +1237,8 @@ struct ACosH
struct Tan
{
static constexpr const char* name = "Tan";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1195,6 +1253,8 @@ struct Tan
struct ATanH
{
static constexpr const char* name = "ATanH";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1209,6 +1269,8 @@ struct ATanH
struct SinH
{
static constexpr const char* name = "SinH";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1223,6 +1285,8 @@ struct SinH
struct Ceil
{
static constexpr const char* name = "Ceil";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1237,6 +1301,8 @@ struct Ceil
struct Exp
{
static constexpr const char* name = "Exp";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1251,6 +1317,8 @@ struct Exp
struct CosH
{
static constexpr const char* name = "CosH";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1265,6 +1333,8 @@ struct CosH
struct Floor
{
static constexpr const char* name = "Floor";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1279,6 +1349,8 @@ struct Floor
struct Log
{
static constexpr const char* name = "Log";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1293,6 +1365,8 @@ struct Log
struct ASin
{
static constexpr const char* name = "ASin";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1307,6 +1381,8 @@ struct ASin
struct Rcp
{
static constexpr const char* name = "Rcp";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1321,6 +1397,8 @@ struct Rcp
struct Swish
{
static constexpr const char* name = "Swish";
Swish(float beta = 1.0f) : beta_(beta) {}
template <typename Y, typename X>
@@ -1350,6 +1428,8 @@ struct Swish
struct SoftRelu
{
static constexpr const char* name = "SoftRelu";
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
@@ -1378,6 +1458,8 @@ struct SoftRelu
struct Power
{
static constexpr const char* name = "Power";
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma){};
@@ -1412,6 +1494,8 @@ struct Power
struct ClippedRelu
{
static constexpr const char* name = "ClippedRelu";
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
template <typename T>
@@ -1441,6 +1525,8 @@ struct ClippedRelu
struct LeakyRelu
{
static constexpr const char* name = "LeakyRelu";
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
template <typename T>
@@ -1468,6 +1554,8 @@ struct LeakyRelu
struct Elu
{
static constexpr const char* name = "Elu";
Elu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
@@ -1495,6 +1583,8 @@ struct Elu
struct Logistic
{
static constexpr const char* name = "Logistic";
Logistic(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
@@ -1523,6 +1613,8 @@ struct Logistic
struct ConvInvscale
{
static constexpr const char* name = "ConvInvscale";
__host__ __device__ ConvInvscale(float scale_in = 1.f,
float scale_wei = 1.f,
float scale_out = 1.f)
@@ -1546,6 +1638,8 @@ struct ConvInvscale
struct ConvScale
{
static constexpr const char* name = "ConvScale";
__host__ __device__ ConvScale(float scale_in = 1.f,
float scale_wei = 1.f,
float scale_out = 1.f)
@@ -1569,6 +1663,8 @@ struct ConvScale
struct ConvScaleRelu
{
static constexpr const char* name = "ConvScaleRelu";
__host__ __device__ ConvScaleRelu(float scale_in = 1.f,
float scale_wei = 1.f,
float scale_out = 1.f)

View File

@@ -16,10 +16,17 @@ __device__ void llvm_amdgcn_s_wait_dscnt(short cnt) __asm("llvm.amdgcn.s.wait.ds
__device__ void block_sync_lds()
{
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#ifdef __gfx12__
#if defined(__gfx12__)
llvm_amdgcn_s_wait_dscnt(0);
asm volatile("s_barrier_signal -1\n\t"
"s_barrier_wait -1");
#elif defined(__gfx11__)
// asm volatile("\
// s_waitcnt lgkmcnt(0) \n \
// s_barrier \
// " ::);
__builtin_amdgcn_s_waitcnt(0xfc07);
__builtin_amdgcn_s_barrier();
#else
// asm volatile("\
// s_waitcnt lgkmcnt(0) \n \

View File

@@ -46,7 +46,7 @@
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/reference/reference_transpose.hpp"
#include "ck_tile/host/rotating_buffers.hpp"
#include "ck_tile/host/shuffle_utils.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/host/tensor_shuffle_utils.hpp"
#include "ck_tile/host/timer.hpp"

View File

@@ -7,6 +7,7 @@
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
@@ -14,14 +15,18 @@ namespace ck_tile {
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
typename OutDataType,
typename Elfunc = ck_tile::element_wise::PassThrough,
typename Tuple = ck_tile::tuple<>>
CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input,
const HostTensor<WeiDataType>& weight,
HostTensor<OutDataType>& output,
std::vector<ck_tile::long_index_t> conv_strides,
std::vector<ck_tile::long_index_t> conv_dilations,
std::vector<ck_tile::long_index_t> in_left_pads,
std::vector<ck_tile::long_index_t>)
std::vector<ck_tile::long_index_t>,
Elfunc elfunc = Elfunc{},
Tuple ds = {})
{
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
weight.get_num_of_dimension() == NDimSpatial + 3 &&
@@ -52,8 +57,12 @@ CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input
}
}
}
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, wo) = v_acc_converted;
if constexpr(Tuple::size() > 0)
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, wo));
else
elfunc(v_acc, v_acc);
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, wo) = v_acc_out;
};
make_ParallelTensorFunctor(func,
@@ -95,8 +104,12 @@ CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input
}
}
}
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, ho, wo) = v_acc_converted;
if constexpr(Tuple::size() > 0)
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, ho, wo));
else
elfunc(v_acc, v_acc);
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, ho, wo) = v_acc_out;
};
make_ParallelTensorFunctor(func,
@@ -145,8 +158,12 @@ CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input
}
}
}
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, d_o, ho, wo) = v_acc_converted;
if constexpr(Tuple::size() > 0)
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, d_o, ho, wo));
else
elfunc(v_acc, v_acc);
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, d_o, ho, wo) = v_acc_out;
};
make_ParallelTensorFunctor(func,

View File

@@ -1540,6 +1540,23 @@ struct Logistic
const float alpha_;
};
struct Clamp
{
CK_TILE_HOST_DEVICE Clamp(float lower = std::numeric_limits<float>::lowest(),
float upper = std::numeric_limits<float>::max())
: lower_(lower), upper_(upper) {};
template <typename T>
CK_TILE_HOST_DEVICE constexpr void operator()(T& y, const T& x) const
{
T lower = ck_tile::type_convert<T>(lower_);
T upper = ck_tile::type_convert<T>(upper_);
y = ck_tile::clamp(x, lower, upper);
}
float lower_, upper_;
};
struct ConvInvscale
{
static constexpr const char* name = "ConvInvscale";
@@ -1629,6 +1646,55 @@ struct Cast
};
};
/**
* @brief Compose two unary element-wise functions into one.
*
*
* @note The Ds tensor can be used by at most one of the composed functions.
* This holds even if compositions are chained:
* In `Compose<FA, Compose<FB, FC>>`, only one of `FA`, `FB`, or `FC` can use
* the Ds tensor.
*
* @tparam FuncA The first function to be applied.
* @tparam FuncB The second function to be applied.
* @tparam FuncADs Whether `FuncA` uses the Ds tensor.
* @tparam FuncBDs Whether `FuncB` uses the Ds tensor.
*/
template <typename FuncA, typename FuncB, bool FuncADs = false, bool FuncBDs = false>
struct Compose
{
static_assert(!(FuncADs && FuncBDs), "Only one composed function may use the Ds tensor.");
CK_TILE_HOST_DEVICE Compose(FuncA func_a_ = FuncA{}, FuncB func_b_ = FuncB{})
: func_a(func_a_), func_b(func_b_)
{
}
template <typename AIn, typename BOut, typename AOut = AIn, typename... ADs>
CK_TILE_HOST_DEVICE constexpr void operator()(BOut& y, const AIn& x, const ADs&... ds) const
{
AOut tmp;
if constexpr(FuncADs)
{
func_a(tmp, x, ds...);
func_b(y, tmp);
}
else if constexpr(FuncBDs)
{
func_a(tmp, x);
func_b(y, tmp, ds...);
}
else
{
func_a(tmp, x);
func_b(y, tmp);
}
}
const FuncA func_a;
const FuncB func_b;
};
// support fastconvert of int8 to fp16
#if 0
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>

View File

@@ -8,7 +8,7 @@
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include <optional>
#include <type_traits>
namespace ck_tile {
@@ -117,6 +117,10 @@ struct CShuffleEpilogue
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
CDElementwise elfunc_;
CK_TILE_DEVICE CShuffleEpilogue(CDElementwise elfunc = CDElementwise{}) : elfunc_(elfunc) {};
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
/**
@@ -385,7 +389,7 @@ struct CShuffleEpilogue
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
number<NumDTensor>{}));
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
tile_elementwise_inout_unpack(elfunc_, c_ds_tiles);
}
template <typename OutDramWindow, typename COutTensor>
@@ -450,7 +454,7 @@ struct CShuffleEpilogue
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* /*p_smem*/,
void* /* p_smem */,
const ScaleM& scale_m = {},
const ScaleN& scale_n = {})
{

View File

@@ -185,14 +185,6 @@ struct GemmKernelMultiABD
{
return false;
}
// Currently MultiABD kernel doesn't support F8 data type
if(ck_tile::get_device_name() == "gfx950" &&
(std::is_same<ck_tile::fp8_t, ADataType>::value ||
std::is_same<ck_tile::fp8_t, BDataType>::value ||
std::is_same<ck_tile::fp8_t, DDataType>::value))
{
return false;
}
return UniversalGemmKernel::IsSupportedArgument(kargs);
}

View File

@@ -8,6 +8,478 @@
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
namespace reboot {
/// @brief The Stream K GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref StreamKKernel "StreamKKernel" when creating the kernel
/// arguments object. It contains all necessary information required to build proper kernel
/// arguments and launch the kernel on GPU. This structure defines the GEMM problem
/// configuration by stating all required information like M,N,K sizes and respective strides.
struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<>
{
CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_,
StreamKReductionStrategy reduction_strategy_)
: UniversalGemmHostArgs<>({a_ptr_},
{b_ptr_},
{/*ds_ptr*/},
c_ptr_,
/*k_batch_ =*/1,
M_,
N_,
K_,
{stride_A_},
{stride_B_},
{/*stride_Ds_*/},
stride_C_),
reduction_strategy{reduction_strategy_}
{
}
ck_tile::StreamKReductionStrategy reduction_strategy;
};
/// @brief The Stream K GEMM kernel class.
///
/// @par Overview
/// This class is responsible for the Stream-K kernel, making use of UniversalGemm.
// The main kernel functions are the operator() functions. There is one for Persistent
// and one for Non-Persistent data parallel sections of the Stream-K algorithm.
//
// Both the Non-Persistent and Persistent kernels make use of `BaseGemm()` and
// `StreamKGemm()`. `BaseGemm()` computes offsets into the A,B,C tensors, then calls
// `RunGemm()` which runs the GEMM pipeline and epilogue. `StreamKGemm()` performs the
// main Stream-K algorithm. Each iteration of the Stream-K loop calls `BaseGemm()`.
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct StreamKKernel
{
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
/// functions.
using UniversalGemmKernel =
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
static constexpr bool PersistentDP = UniversalGemmKernel::PersistentKernel;
using TilePartitioner = TilePartitioner_;
using GemmPipeline = GemmPipeline_;
using EpiloguePipeline = EpiloguePipeline_;
static_assert(
TilePartitioner::PERSISTENT == PersistentDP,
"Persistent flag from TilePartitioner must match Persistent flag from UniversalGemm.");
/// @brief Specify the layout configurations for A, B, and C
using ALayout = typename GemmPipeline::ALayout;
using BLayout = typename GemmPipeline::BLayout;
using CLayout = typename GemmPipeline::CLayout;
/// @brief Specify the data type configurations for A, B, and C
using ADataType = typename GemmPipeline::ADataType;
using BDataType = typename GemmPipeline::BDataType;
using CDataType = typename EpiloguePipeline::ODataType;
template <typename T>
static constexpr bool is_tuple_v = is_detected<is_tuple, T>::value;
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
static_assert(!is_tuple_v<ALayout> && !is_tuple_v<ADataType>,
"ALayout and ADataType must be scalars.");
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
static_assert(!is_tuple_v<BLayout> && !is_tuple_v<BDataType>,
"BLayout and BDataType must be scalars.");
/// @brief CLayout and CDataType are expected to be scalars, not a tuple.
static_assert(!is_tuple_v<CLayout> && !is_tuple_v<CDataType>,
"CLayout and CDataType must be scalars.");
struct StreamKKernelArgs : ck_tile::UniversalGemmKernelArgs<>
{
StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid)
: UniversalGemmKernelArgs{host_args.as_ptr,
host_args.bs_ptr,
host_args.ds_ptr,
host_args.e_ptr,
host_args.M,
host_args.N,
host_args.K,
host_args.stride_As,
host_args.stride_Bs,
host_args.stride_Ds,
host_args.stride_E,
host_args.k_batch},
reduction_strategy{host_args.reduction_strategy},
// The workspace pointer is set to nullptr because we must first
// instantiate the TilePartitioner to get the necessary size
workspace_ptr{nullptr},
tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}}
{
}
/// @brief The strategy used by work groups to compute final results in C tensor.
StreamKReductionStrategy reduction_strategy;
/// @brief A pointer to a buffer in device memory for accumulating partial via reduction
/// strategy.
void* workspace_ptr;
/// @brief An instance of the TilePartioner class for assisting with mapping workgroups to
/// the C tensor.
TilePartitioner tile_partitioner;
};
using KernelArgs = StreamKKernelArgs;
using Kernel = StreamKKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
using P_ = GemmPipeline;
using WarpTile = typename P_::BlockGemmShape::WarpTile;
return concat('_', "streamk", gemm_prec_str<ADataType, BDataType>(),
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
concat('x', WarpTile::at(number<0>{}), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
// clang-format on
}
/// @brief Compute the grid size for the Stream K kernel using the tile_partitioner.
/// @return The grid size.
CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3
{
return tile_partitioner.grid_size();
}
/// @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
/// @return The maximum occupancy grid size.
/// @note This function queries the maximum occupancy of the kernel using
/// `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
return UniversalGemmKernel::MaxOccupancyGridSize(s);
}
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
{
return UniversalGemmKernel::BlockSize();
}
/// @brief Constructs kernel arguments for the Stream-K kernel.
/// @param host_args Stream-K host arguments.
/// @param num_cu Number of compute units (CUs). The default is the number of CUs on the device.
/// The caller may select their own to assist with test reproducibility, etc.
/// @param occupancy The maximum number of active blocks per CU for this kernel. The caller may
/// select their own to assist with test reproducibility, etc.
/// @return The kernel arguments for Stream-K.
CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args,
int num_cu = NumCU(),
int occupancy = Occupancy())
{
const index_t grid = num_cu * occupancy;
return StreamKKernelArgs{host_args, grid};
}
template <bool UseDefaultScheduler = true>
CK_TILE_DEVICE static void
RunGemm(const std::array<const ADataType*, UniversalGemmKernel::NumATensor>& as_ptr,
const std::array<const BDataType*, UniversalGemmKernel::NumBTensor>& bs_ptr,
const std::array<const void*, UniversalGemmKernel::NumDTensor>& ds_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
const typename UniversalGemmKernel::KernelArgs& kargs,
const index_t num_loop,
const index_t block_idx_m,
const index_t block_idx_n,
const index_t k_size)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// Run GEMM cooperatively by whole workgroup.
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
// Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
// has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
// case, we call the GemmPipeline's operator() function that takes both has_hot_loop and
// tail_num.
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
bs_block_window[UniversalGemmKernel::I0],
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0);
if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
}
CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs)
{
if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy.");
}
return false;
}
return UniversalGemmKernel::IsSupportedArgument(kargs);
}
/// @brief Computes the buffer size needed to store accumulation results for Stream K.
/// @return The buffer size needed.
CK_TILE_HOST static uint32_t GetWorkSpaceSize(const StreamKKernelArgs& kargs)
{
return kargs.tile_partitioner.GetWorkSpaceSize(sizeof(CDataType));
}
/// @brief Sets the kargs' current workspace_ptr to the given workspace_ptr.
/// @note Assumes that the given workspace_ptr points to allocated device memory.
CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr)
{
kargs.workspace_ptr = workspace_ptr;
}
/// @brief Computes offsets into A, B, and C tensors then runs the GEMM pipeline and epilogue.
/// @param kargs Stream-K kernel arguments.
/// @param tile_idx The 1D tile index in the C tensor for this workgroup.
/// @param num_loop The number of iterations (at the macro tile level) in the K dimension this
/// workgroup will perform in the C tile.
/// @param i_k_a The K offset in the A tensor.
/// @param i_k_b The K offset in the B tensor.
/// @param k_size The portion of the K dimension this workgroup processes in the assigned
/// `tile_idx`.
/// @param smem_ptr_0 Pointer to LDS.
CK_TILE_DEVICE void BaseGemm(StreamKKernelArgs& kargs,
index_t tile_idx,
index_t num_loop,
index_t i_k_a,
index_t i_k_b,
index_t k_size,
void* smem_ptr_0) const
{
const auto c_macro_tile_idx = kargs.tile_partitioner.get_output_tile_index(tile_idx);
index_t i_m = c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock;
index_t i_n = c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock;
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
// Run the GEMM pipeline and Epilogue.
RunGemm(
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, k_size);
}
/// @brief Runs the main Stream-K algorithm.
/// @param kargs Stream-K kernel arguments.
/// @param cta_idx The current Stream-K workgroup's index.
/// @param smem_ptr_0 Pointer to LDS.
/// @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
/// non-persistent data-parallel (DP) section is used, then a Stream-K workgroup's `cta_idx`
/// should be something like `blockIdx.x` minus number of DP workgroups.
CK_TILE_DEVICE void
StreamKGemm(StreamKKernelArgs& kargs, index_t cta_idx, void* smem_ptr_0) const
{
index_t iter_start, iter_end;
kargs.tile_partitioner.get_iter_boundaries(iter_start, iter_end, cta_idx);
while(iter_start < iter_end)
{
// Get the 1D tile index in the C tensor that this workgroup will work in for this
// iteration of the loop.
index_t tile_idx =
amd_wave_read_first_lane(kargs.tile_partitioner.get_tile_index(iter_start));
// Get the start and end boundaries for the current tile.
index_t tile_iter_start, tile_iter_end;
kargs.tile_partitioner.get_tile_boundaries(tile_iter_start, tile_iter_end, tile_idx);
// Get the start and end iteration within the current tile for the workgroup.
index_t local_iter_start = amd_wave_read_first_lane(
kargs.tile_partitioner.get_local_iter(iter_start, tile_iter_start));
index_t local_iter_end =
amd_wave_read_first_lane(kargs.tile_partitioner.get_local_iter_end(
tile_iter_start, iter_end, tile_iter_end));
// Get the iteration length.
index_t num_loop_sk = local_iter_end - local_iter_start;
// Determine the total size along the K dimension the workgroup is using in this
// iteration (used to construct tensor views).
index_t k_size = num_loop_sk * TilePartitioner::KPerBlock;
// Get the K offsets for the A and B tensors
auto [i_k_a, i_k_b] = GetKOffsets<ALayout, BLayout>(
local_iter_start, kargs.stride_As[0], kargs.stride_Bs[0]);
if constexpr(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Atomic)
{
BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
}
else
{
// TODO: Apply reduction logic.
}
// Prepare for next Stream-K loop iteration.
iter_start = tile_iter_end;
block_sync_lds();
}
}
/// @brief Entry point for the Stream-K Kernel with non-persistent DP.
///
/// @par Overview
/// For the Non-Persistent kernel, each data parallel workgroup will
/// compute the results for their assigned macro-tile by calling `BaseGemm()`.
/// The Stream-K workgroups will do their assigned work by calling
/// `StreamKGemm()`, which calls `BaseGemm()` in the Stream-K loop.
template <bool U = PersistentDP>
CK_TILE_DEVICE typename std::enable_if_t<!U> operator()(StreamKKernelArgs kargs) const
{
// Allocate LDS
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
index_t block_idx = ck_tile::get_block_1d_id();
index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
index_t dp_ctas = kargs.tile_partitioner.get_dp_ctas();
bool is_dp_ctas = block_idx < kargs.tile_partitioner.get_dp_ctas();
// Check if at the data parallel section
if(is_dp_ctas)
{
BaseGemm(kargs, block_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
}
else
{
// Stream-K
StreamKGemm(kargs, block_idx - dp_ctas, smem_ptr_0);
}
}
/// @brief Entry point for the Stream-K Kernel with persistent DP.
///
/// @par Overview
/// For the Persistent kernel, each workgroup will first compute their
/// assigned data-parallel tiles. Each data parallel tile will be computed
/// by calling `BaseGemm()`. Then the workgroups will proceed with the
/// Stream-K portion by calling `StreamKGemm()`, which calls `BaseGemm()`
/// in the Stream-K loop.
template <bool U = PersistentDP>
CK_TILE_DEVICE typename std::enable_if_t<U> operator()(StreamKKernelArgs kargs) const
{
// Allocate LDS
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
index_t block_idx = ck_tile::get_block_1d_id();
index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
// Data-parallel section
for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles();
tile_idx += kargs.tile_partitioner.get_grid())
{
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
}
// Stream-K section
StreamKGemm(kargs, block_idx, smem_ptr_0);
}
private:
/// @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
/// the starting macro tile index in the K dimension for the workgroup.
/// @return A tuple containing the offsets into the A and B tensors accounting for the layouts
/// of A and B.
/// @note The default case is that A is assumed to be row major and B is assumed to be column
/// major.
template <typename ALayout, typename BLayout>
CK_TILE_DEVICE static tuple<index_t, index_t>
GetKOffsets(index_t iter_offset, index_t stride_a, index_t stride_b)
{
index_t stride_offset_a;
index_t stride_offset_b;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
stride_offset_a = stride_a;
}
else
{
stride_offset_a = 1;
}
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
stride_offset_b = stride_b;
}
else
{
stride_offset_b = 1;
}
index_t base_offset = iter_offset * TilePartitioner::KPerBlock;
return make_tuple(base_offset * stride_offset_a, base_offset * stride_offset_b);
}
CK_TILE_HOST static int NumCU()
{
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
int num_cu = dev_prop.multiProcessorCount;
return num_cu;
}
/// @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel
/// @return The occupancy
/// @note This function queries the maximum occupancy of the kernel using
/// `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
CK_TILE_HOST static int Occupancy()
{
int occupancy;
// Since occupancy of 1 is valid for stream k, we set min_num_block_per_cu to 1
constexpr int min_block_per_cu = 1;
const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
return occupancy;
}
};
} // namespace reboot
/// @brief The Stream K GEMM kernel host arguments.
///

View File

@@ -186,6 +186,11 @@ struct StreamKTilePartitionerBase
*/
CK_TILE_HOST_DEVICE index_t get_n() const noexcept;
/**
* @brief Returns an estimate of the number of workgroups writing to the same macro tile in C.
*/
CK_TILE_HOST index_t estimate_num_wgs_per_tile() const noexcept;
protected:
index_t num_tiles_;
index_t grid_;
@@ -246,6 +251,7 @@ struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true
ck_tile::index_t grid);
public:
static constexpr bool PERSISTENT = true;
/**
* @brief Calculates the launching grid size for the Stream-K kernel. In the Persistent
* case, no extra workgroups are allocated for the data parallel section, making the grid
@@ -292,6 +298,7 @@ struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, fals
ck_tile::index_t grid);
public:
static constexpr bool PERSISTENT = false;
/**
* @brief Calculates the launching grid size for the Stream-K kernel. In the Non-Persistent
* case, extra workgroups are allocated for the data parallel section, making the grid

View File

@@ -214,6 +214,27 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_n() c
return n_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_HOST index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::estimate_num_wgs_per_tile()
const noexcept
{
// In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup
// writing final results to a given macro tile in C.
int num_wgs_per_tile = 1;
// Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
if(sk_ctas_ > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
{
ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1);
// Estimate the number of workgroups per macro tile.
num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) +
((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0);
}
return std::max(num_wgs_per_tile, 1);
}
template <typename BlockGemmShapeType,
StreamKReductionStrategy ReductionStrategyType,
bool Persistent>

View File

@@ -309,6 +309,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as corresponding "
"C block tensor data type!");
constexpr auto warp_size = get_warp_size();
// hot loop:
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
@@ -352,7 +353,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
else
{
// Many N warps/iters share the same scale, index from full
// [NQPerBlock=1, BQPerBlock] matrix
// [KQPerBlock, NQPerBlock=1] matrix
static_assert(Traits::NQPerBlock == 1);
return kQScale;
}
@@ -366,11 +367,11 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
});
});
});

View File

@@ -687,8 +687,8 @@ struct QuantGemmKernel
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
make_tuple(kargs.stride_BQ, 1),
make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
make_tuple(1, kargs.stride_BQ),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
@@ -908,9 +908,9 @@ struct QuantGemmKernel
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
{0, i_n / QuantGroupSize::kN});
}
else
{

View File

@@ -375,30 +375,48 @@ struct QuantGroupedGemmKernel
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I4);
if constexpr(kQuantType == QuantType::RowColQuant)
if constexpr(kQuantType == QuantType::BQuantGrouped)
{
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(c_block_window)>(
c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
aq_block_window,
bq_block_window);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
b_block_window,
bq_block_window,
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0);
auto& c_block_window = gemm_tile_windows.at(Base::I4);
// Run Epilogue Pipeline
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
else if constexpr(kQuantType == QuantType::TensorQuant)
else
{
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I4);
if constexpr(kQuantType == QuantType::RowColQuant)
{
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}(c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
aq_block_window,
bq_block_window);
}
else if constexpr(kQuantType == QuantType::TensorQuant)
{
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
}
}
}

View File

@@ -55,8 +55,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
tile_distribution_encoding_pattern_bq<BlockGemmShape,
WarpGemm,
BlockSize,
NPerBlockBQ,
KPerBlockBQ,
NPerBlockBQ,
Problem::QuantGroupSize::kN,
VecLoadSize>;

View File

@@ -259,8 +259,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
static_assert(KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
"Bq block window has incorrect lengths for defined BqLayout!");
static_assert(is_a_col_major
@@ -318,7 +318,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BQDramTileWindowStep bq_dram_tile_window_step =
is_bq_col_major ? make_array(0, KPerBlockBQ) : make_array(KPerBlockBQ, 0);
is_bq_col_major ? make_array(KPerBlockBQ, 0) : make_array(0, KPerBlockBQ);
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
@@ -363,6 +363,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
if constexpr(HasHotLoop)
{
constexpr index_t tail_count =
((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) ? 1 : 2;
index_t i = 0;
do
{
@@ -408,7 +410,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 1));
} while(i < (num_loop - tail_count));
}
// tail
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
@@ -475,6 +477,49 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
num_loop,
p_smem);
}
/// @brief Runtime pipeline dispatch operator for grouped GEMM kernels.
///
/// This operator is used by grouped GEMM kernels where pipeline parameters
/// (has_hot_loop, num_loop, tail_number) are calculated on the device side
/// at runtime, not on the host side during compilation. This is necessary
/// because different GEMM problems in the group may have different K dimensions,
/// requiring different pipeline configurations that cannot be determined at
/// compile time.
///
/// @param a_dram_block_window_tmp Block window for A tensor in DRAM
/// @param b_dram_block_window_tmp Block window for B tensor in DRAM
/// @param bq_dram_block_window_tmp Block window for BQ (quantization scale) tensor in DRAM
/// @param num_loop Number of main loop iterations (calculated on device)
/// @param has_hot_loop Whether the pipeline has a hot loop (calculated on device)
/// @param tail_number Type of tail handling required (calculated on device)
/// @param p_smem Pointer to shared memory
/// @return Accumulated result tile in registers
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename BQDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* p_smem) const
{
const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) {
constexpr bool hot_loop = has_hot_loop_.value;
constexpr auto tail_num = tail_number_.value;
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
bq_dram_block_window_tmp,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
};
} // namespace ck_tile

View File

@@ -192,51 +192,51 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
{
if constexpr(YPerQ == 1)
{
// YPerQ == 1 implementation - each row of B has independent scale
constexpr index_t X = XPerTile;
constexpr index_t XR = 2;
constexpr index_t Y0 = NIterPerWarp;
constexpr index_t Y1 = NWarps;
constexpr index_t Y2 = WarpGemm::kN;
// each row of B has independent scale
constexpr index_t Y = XPerTile;
constexpr index_t YR = 1;
constexpr index_t X0 = NIterPerWarp;
constexpr index_t X1 = NWarps;
constexpr index_t X2 = WarpGemm::kN;
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
static_assert(X0 * X1 * X2 == XPerTile, "X0, X1, X2 must cover the blocktile along X.");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, XR>,
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 2>,
sequence<0, 0>>{});
tile_distribution_encoding<sequence<MWarps, YR>,
tuple<sequence<Y>, sequence<X0, X1, X2>>,
tuple<sequence<0, 2>, sequence<0, 2>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<2, 1>,
sequence<0, 0>>{});
}
else if constexpr(YPerTile >= NIterPerWarp * NWarps)
else if constexpr(XPerTile >= NIterPerWarp * NWarps)
{
// small NQ block size case: split NQ axis by iters and Nwarps
constexpr index_t NQPerIter = integer_divide_ceil(YPerTile, NIterPerWarp * NWarps);
constexpr index_t XR = get_warp_size() / NQPerIter;
static_assert(YPerTile == NQPerIter * NWarps * NIterPerWarp);
constexpr index_t NQPerIter = integer_divide_ceil(XPerTile, NIterPerWarp * NWarps);
constexpr index_t YR = get_warp_size() / NQPerIter;
static_assert(XPerTile == NQPerIter * NWarps * NIterPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<MWarps, XR>,
tuple<sequence<NIterPerWarp, NWarps, NQPerIter>, sequence<XPerTile>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
sequence<MWarps, YR>,
tuple<sequence<XPerTile>, sequence<NIterPerWarp, NWarps, NQPerIter>>,
tuple<sequence<0, 2>, sequence<0, 2>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 2>,
sequence<2, 1>,
sequence<0, 0>>{});
}
else if constexpr(YPerTile >= NIterPerWarp)
else if constexpr(XPerTile >= NIterPerWarp)
{
// now all NWarps have the same scale -> replicate
constexpr index_t NQPerIter = integer_divide_ceil(YPerTile, NIterPerWarp);
constexpr index_t XR = get_warp_size() / NQPerIter;
constexpr index_t YR = get_warp_size() / NQPerIter;
static_assert(YPerTile == NQPerIter * NIterPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<MWarps, NWarps, XR>,
tuple<sequence<NIterPerWarp, NQPerIter>, sequence<XPerTile>>,
tuple<sequence<0, 0>, sequence<0, 1>>,
sequence<MWarps, NWarps, YR>,
tuple<sequence<XPerTile>, sequence<NIterPerWarp, NQPerIter>>,
tuple<sequence<0, 2>, sequence<0, 2>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
sequence<1, 2>,
sequence<2, 1>,
sequence<0, 0>>{});
}
else
@@ -245,18 +245,13 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
// threads
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, get_warp_size()>,
tuple<sequence<YPerTile>, sequence<XPerTile>>,
tuple<sequence<XPerTile>, sequence<YPerTile>>,
tuple<sequence<0, 0>, sequence<0>>,
tuple<sequence<0, 1>, sequence<2>>,
sequence<1, 2>,
sequence<2, 1>,
sequence<0, 0>>{});
}
}
CK_TILE_HOST_DEVICE static constexpr bool is_fully_replicated()
{
return YPerQ > 1 && YPerTile < NIterPerWarp * NWarps;
}
};
template <typename GroupSizes>

View File

@@ -237,7 +237,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
// BQ DRAM window for load
auto bq_copy_dram_window =
make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<KPerBlockBQ>{}),
make_tuple(number<KPerBlockBQ>{}, number<kNPerBlock>{}),
bq_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeBQDramTileDistribution<Problem>());
@@ -270,7 +270,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
BQBlockTile bq_block_tile, bq_block_tile_2;
bq_block_tile = load_tile(bq_copy_dram_window);
// move BQ to tile 1
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
// Prefill A0
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
@@ -319,7 +319,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
bq_block_tile_2 = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
// Prefill A(2i+1)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
@@ -361,7 +361,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
bq_block_tile = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
// Prefill A(2i+2)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);

View File

@@ -7,10 +7,12 @@
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp"
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
@@ -28,6 +30,7 @@ struct GroupedConvFwdKernelArgs
GroupedConvTraitsType_::VectorSizeB,
GroupedConvTraitsType_::VectorSizeC,
true>; // Split N enabled
using CDElementwise = typename GroupedConvTraitsType_::CDElementwise;
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
template <
@@ -38,7 +41,8 @@ struct GroupedConvFwdKernelArgs
std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
bool>::type = false>
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& args)
: elfunc(args.elfunc)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
@@ -121,7 +125,8 @@ struct GroupedConvFwdKernelArgs
std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
bool>::type = false>
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& args)
: elfunc(args.elfunc)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
@@ -213,7 +218,8 @@ struct GroupedConvFwdKernelArgs
std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
bool>::type = false>
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& args)
: elfunc(args.elfunc)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
@@ -335,6 +341,7 @@ struct GroupedConvFwdKernelArgs
const void* in_ptr;
const void* wei_ptr;
std::array<const void*, NumDTensor> ds_ptr;
const CDElementwise elfunc;
void* out_ptr;
AGridDescMK a_grid_desc_m_k;
@@ -423,6 +430,8 @@ struct GroupedConvolutionForwardKernel
// Below type is actually accumulation data type - the output of block GEMM.
using OutDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using CDElementwise = typename EpiloguePipeline::CDElementwise;
using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs<GroupedConvTraitsType_>;
// TODO: Enable this
@@ -458,7 +467,7 @@ struct GroupedConvolutionForwardKernel
}
CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized
MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs)
MakeKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& hostArgs)
{
return GroupedConvFwdKernelArgsSpecialized(hostArgs);
}
@@ -636,7 +645,7 @@ struct GroupedConvolutionForwardKernel
"Not supported!");
return make_tensor_view<address_space_enum::global>(
static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
static_cast<const OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
},
number<NumDTensor>{});
@@ -765,8 +774,9 @@ struct GroupedConvolutionForwardKernel
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{kargs.elfunc}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
/**

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
namespace ck_tile {
@@ -14,7 +15,7 @@ namespace ck_tile {
/// This structure is passed to Grouped Convolution Kernels when creating kernel
/// arguments object. It contain all necessary information required to
/// build proper kernel argument and launch kernel on GPU.
template <typename InPtr, typename WeiPtr, typename OutPtr>
template <typename InPtr, typename WeiPtr, typename OutPtr, typename CDElementwise>
struct GroupedConvHostArgs : public conv::ConvParam
{
CK_TILE_HOST GroupedConvHostArgs() = delete;
@@ -23,13 +24,15 @@ struct GroupedConvHostArgs : public conv::ConvParam
WeiPtr wei_ptr_,
const std::vector<const void*> ds_ptr_,
OutPtr out_ptr_,
index_t k_batch_)
index_t k_batch_,
CDElementwise elfunc_ = CDElementwise{})
: conv::ConvParam(conv_param),
in_ptr(in_ptr_),
wei_ptr(wei_ptr_),
ds_ptr(ds_ptr_),
out_ptr(out_ptr_),
k_batch(k_batch_)
k_batch(k_batch_),
elfunc(elfunc_)
{
}
@@ -38,11 +41,17 @@ struct GroupedConvHostArgs : public conv::ConvParam
const std::vector<const void*> ds_ptr;
OutPtr out_ptr;
index_t k_batch;
const CDElementwise elfunc;
};
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*>;
using GroupedConvBwdWeightHostArgs = GroupedConvHostArgs<const void*, void*, const void*>;
using GroupedConvBwdDataHostArgs = GroupedConvHostArgs<void*, const void*, const void*>;
using PassThrough = ck_tile::element_wise::PassThrough;
template <typename CDElementwise = PassThrough>
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*, CDElementwise>;
using GroupedConvBwdWeightHostArgs =
GroupedConvHostArgs<const void*, void*, const void*, PassThrough>;
using GroupedConvBwdDataHostArgs =
GroupedConvHostArgs<void*, const void*, const void*, PassThrough>;
template <index_t NDimSpatial_,
ConvolutionSpecialization ConvSpecialization_,
@@ -50,9 +59,10 @@ template <index_t NDimSpatial_,
typename WeiLayout_,
typename DsLayout_,
typename OutLayout_,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
index_t VectorSizeC_ = 1>
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
index_t VectorSizeC_ = 1,
typename CDElementwise_ = PassThrough>
struct GroupedConvTraits
{
private:
@@ -70,6 +80,7 @@ struct GroupedConvTraits
using WeiLayout = WeiLayout_;
using DsLayout = DsLayout_;
using OutLayout = OutLayout_;
using CDElementwise = CDElementwise_;
using GroupedConvImplicitGemmTraitsFwd =
TileGemmTraits<true,
true,

View File

@@ -300,7 +300,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
}
// layout NGCHW/GKYXC/NGKHW
// layout NGCHW/GKCYX/NGKHW
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NGCHW> &&
is_same_v<WeiLayout, GKCYX> && is_same_v<OutLayout, NGKHW>)
{

View File

@@ -7,7 +7,7 @@
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dynamic.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
@@ -161,7 +161,7 @@ template <ck::index_t NumDimSpatial,
typename DDataTypes,
typename OutDataType,
typename AComputeType,
typename BComputeType = AComputeType>
typename BComputeType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial,
InLayout,

View File

@@ -69,7 +69,7 @@ void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
PassThrough,
Scale>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
# Copyright © Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -46,6 +46,7 @@ set(REGRESSION_TESTS
test_ck_tile_fmha_fwd_bf16
test_ck_tile_fmha_fwd_fp16
test_ck_tile_fmha_fwd_fp8
test_ck_tile_streamk_reboot_extended
)
function(add_test_executable TEST_NAME)

View File

@@ -33,6 +33,14 @@ struct elementwise_op_traits<ck_tile::element_wise::Relu>
static constexpr int num_inputs = 1;
};
using NegRelu =
ck_tile::element_wise::Compose<ck_tile::element_wise::Relu, ck_tile::element_wise::Neg>;
template <>
struct elementwise_op_traits<NegRelu>
{
static constexpr int num_inputs = 1;
};
template <std::size_t D, typename F>
auto make_uniform_array_with_factory(F&& factory)
{
@@ -194,7 +202,11 @@ using TestConfig_F16_Add = std::tuple<ck_tile::half_t,
Shape1_BlockTile,
Shape1_WarpTile>;
using TestTypes = ::testing::Types<TestConfig_F32_Add, TestConfig_F32_Relu, TestConfig_F16_Add>;
using TestConfig_F32_Neg_Relu =
std::tuple<float, float, float, NegRelu, Shape1_BlockWarps, Shape1_BlockTile, Shape1_WarpTile>;
using TestTypes = ::testing::
Types<TestConfig_F32_Add, TestConfig_F32_Relu, TestConfig_F16_Add, TestConfig_F32_Neg_Relu>;
TYPED_TEST_SUITE(TestCkTileElementwise, TestTypes);

View File

@@ -5,7 +5,7 @@
#include "test_gemm_quant_base.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/shuffle_utils.hpp"
#include "ck_tile/host/tensor_shuffle_utils.hpp"
struct GemmConfigBase
{

View File

@@ -20,20 +20,19 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using KernelTypes = ::testing::Types<
// Has cshuffle epilogue enabled
// A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>
// Currently MultiABD kernel doesn't support F8 data type
//std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
//std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>
>;
// clang-format on

View File

@@ -20,19 +20,17 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using KernelTypes = ::testing::Types<
// Has cshuffle epilogue disabled
// A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>
// Currently MultiABD kernel doesn't support F8 data type
//std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
//std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>
>;
// clang-format on

View File

@@ -1,5 +1,95 @@
#pragma once
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x512x256)
{
constexpr int M = 256;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x256x256)
{
constexpr int M = 512;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x512x256)
{
constexpr int M = 512;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x256x256)
{
constexpr int M = 256;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x768x256)
{
constexpr int M = 512;
constexpr int N = 768;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x1280x256)
{
constexpr int M = 512;
constexpr int N = 1280;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x1280x256)
{
constexpr int M = 256;
constexpr int N = 1280;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_768x512x256)
{
constexpr int M = 768;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x512x256)
{
constexpr int M = 1280;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x512x512)
{
constexpr int M = 512;

View File

@@ -13,40 +13,9 @@
#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
struct AddScale
{
template <typename E, typename A0, typename A1>
CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const A0& a0, const A1& a1) const
{
a = scale * (ck_tile::type_convert<float>(a0) + ck_tile::type_convert<float>(a1));
}
float scale = 1.0;
};
struct MultiplyMultiply
{
template <typename E, typename C, typename D0, typename D1>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void
{
const float x0_f = ck_tile::type_convert<float>(c) * ck_tile::type_convert<float>(d0) *
ck_tile::type_convert<float>(d1);
e = ck_tile::type_convert<E>(x0_f);
}
};
struct ElementWiseAddAdd
{
template <typename E, typename C, typename D0, typename D1>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void
{
const float x0_f = ck_tile::type_convert<float>(c) + ck_tile::type_convert<float>(d0) +
ck_tile::type_convert<float>(d1);
e = ck_tile::type_convert<E>(x0_f);
}
};
using AddScale = ck_tile::element_wise::AddScale;
using ElementWiseAddAdd = ck_tile::element_wise::MultiDAdd;
using MultiplyMultiply = ck_tile::element_wise::MultiDMultiply;
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)

View File

@@ -1,3 +1,17 @@
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
set(EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
-mllvm
-enable-noalias-to-md-conversion=0
)
set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
# Currently test_ck_tile_streamk is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
@@ -6,23 +20,33 @@ if(GPU_TARGETS MATCHES "gfx9")
#TODO: support all arches
#TODO: current c-shuffle only supports C layout as R
add_gtest_executable(test_ck_tile_streamk_smoke
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rcc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_crc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rcc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_crc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
)
# TODO: enable extended tests after tolerances for atomic reductions are addressed.
# add_gtest_executable(test_ck_tile_streamk_extended
@@ -117,6 +141,19 @@ if(GPU_TARGETS MATCHES "gfx9")
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
# )
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
add_gtest_executable(test_ck_tile_streamk_reboot_smoke
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_fp16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_bf16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp
test_gemm_streamk_reboot_util.cpp)
add_gtest_executable(test_ck_tile_streamk_reboot_extended
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp
test_gemm_streamk_reboot_util.cpp)
target_compile_options(test_ck_tile_streamk_smoke PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping test_ck_tile_streamk tests for current target")
endif()

View File

@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"
template <typename Tuple>
class TestCkTileStreamKRebootBf16NonPersistent : public TestCkTileStreamKReboot<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16NonPersistent
TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16NonPersistent, KernelTypesStreamKBf16NonPersistent);
#include "test_gemm_streamk_reboot_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"
template <typename Tuple>
class TestCkTileStreamKRebootBf16Persistent : public TestCkTileStreamKReboot<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16Persistent
TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16Persistent, KernelTypesStreamKBf16Persistent);
#include "test_gemm_streamk_reboot_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"
template <typename Tuple>
class TestCkTileStreamKRebootFp16NonPersistent : public TestCkTileStreamKReboot<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16NonPersistent
TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16NonPersistent, KernelTypesStreamKFp16NonPersistent);
#include "test_gemm_streamk_reboot_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"
template <typename Tuple>
class TestCkTileStreamKRebootFp16Persistent : public TestCkTileStreamKReboot<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16Persistent
TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16Persistent, KernelTypesStreamKFp16Persistent);
#include "test_gemm_streamk_reboot_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,11 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
#define TEST_SUITE_PARAMS BF8_CCR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
#include "test_gemm_streamk_cases.inc"

View File

@@ -0,0 +1,11 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
#define TEST_SUITE_PARAMS BF8_CRR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
#include "test_gemm_streamk_cases.inc"

Some files were not shown because too many files have changed in this diff Show More