mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
Merge remote-tracking branch 'origin/develop' into samremes/bmatrix_2d_blockscale
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -36,6 +36,9 @@ tags
|
||||
# Editors
|
||||
.vscode
|
||||
|
||||
# Cline
|
||||
.cline*
|
||||
|
||||
# build-in-source directory (see exceptions below)
|
||||
build*
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
ARG BASE_DOCKER="rocm/composable_kernel-private:ck_aiter_base"
|
||||
ARG BASE_DOCKER="rocm/pytorch:latest"
|
||||
FROM $BASE_DOCKER
|
||||
ARG AITER_BRANCH="main"
|
||||
ARG CK_AITER_BRANCH="develop"
|
||||
RUN groupadd irc && \
|
||||
pip install pandas zmq einops && \
|
||||
RUN pip install pandas zmq einops ninja && \
|
||||
pip install numpy==1.26.2 && \
|
||||
sudo mkdir /home/jenkins && \
|
||||
sudo mkdir /home/jenkins/workspace && \
|
||||
@@ -14,6 +13,8 @@ RUN groupadd irc && \
|
||||
rm -rf 3rdparty/composable_kernel/ && \
|
||||
git clone -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git 3rdparty/composable_kernel/ && \
|
||||
python3 setup.py develop && \
|
||||
chown -R jenkins:jenkins /home/jenkins/workspace && \
|
||||
chmod -R a+rwx /home/jenkins/workspace && \
|
||||
groupadd -g 1001 jenkins && \
|
||||
useradd -u 1001 -g 1001 -m -s /bin/bash jenkins && \
|
||||
chown -R jenkins:jenkins /home/jenkins && \
|
||||
chmod -R a+rwx /home/jenkins && \
|
||||
sudo usermod -aG irc jenkins
|
||||
|
||||
101
Jenkinsfile
vendored
101
Jenkinsfile
vendored
@@ -194,6 +194,33 @@ def check_arch(){
|
||||
return arch_type
|
||||
}
|
||||
|
||||
def check_arch_name(){
|
||||
def arch_name = ""
|
||||
sh 'rocminfo | tee rocminfo.log'
|
||||
if ( runShell('grep -n "gfx90a" rocminfo.log') ){
|
||||
arch_name = "gfx90a"
|
||||
}
|
||||
else if ( runShell('grep -n "gfx942" rocminfo.log') ) {
|
||||
arch_name = "gfx942"
|
||||
}
|
||||
else if ( runShell('grep -n "gfx10" rocminfo.log') ) {
|
||||
arch_name = "gfx10"
|
||||
}
|
||||
else if ( runShell('grep -n "gfx11" rocminfo.log') ) {
|
||||
arch_name = "gfx11"
|
||||
}
|
||||
else if ( runShell('grep -n "gfx12" rocminfo.log') ) {
|
||||
arch_name = "gfx12"
|
||||
}
|
||||
else if ( runShell('grep -n "gfx908" rocminfo.log') ) {
|
||||
arch_name = "gfx908"
|
||||
}
|
||||
else if ( runShell('grep -n "gfx950" rocminfo.log') ) {
|
||||
arch_name = "gfx950"
|
||||
}
|
||||
return arch_name
|
||||
}
|
||||
|
||||
def getDockerImage(Map conf=[:]){
|
||||
env.DOCKER_BUILDKIT=1
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
@@ -302,12 +329,6 @@ def cmake_build(Map conf=[:]){
|
||||
//cmake_env can overwrite default CXX variables.
|
||||
def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","")
|
||||
|
||||
def package_build = (conf.get("package_build","") == "true")
|
||||
|
||||
if (package_build == true) {
|
||||
config_targets = "package"
|
||||
}
|
||||
|
||||
if(conf.get("build_install","") == "true")
|
||||
{
|
||||
config_targets = 'install ' + config_targets
|
||||
@@ -455,15 +476,20 @@ def cmake_build(Map conf=[:]){
|
||||
else{
|
||||
sh "ninja check"
|
||||
}
|
||||
if(params.BUILD_PACKAGES){
|
||||
echo "Build ckProfiler packages"
|
||||
sh 'ninja -j64 package'
|
||||
def arch_name = check_arch_name()
|
||||
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb"
|
||||
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}"
|
||||
}
|
||||
}
|
||||
if(params.BUILD_INSTANCES_ONLY){
|
||||
// build deb packages
|
||||
echo "Build packages"
|
||||
echo "Build library package"
|
||||
sh 'ninja -j64 package'
|
||||
archiveArtifacts artifacts: 'composablekernel-dev*.deb'
|
||||
sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.2.0_amd64.deb'
|
||||
sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64.deb'
|
||||
stash includes: "composablekernel-**.deb", name: "packages"
|
||||
stash includes: "composablekernel-dev**.deb", name: "lib_package"
|
||||
}
|
||||
}
|
||||
else{
|
||||
@@ -475,15 +501,18 @@ def cmake_build(Map conf=[:]){
|
||||
else{
|
||||
sh "ninja check"
|
||||
}
|
||||
if(params.BUILD_PACKAGES){
|
||||
echo "Build ckProfiler packages"
|
||||
sh 'ninja -j64 package'
|
||||
def arch_name = check_arch_name()
|
||||
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb"
|
||||
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only archive from develop
|
||||
if (package_build == true && env.BRANCH_NAME == "develop") {
|
||||
archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true
|
||||
}
|
||||
//check the node gpu architecture
|
||||
def arch = check_arch()
|
||||
if (params.RUN_CK_TILE_FMHA_TESTS){
|
||||
@@ -823,9 +852,42 @@ def process_results(Map conf=[:]){
|
||||
}
|
||||
if (params.BUILD_INSTANCES_ONLY){
|
||||
// unstash deb packages
|
||||
unstash "packages"
|
||||
try{
|
||||
unstash "lib_package"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate lib_package."
|
||||
}
|
||||
sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/"
|
||||
}
|
||||
if (params.BUILD_PACKAGES){
|
||||
// unstash deb packages
|
||||
try{
|
||||
unstash "profiler_package_gfx90a"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate profiler_package_gfx90a."
|
||||
}
|
||||
try{
|
||||
unstash "profiler_package_gfx942"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate profiler_package_gfx942."
|
||||
}
|
||||
try{
|
||||
unstash "profiler_package_gfx950"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate profiler_package_gfx950."
|
||||
}
|
||||
try{
|
||||
unstash "profiler_package_gfx12"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate profiler_package_gfx12."
|
||||
}
|
||||
sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-ckprofiler*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/"
|
||||
}
|
||||
else{
|
||||
// unstash perf files to master
|
||||
try{
|
||||
@@ -993,7 +1055,7 @@ def run_pytorch_tests(Map conf=[:]){
|
||||
//launch develop branch daily jobs
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_PERFORMANCE_TESTS=true;FORCE_CI=true
|
||||
0 22 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
|
||||
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
|
||||
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true;BUILD_PACKAGES=true
|
||||
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
|
||||
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
|
||||
0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true
|
||||
@@ -1085,6 +1147,10 @@ pipeline {
|
||||
name: "BUILD_INSTANCES_ONLY",
|
||||
defaultValue: false,
|
||||
description: "Test building instances for various architectures simultaneously (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_PACKAGES",
|
||||
defaultValue: false,
|
||||
description: "Build packages for the libraries and/or ckProfiler (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX908",
|
||||
defaultValue: false,
|
||||
@@ -1574,7 +1640,6 @@ pipeline {
|
||||
-D GPU_TARGETS="gfx1201" \
|
||||
-D GEMM_DATATYPE="fp16" \
|
||||
-D GEMM_LAYOUT="rcr;rrr;crr;ccr" \
|
||||
-DGEMM_CONFIG_FILE=gfx120x_config.json \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
|
||||
ninja -j64 benchmark_gemm_all && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \
|
||||
@@ -1830,7 +1895,7 @@ pipeline {
|
||||
stage("Process results"){
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { (params.RUN_PERFORMANCE_TESTS.toBoolean() || params.BUILD_INSTANCES_ONLY.toBoolean() || params.RUN_CK_TILE_FMHA_TESTS.toBoolean()) && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
expression { (params.RUN_PERFORMANCE_TESTS.toBoolean() || params.BUILD_INSTANCES_ONLY.toBoolean() || params.RUN_CK_TILE_FMHA_TESTS.toBoolean()|| params.BUILD_PACKAGES.toBoolean()) && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent { label 'mici' }
|
||||
steps{
|
||||
|
||||
@@ -52,7 +52,7 @@ struct kernel
|
||||
template <class... Ts>
|
||||
auto launch(hipStream_t stream, std::size_t global, std::size_t local, Ts... zs) const
|
||||
{
|
||||
return [=](auto&&... xs) {
|
||||
return [=, this](auto&&... xs) {
|
||||
launch(stream, global, local, std::vector<kernel_argument>{xs...}, zs...);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -59,4 +59,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
return !run_grouped_gemm_example(argc, argv);
|
||||
}
|
||||
|
||||
@@ -278,19 +278,20 @@ bool run_grouped_gemm_example(int argc, char* argv[])
|
||||
|
||||
problem_size.group_count = 16;
|
||||
|
||||
if(argc == 4)
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default cases
|
||||
}
|
||||
else if(argc == 4 || argc == 6)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 6)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.async_hargs = std::stoi(argv[4]);
|
||||
problem_size.group_count = std::stoi(argv[5]);
|
||||
if(argc == 6)
|
||||
{
|
||||
config.async_hargs = std::stoi(argv[4]);
|
||||
problem_size.group_count = std::stoi(argv[5]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -299,18 +300,33 @@ bool run_grouped_gemm_example(int argc, char* argv[])
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: async hargs (0=n0, 1=yes)\n");
|
||||
printf("arg5: group count (default=16)");
|
||||
exit(0);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
// Lambda to get stride based on layout
|
||||
auto get_stride = [](auto layout, auto row_dim, auto col_dim) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col_dim;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row_dim;
|
||||
}
|
||||
};
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(256 + 256 * i);
|
||||
problem_size.Ns.push_back(128 + 128 * i);
|
||||
problem_size.Ks.push_back(128 + 64 * i);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
problem_size.stride_As.push_back(
|
||||
get_stride(ALayout{}, problem_size.Ms[i], problem_size.Ks[i]));
|
||||
problem_size.stride_Bs.push_back(
|
||||
get_stride(BLayout{}, problem_size.Ks[i], problem_size.Ns[i]));
|
||||
problem_size.stride_Cs.push_back(
|
||||
get_stride(ELayout{}, problem_size.Ms[i], problem_size.Ns[i]));
|
||||
}
|
||||
|
||||
return run_grouped_gemm(problem_size, config);
|
||||
|
||||
@@ -82,37 +82,29 @@ int main(int argc, char* argv[])
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
|
||||
ck::index_t M = 48 * 256;
|
||||
ck::index_t N = 1024;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default
|
||||
}
|
||||
else if(argc == 3)
|
||||
else if(argc == 3 || argc == 5)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
if(argc == 5)
|
||||
{
|
||||
M = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: time kernel (0=no, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
ck::index_t M = 48 * 256;
|
||||
ck::index_t N = 1024;
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 3)
|
||||
{
|
||||
M = std::stoi(argv[1]);
|
||||
N = std::stoi(argv[2]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1 to 2: M, N" << std::endl;
|
||||
return 1;
|
||||
printf("arg3-4: M, N\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
ck::index_t Stride = N;
|
||||
|
||||
@@ -182,6 +182,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
|
||||
@@ -167,6 +167,113 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
GemmConfig::Persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType (empty for no D tensors)
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout (empty for no D tensors)
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(splitk)
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
template <typename GemmConfig, typename PrecType>
|
||||
|
||||
@@ -29,7 +29,7 @@ template <typename GemmConfig,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
ck_tile::QuantType QuantMode>
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr)
|
||||
@@ -48,8 +48,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false,
|
||||
false,
|
||||
false, // PreshuffleQuant
|
||||
false, // PreshuffleB
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
@@ -67,18 +67,29 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
using QuantGemmProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
transpose_c,
|
||||
BDataType,
|
||||
scheduler>;
|
||||
using QuantGemmProblem = typename std::conditional<
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
128>, // QuantGroupSize
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
transpose_c,
|
||||
BDataType,
|
||||
scheduler>>::type;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<QuantGemmProblem>;
|
||||
using GemmPipeline =
|
||||
typename std::conditional<QuantMode == ck_tile::QuantType::BQuantGrouped,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>::type;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_BQUANT_COMPUTE_V3 2
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
@@ -41,6 +42,14 @@ struct GemmTypeConfig<ck_tile::fp8_t>
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
@@ -77,24 +86,11 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
@@ -122,8 +118,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.")
|
||||
.insert("kbatch", "1", "kbatch for SplitK")
|
||||
.insert("quant_mode", "tensor", "Choose tensor (default), or rowcol");
|
||||
;
|
||||
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -43,8 +43,8 @@ template <typename GemmConfig,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(int n_warmup,
|
||||
int n_repeat,
|
||||
int group_count,
|
||||
@@ -159,11 +159,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
return group_count != 0 && ((args.size() == static_cast<size_t>(group_count)) && ...);
|
||||
};
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
const int kbatch = arg_parser.get_int("kbatch");
|
||||
bool validate = arg_parser.get_bool("validate");
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
const int kbatch = arg_parser.get_int("kbatch");
|
||||
bool validate = arg_parser.get_bool("validate");
|
||||
const ck_tile::index_t QuantGroupSize = 128;
|
||||
|
||||
if(kbatch > 1 && validate && warmup + repeat > 1)
|
||||
{
|
||||
@@ -172,9 +173,11 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
validate = false;
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
std::vector<ck_tile::index_t> AQs; // dimension of AQ tensor is calculated from A tensor
|
||||
std::vector<ck_tile::index_t> BQs; // dimension of BQ tensor is calculated from B tensor
|
||||
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
|
||||
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
|
||||
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
|
||||
@@ -252,6 +255,15 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
|
||||
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
AQK = 0; // No A quantization
|
||||
BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize
|
||||
if(K % QuantGroupSize != 0)
|
||||
{
|
||||
throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode");
|
||||
}
|
||||
}
|
||||
|
||||
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
|
||||
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
|
||||
@@ -289,6 +301,13 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout))));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
|
||||
ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout))));
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout))));
|
||||
}
|
||||
|
||||
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
|
||||
@@ -394,6 +413,17 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
bq_tensors[i],
|
||||
c_m_n_host_ref);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
false>(
|
||||
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
}
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
@@ -441,42 +471,6 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
QuantMode>(
|
||||
argc, argv, Col{}, Col{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
@@ -513,6 +507,41 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported quantization mode!");
|
||||
}
|
||||
}
|
||||
if(data_type == "bf8")
|
||||
{
|
||||
if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported quantization mode!");
|
||||
|
||||
@@ -70,23 +70,13 @@ float invoke_gemm(int n_warmup,
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GemmConfig::Preshuffle)
|
||||
{
|
||||
// not supported yet
|
||||
throw std::runtime_error(
|
||||
"Persistent grouped gemm with preshuffle is not supported yet");
|
||||
}
|
||||
|
||||
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to haveCollapse
|
||||
// commentComment on line L74tenpercent commented on Sep 5, 2025 tenpercenton Sep 5,
|
||||
// 2025ContributorMore actionsdid you intend to remove the comment?Write a replyResolve
|
||||
// commentCode has comments. Press enter to view. the gemm problems known on the host.
|
||||
// Instead, we can just pass the pointer to the kernel and let the workgroups figure out
|
||||
// which tiles to work on. This is useful when the gemm problems are generated dynamically.
|
||||
// In this example however, we generate the `kargs` using the known gemm_descs,
|
||||
// and copy the gemm descriptions to the device memory.
|
||||
// The contents of the memory pointed to by `kargs_ptr` pointer could be
|
||||
// written by e.g. another kernel from earlier stage.
|
||||
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have the gemm
|
||||
// problems known on the host. Instead, we can just pass the pointer to the kernel and let
|
||||
// the workgroups figure out which tiles to work on. This is useful when the gemm problems
|
||||
// are generated dynamically. In this example however, we generate the `kargs` using the
|
||||
// known gemm_descs, and copy the gemm descriptions to the device memory. The contents of
|
||||
// the memory pointed to by `kargs_ptr` pointer could be written by e.g. another kernel from
|
||||
// earlier stage.
|
||||
|
||||
std::vector<ck_tile::GemmTransKernelArg<>> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
|
||||
@@ -4,6 +4,9 @@ list(APPEND EXAMPLE_CONV_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion
|
||||
add_executable(tile_example_grouped_conv_fwd EXCLUDE_FROM_ALL grouped_convolution_forward.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_grouped_conv_fwd_bias_clamp EXCLUDE_FROM_ALL grouped_convolution_forward_bias_clamp.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_fwd_bias_clamp PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_grouped_conv_bwd_weight EXCLUDE_FROM_ALL grouped_convolution_backward_weight.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
#include "grouped_convolution_forward_invoker.hpp"
|
||||
#include "run_grouped_convolution_fwd_bias_clamp_example.inc"
|
||||
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
int run_grouped_conv_fwd_bias_clamp_example(int argc, char* argv[])
|
||||
{
|
||||
using Invoker = GroupedConvolutionForwardInvoker;
|
||||
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string in_layout = arg_parser.get_str("in_layout");
|
||||
std::string wei_layout = arg_parser.get_str("wei_layout");
|
||||
std::string out_layout = arg_parser.get_str("out_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_fwd_bias_clamp_example_prec_type<Invoker,
|
||||
GemmConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_fwd_bias_clamp_example_prec_type<Invoker,
|
||||
GemmConfig<ck_tile::bf16_t>,
|
||||
ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_fwd_bias_clamp_example<GemmConfigComputeV3_WMMA>(argc, argv);
|
||||
#else
|
||||
return !run_grouped_conv_fwd_bias_clamp_example<GemmConfigComputeV3>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
@@ -15,10 +15,10 @@ struct GroupedConvolutionForwardInvoker
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDElementWise = ck_tile::element_wise::PassThrough>
|
||||
static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<CDElementWise>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
@@ -49,7 +49,8 @@ struct GroupedConvolutionForwardInvoker
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
VectorSizeC,
|
||||
CDElementWise>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
GemmConfig::kPadM,
|
||||
@@ -128,7 +129,7 @@ struct GroupedConvolutionForwardInvoker
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
CDElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
|
||||
@@ -0,0 +1,301 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
using BiasAndClamp = ck_tile::element_wise::
|
||||
Compose<ck_tile::element_wise::MultiDAdd, ck_tile::element_wise::Clamp, true>;
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename Invoker,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
float invoke_grouped_conv_fwd_bias_clamp(const ck_tile::GroupedConvFwdHostArgs<BiasAndClamp>& args,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
float ave_time = Invoker::template grouped_conv_fwd<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
ck_tile::tuple<OutDataType>,
|
||||
ck_tile::tuple<OutLayout>,
|
||||
BiasAndClamp>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = args.GetFlops();
|
||||
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename Invoker,
|
||||
typename InDataType,
|
||||
typename WeiDataType = InDataType,
|
||||
typename OutDataType = InDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
int run_grouped_conv_fwd_bias_clamp_example_with_layouts(
|
||||
int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout)
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using AccDataType = float;
|
||||
|
||||
std::vector<ck_tile::index_t> filter_spatial_lengths;
|
||||
std::vector<ck_tile::index_t> image_spatial_lengths;
|
||||
std::vector<ck_tile::index_t> strides;
|
||||
std::vector<ck_tile::index_t> dilations;
|
||||
std::vector<ck_tile::index_t> lpads;
|
||||
std::vector<ck_tile::index_t> rpads;
|
||||
|
||||
const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths,
|
||||
image_spatial_lengths,
|
||||
strides,
|
||||
dilations,
|
||||
lpads,
|
||||
rpads,
|
||||
arg_parser);
|
||||
|
||||
ck_tile::conv::ConvParam conv_param{num_dim_sp,
|
||||
arg_parser.get_int("g"),
|
||||
arg_parser.get_int("n"),
|
||||
arg_parser.get_int("k"),
|
||||
arg_parser.get_int("c"),
|
||||
filter_spatial_lengths,
|
||||
image_spatial_lengths,
|
||||
strides,
|
||||
dilations,
|
||||
lpads,
|
||||
rpads};
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
|
||||
const float floor = -100.f;
|
||||
const float ceil = 100.f;
|
||||
|
||||
const ck_tile::element_wise::MultiDAdd bias_op{};
|
||||
const ck_tile::element_wise::Clamp clamp_op{floor, ceil};
|
||||
const BiasAndClamp bias_clamp_op{bias_op, clamp_op};
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
|
||||
|
||||
ck_tile::HostTensor<InDataType> input(in_g_n_c_wis_desc);
|
||||
ck_tile::HostTensor<WeiDataType> weight(wei_g_k_c_xs_desc);
|
||||
ck_tile::HostTensor<OutDataType> output(out_g_n_k_wos_desc);
|
||||
ck_tile::HostTensor<OutDataType> bias(out_g_n_k_wos_desc);
|
||||
|
||||
std::string bias_str = "";
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(input);
|
||||
ck_tile::FillUniformDistribution<WeiDataType>{-5.f, 5.f}(weight);
|
||||
ck_tile::FillUniformDistribution<OutDataType>{-5.f, 5.f}(bias);
|
||||
bias_str = " (Uniform(-5,5))";
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<InDataType>{}(input);
|
||||
ck_tile::FillMonotonicSeq<WeiDataType>{}(weight);
|
||||
ck_tile::FillMonotonicSeq<OutDataType>{}(bias);
|
||||
bias_str = " (Monotonic)";
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<InDataType>{1.f, 1.f}(input);
|
||||
ck_tile::FillUniformDistribution<WeiDataType>{1.f, 1.f}(weight);
|
||||
ck_tile::FillUniformDistribution<OutDataType>{1.f, 1.f}(bias);
|
||||
bias_str = " (Constant 1)";
|
||||
}
|
||||
else
|
||||
{
|
||||
input.SetZero();
|
||||
weight.SetZero();
|
||||
bias.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem bias_dev_buf(bias.get_element_space_size_in_bytes());
|
||||
|
||||
input_dev_buf.ToDevice(input.data());
|
||||
weight_dev_buf.ToDevice(weight.data());
|
||||
output_dev_buf.SetZero();
|
||||
bias_dev_buf.ToDevice(bias.data());
|
||||
|
||||
ck_tile::GroupedConvFwdHostArgs<BiasAndClamp> args(conv_param,
|
||||
input_dev_buf.GetDeviceBuffer(),
|
||||
weight_dev_buf.GetDeviceBuffer(),
|
||||
{bias_dev_buf.GetDeviceBuffer()},
|
||||
output_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
bias_clamp_op);
|
||||
|
||||
std::cout << "Run Grouped Conv Fwd kernel with bias" << bias_str << " and clamp (" << floor
|
||||
<< ", " << ceil << ")." << std::endl;
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
std::cout << "weight: " << weight.mDesc << std::endl;
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
invoke_grouped_conv_fwd_bias_clamp<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
Invoker,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(args, n_warmup, n_repeat);
|
||||
|
||||
output_dev_buf.FromDevice(output.data());
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
// FIXME: Address this issue
|
||||
if(arg_parser.get_int("g") > 1 && init_method == 0)
|
||||
std::cerr << "Adding different bias to different groups yield incorrect results"
|
||||
<< std::endl;
|
||||
|
||||
ck_tile::HostTensor<OutDataType> output_host_ref(out_g_n_k_wos_desc);
|
||||
output_host_ref.SetZero();
|
||||
|
||||
auto bias_clamp_host = [floor,
|
||||
ceil](float& y, const float& x, const OutDataType& element_bias) {
|
||||
float x_float = ck_tile::type_convert<float>(x);
|
||||
x_float += ck_tile::type_convert<float>(element_bias);
|
||||
if(x_float < floor)
|
||||
x_float = floor;
|
||||
else if(x_float > ceil)
|
||||
x_float = ceil;
|
||||
y = x_float;
|
||||
};
|
||||
auto bias_tuple = ck_tile::make_tuple(bias);
|
||||
ck_tile::reference_grouped_conv_fwd<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
decltype(bias_clamp_host)>(
|
||||
input,
|
||||
weight,
|
||||
output_host_ref,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_,
|
||||
bias_clamp_host,
|
||||
bias_tuple);
|
||||
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(output_host_ref.mData.begin(), output_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
|
||||
GemmK, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(output,
|
||||
output_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
throw std::runtime_error("Unsupported gpu verification !!!");
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename Invoker,
|
||||
typename GemmWarpConfig,
|
||||
typename InPrecType,
|
||||
typename WeiPrecType = InPrecType,
|
||||
typename OutPrecType = InPrecType>
|
||||
int run_grouped_conv_fwd_bias_clamp_example_prec_type(
|
||||
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
|
||||
{
|
||||
// using NWGC = ck_tile::tensor_layout::convolution::NWGC;
|
||||
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
|
||||
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
|
||||
|
||||
// using GKXC = ck_tile::tensor_layout::convolution::GKXC;
|
||||
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
|
||||
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
|
||||
|
||||
// using NWGK = ck_tile::tensor_layout::convolution::NWGK;
|
||||
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
|
||||
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
// FIXME: Fix crash in 1D convolution whem using Ds tensor.
|
||||
throw std::runtime_error("1D Convolution does not support bias.");
|
||||
// return run_grouped_conv_fwd_bias_clamp_example_with_layouts<ck_tile::number<1>{},
|
||||
// GemmWarpConfig,
|
||||
// Invoker,
|
||||
// InPrecType,
|
||||
// WeiPrecType,
|
||||
// OutPrecType>(
|
||||
// argc, argv, NWGC{}, GKXC{}, NWGK{});
|
||||
}
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
return run_grouped_conv_fwd_bias_clamp_example_with_layouts<ck_tile::number<2>{},
|
||||
GemmWarpConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
|
||||
}
|
||||
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
|
||||
{
|
||||
return run_grouped_conv_fwd_bias_clamp_example_with_layouts<ck_tile::number<3>{},
|
||||
GemmWarpConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout!");
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,7 @@ template <ck_tile::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
|
||||
float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<>& args,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
@@ -128,12 +128,12 @@ int run_grouped_conv_fwd_example_with_layouts(
|
||||
weight_dev_buf.ToDevice(weight.data());
|
||||
output_dev_buf.SetZero();
|
||||
|
||||
ck_tile::GroupedConvFwdHostArgs args(conv_param,
|
||||
input_dev_buf.GetDeviceBuffer(),
|
||||
weight_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
output_dev_buf.GetDeviceBuffer(),
|
||||
kbatch);
|
||||
ck_tile::GroupedConvFwdHostArgs<> args(conv_param,
|
||||
input_dev_buf.GetDeviceBuffer(),
|
||||
weight_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
output_dev_buf.GetDeviceBuffer(),
|
||||
kbatch);
|
||||
|
||||
std::cout << "Run Grouped Conv Fwd kernel" << std::endl;
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
#include "ck_tile/host/shuffle_utils.hpp"
|
||||
#include "ck_tile/host/tensor_shuffle_utils.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(tile_example_streamk_gemm_basic EXCLUDE_FROM_ALL streamk_gemm_basic.cpp)
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
target_compile_options(tile_example_streamk_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile streamk gemm tests for current target")
|
||||
endif()
|
||||
|
||||
@@ -28,10 +28,10 @@ args:
|
||||
-stride_b tensor B stride (default:0)
|
||||
-stride_c tensor C stride (default:0)
|
||||
-v validation strategy. 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
|
||||
-prec data type. fp16/bf16 (default:fp16)
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-warmup number of iterations before benchmarking the kernel (default:50)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer timing mode. gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-init data initialization strategy. 0:random, 1:linear, 2:constant(1) (default:0)
|
||||
-flush_cache flush the cache before running the kernel (default:true)
|
||||
```
|
||||
```
|
||||
|
||||
@@ -75,6 +75,18 @@ struct DataTypeTraits<ck_tile::bf16_t>
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
@@ -94,7 +106,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmarking the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
|
||||
@@ -56,7 +56,7 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
GemmUniversalTraits,
|
||||
GemmConfig::Scheduler>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem>;
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
@@ -187,6 +187,18 @@ int run_gemm_example(int argc, char* argv[])
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig = StreamKGemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
|
||||
@@ -28,3 +28,4 @@ add_subdirectory(38_block_scale_gemm)
|
||||
add_subdirectory(39_copy)
|
||||
add_subdirectory(40_streamk_gemm)
|
||||
add_subdirectory(41_batched_contraction)
|
||||
|
||||
|
||||
@@ -23,9 +23,18 @@ This project is a prototype for a more general builder pattern for all of compos
|
||||
|
||||
To enable the experimental builder, configure your build with:
|
||||
|
||||
```sh
|
||||
cmake -DCK_EXPERIMENTAL_BUILDER=ON -DCMAKE_CXX_STANDARD=20 ...
|
||||
```bash
|
||||
cmake \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx942;gfx950" \
|
||||
-D CK_EXPERIMENTAL_BUILDER=ON \
|
||||
-D CMAKE_CXX_STANDARD=20 \
|
||||
-G Ninja \
|
||||
..
|
||||
```
|
||||
|
||||
## Building and testing
|
||||
|
||||
During development, build and test from the CK build directory with
|
||||
|
||||
143
experimental/builder/include/ck_tile/builder/builder_utils.hpp
Normal file
143
experimental/builder/include/ck_tile/builder/builder_utils.hpp
Normal file
@@ -0,0 +1,143 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Convert a static array to a sequence
|
||||
// Usage example:
|
||||
// static constexpr std::vector arr {1, 2, 3};
|
||||
// using seq = to_sequence_v<arr>; // seq is ck::Sequence<1, 2, 3>
|
||||
template <typename T, const T& Arr>
|
||||
struct to_sequence_t
|
||||
{
|
||||
private:
|
||||
template <std::size_t... Is>
|
||||
static auto get_sequence_type(std::index_sequence<Is...>) -> ck::Sequence<Arr[Is]...>;
|
||||
|
||||
// Helper method to handler the unusual .Size() method name in ck::Array.
|
||||
static constexpr auto get_size(const auto& arr)
|
||||
{
|
||||
if constexpr(requires { arr.size(); })
|
||||
{
|
||||
return arr.size();
|
||||
}
|
||||
else
|
||||
{
|
||||
return arr.Size();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
using value = decltype(get_sequence_type(std::make_index_sequence<get_size(Arr)>{}));
|
||||
};
|
||||
|
||||
template <auto& Arr>
|
||||
using to_sequence_v = typename to_sequence_t<std::remove_cvref_t<decltype(Arr)>, Arr>::value;
|
||||
|
||||
// Wrapper function to make constexpr strings a structural type for NTTP.
|
||||
template <size_t N>
|
||||
struct StringLiteral
|
||||
{
|
||||
char data[N];
|
||||
constexpr StringLiteral(const char (&str)[N])
|
||||
{
|
||||
for(size_t i = 0; i < N; ++i)
|
||||
data[i] = str[i];
|
||||
}
|
||||
|
||||
constexpr bool operator==(const StringLiteral<N>& other) const
|
||||
{
|
||||
for(size_t i = 0; i < N; ++i)
|
||||
{
|
||||
if(data[i] != other.data[i])
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// This is a C++17 deduction guide. It allows the compiler to automatically
|
||||
// deduce the template argument `N` for `StringLiteral` from a string literal
|
||||
// constructor argument. For example, you can write `StringLiteral s{"foo"};`
|
||||
// instead of `StringLiteral<4> s{"foo"};`.
|
||||
template <size_t N>
|
||||
StringLiteral(const char (&)[N]) -> StringLiteral<N>;
|
||||
|
||||
// Helper to provide a readable error for unsupported enum values.
|
||||
// The compiler will print the name of this struct in the error message, so
|
||||
// the name of the enum value will appear instead of just its integer value.
|
||||
template <auto T>
|
||||
struct UnsupportedEnumValue
|
||||
{
|
||||
};
|
||||
|
||||
// Helper functions to convert enums to strings
|
||||
constexpr std::string_view ConvDirectionToString(ConvDirection dir)
|
||||
{
|
||||
switch(dir)
|
||||
{
|
||||
case ConvDirection::FORWARD: return "Forward";
|
||||
case ConvDirection::BACKWARD_DATA: return "Backward Data";
|
||||
case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view DataTypeToString(DataType dt)
|
||||
{
|
||||
switch(dt)
|
||||
{
|
||||
case DataType::FP16: return "FP16";
|
||||
case DataType::FP32: return "FP32";
|
||||
case DataType::BF16: return "BF16";
|
||||
case DataType::FP8: return "FP8";
|
||||
case DataType::I8: return "I8";
|
||||
case DataType::U8: return "U8";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout1D layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout1D::GNWC_GKXC_GNWK: return "GNWC_GKXC_GNWK";
|
||||
case GroupConvLayout1D::NWGC_GKXC_NWGK: return "NWGC_GKXC_NWGK";
|
||||
case GroupConvLayout1D::NGCW_GKXC_NGKW: return "NGCW_GKXC_NGKW";
|
||||
case GroupConvLayout1D::NGCW_GKCX_NGKW: return "NGCW_GKCX_NGKW";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout2D layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout2D::GNHWC_GKYXC_GNHWK: return "GNHWC_GKYXC_GNHWK";
|
||||
case GroupConvLayout2D::NHWGC_GKYXC_NHWGK: return "NHWGC_GKYXC_NHWGK";
|
||||
case GroupConvLayout2D::NGCHW_GKYXC_NGKHW: return "NGCHW_GKYXC_NGKHW";
|
||||
case GroupConvLayout2D::NGCHW_GKCYX_NGKHW: return "NGCHW_GKCYX_NGKHW";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout3D layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK: return "GNDHWC_GKZYXC_GNDHWK";
|
||||
case GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK: return "NDHWGC_GKZYXC_NDHWGK";
|
||||
case GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW: return "NGCDHW_GKZYXC_NGKDHW";
|
||||
case GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW: return "NGCDHW_GKCZYX_NGKDHW";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -0,0 +1,141 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include <concepts>
|
||||
#include <array>
|
||||
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
/********************************************************************/
|
||||
/* Descriptors for individual elements of the algorithm description */
|
||||
/********************************************************************/
|
||||
|
||||
// Concept for thread block dimensions for a GEMM problem.
|
||||
template <typename T>
|
||||
concept ThreadBlockDescriptor = requires(T t) {
|
||||
{ t.block_size } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.m } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.n } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.k } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for parameters that describe a gridwise GEMM problem.
|
||||
template <typename T>
|
||||
concept GridwiseGemmDescriptor = requires(T t) {
|
||||
{ t.ak1 } -> std::convertible_to<size_t>;
|
||||
{ t.bk1 } -> std::convertible_to<size_t>;
|
||||
{ t.m_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.n_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.m_xdl_per_wave } -> std::convertible_to<size_t>;
|
||||
{ t.n_xdl_per_wave } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for vectorized data transfer for convolution input tensors.
|
||||
template <typename T>
|
||||
concept BlockTransferDescriptor = requires(T t) {
|
||||
{ t.k0 } -> std::convertible_to<size_t>;
|
||||
{ t.m_n } -> std::convertible_to<size_t>;
|
||||
{ t.k1 } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for thread cluster dimensions for GEMM output tensor.
|
||||
template <typename T>
|
||||
concept ThreadClusterDescriptor = requires(T t) {
|
||||
{ t.m_block } -> std::convertible_to<size_t>;
|
||||
{ t.m_wave_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.n_block } -> std::convertible_to<size_t>;
|
||||
{ t.n_wave_per_xdl } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for the LDS transfer for the convolution input tensors.
|
||||
template <typename T>
|
||||
concept LdsTransferDescriptor = requires(T t) {
|
||||
{ t.src_vector_dim } -> std::convertible_to<size_t>;
|
||||
{ t.src_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.lds_dst_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.is_direct_load } -> std::convertible_to<bool>;
|
||||
{ t.lds_padding } -> std::convertible_to<bool>;
|
||||
};
|
||||
|
||||
// Concept for the convolution output tensor epilogue (copy from registers to global memory via
|
||||
// LDS).
|
||||
template <typename T>
|
||||
concept EpilogueDescriptor = requires(T t) {
|
||||
{ t.m_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
|
||||
{ t.n_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
|
||||
{ t.scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for the thread cluster access order
|
||||
template <typename T>
|
||||
concept AccessOrderDescriptor = requires(T t) {
|
||||
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
|
||||
};
|
||||
|
||||
// No requirements yet for a ConvAlogorithm concept.
|
||||
template <typename T>
|
||||
concept ConvAlgorithmDescriptor = std::is_class_v<T>;
|
||||
|
||||
/******************************************** */
|
||||
/* Requirements for the algorithm description */
|
||||
/******************************************** */
|
||||
|
||||
// Concept to check if struct specifies thread block info.
|
||||
template <typename T>
|
||||
concept SpecifiesThreadBlock = requires {
|
||||
{ T::thread_block } -> ThreadBlockDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseGemm = requires {
|
||||
{ T::gridwise_gemm } -> GridwiseGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies convolution input and output block transfer info.
|
||||
template <typename T>
|
||||
concept SpecifiesBlockTransfer = requires(T t) {
|
||||
{ T::block_transfer.block_transfer_a } -> BlockTransferDescriptor;
|
||||
{ T::block_transfer.block_transfer_b } -> BlockTransferDescriptor;
|
||||
{ T::block_transfer.thread_cluster_dims_c } -> ThreadClusterDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies LDS transfer info for tensors A, B, and C.
|
||||
template <typename T>
|
||||
concept SpecifiesLdsTransfer = requires(T t) {
|
||||
{ T::block_transfer.lds_transfer_a } -> LdsTransferDescriptor;
|
||||
{ T::block_transfer.lds_transfer_b } -> LdsTransferDescriptor;
|
||||
{ T::block_transfer.epilogue_c } -> EpilogueDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies thread cluster access order info.
|
||||
template <typename T>
|
||||
concept SpecifiesThreadClusterAccessOrder = requires(T t) {
|
||||
{ T::block_transfer.block_transfer_access_order_a } -> AccessOrderDescriptor;
|
||||
{ T::block_transfer.block_transfer_access_order_b } -> AccessOrderDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies source access order info.
|
||||
template <typename T>
|
||||
concept SpecifiesSourceAccessOrder = requires(T t) {
|
||||
{ T::block_transfer.src_access_order_a } -> AccessOrderDescriptor;
|
||||
{ T::block_transfer.src_access_order_b } -> AccessOrderDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies block_gemm_pipeline_version.
|
||||
template <typename T>
|
||||
concept SpecifiesGemmPipelineVersion = requires {
|
||||
{ T::pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesFwdConcSpecialization = requires {
|
||||
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include <concepts>
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Limits for input vector transfer.
|
||||
template <auto Value>
|
||||
concept InputVectorTransferLimits = requires {
|
||||
requires Value.src_vector_dim > 0 && Value.src_scalar_per_vector > 0 &&
|
||||
Value.lds_dst_scalar_per_vector > 0;
|
||||
};
|
||||
|
||||
// Limits for output vector transfer.
|
||||
template <auto Value>
|
||||
concept OutputVectorTransferLimits = requires {
|
||||
requires Value.scalar_per_vector > 0 && Value.m_xdl_per_wave_per_shuffle > 0 &&
|
||||
Value.n_xdl_per_wave_per_shuffle > 0;
|
||||
};
|
||||
|
||||
// Limits for access order. Must be a permutation of {0, 1, 2}.
|
||||
template <auto Value>
|
||||
concept AccessOrderLimits = requires {
|
||||
requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[1] != Value[2]) &&
|
||||
(Value[0] >= 0 && Value[0] < 3) && (Value[1] >= 0 && Value[1] < 3) &&
|
||||
(Value[2] >= 0 && Value[2] < 3));
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/conv_factory.hpp"
|
||||
#include "ck_tile/builder/versions.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
/**
|
||||
* @brief Top-level builder for creating convolution kernel instances.
|
||||
*
|
||||
* This struct serves as the main entry point for generating a convolution kernel.
|
||||
* It uses a factory pattern based on the provided signature, algorithm, and version
|
||||
* to construct the appropriate kernel instance.
|
||||
*
|
||||
* @tparam SIGNATURE The convolution signature, which describes the mathematical functionality of
|
||||
* the algorithm (e.g., data types, layouts, direction).
|
||||
* @tparam ALGORITHM The specific convolution algorithm to be used for the implementation.
|
||||
* @tparam VERSION The version of the builder implementation.
|
||||
*/
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION = LATEST_API_VERSION>
|
||||
requires SupportedVersion<VERSION> && ValidConvSignature<SIGNATURE>
|
||||
struct ConvBuilder
|
||||
{
|
||||
static constexpr auto kVersion = VERSION;
|
||||
using Factory = ConvFactory<SIGNATURE, ALGORITHM, VERSION>;
|
||||
// Output: The kernel class.
|
||||
using Instance = Factory::Instance;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
539
experimental/builder/include/ck_tile/builder/conv_factory.hpp
Normal file
539
experimental/builder/include/ck_tile/builder/conv_factory.hpp
Normal file
@@ -0,0 +1,539 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// A factory for instantiating CK convolution kernels.
|
||||
//
|
||||
// This file translates a semantic description of a convolution operation
|
||||
// (`ConvSignatureDescriptor` and `ConvAlgorithmDescriptor`) into specific,
|
||||
// low-level template arguments required by the underlying CK device-level
|
||||
// kernel implementations. This abstraction enables more complex build
|
||||
// time logic and simplifies the kernel specification.
|
||||
//
|
||||
// Key Components:
|
||||
//
|
||||
// Template Metaprogram:
|
||||
// - ConvFactory: The main factory, with specializations for different
|
||||
// convolution directions (currently only forward).
|
||||
//
|
||||
// Template Metaprogram Helpers:
|
||||
// - ConvTensorLayouts: Maps layout enums to CK layout types for different
|
||||
// spatial dimensions (2D/3D) and directions.
|
||||
// - ConvTensorTypes: Maps data type enums (FP16, BF16, FP32) to C++ types used by CK.
|
||||
// - ConvPassThroughOps: Hard-coded pass-through element-wise operations.
|
||||
// - ConvSpec: Encapsulates convolution and GEMM specialization enums.
|
||||
//
|
||||
// `constexpr` Helper Functions:
|
||||
// - SetThreadBlockInfo: Determines thread block dimensions and tile sizes.
|
||||
// - SetConvTuningInfo: Sets XDL and AK1/BK1 tuning parameters.
|
||||
// - SetFwdConvABlockTransfer: Configures A tensor block transfer parameters.
|
||||
// - SetFwdConvBBlockTransfer: Configures B tensor block transfer parameters.
|
||||
// - SetCBlockTransfer: Configures C tensor block transfer parameters.
|
||||
// - SetBlockGemmPipelineVersion: Maps pipeline version enum to CK types.
|
||||
//
|
||||
// The primary entry point is the `ConvFactory` struct, which is currently
|
||||
// specialized for forward convolutions and produces instances of
|
||||
// DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/versions.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory_internal {
|
||||
|
||||
// Type mappings from the builder FwdGroupConvLayout enum classes to the CK tensor data types.
|
||||
template <auto LayoutValue, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM> && ValidConvLayoutForSpatialDim<LayoutValue, SPATIAL_DIM>)
|
||||
struct ConvTensorLayouts
|
||||
{
|
||||
// This will trigger if a specialization for the given layout is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
using Layout = decltype(LayoutValue);
|
||||
static_assert(sizeof(Layout) == 0,
|
||||
"Internal error. Unsupported layout for convolution factory.");
|
||||
};
|
||||
|
||||
// 1D Forward Convolution Layout Specializations
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::NWGC_GKXC_NWGK, 1, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NWGC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NWGK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::NGCW_GKXC_NGKW, 1, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::GNWC_GKXC_GNWK, 1, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::GNWC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::GNWK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::NGCW_GKCX_NGKW, 1, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKCX;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::NGCHW_GKYXC_NGKHW, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKHW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::NHWGC_GKYXC_NHWGK, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NHWGC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NHWGK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::GNHWC_GKYXC_GNHWK, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::GNHWC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::GNHWK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::NGCHW_GKCYX_NGKHW, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKCYX;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKHW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, 3, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCDHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKCZYX;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKDHW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, 3, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, 3, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::GNDHWC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::GNDHWK;
|
||||
};
|
||||
|
||||
// Type mappings from builder convolution data type to CK tensor types.
|
||||
template <DataType T>
|
||||
struct ConvTensorTypes
|
||||
{
|
||||
// This will trigger if a specialization for the given DataType is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
|
||||
"Internal error. Unsupported data type for convolution factory.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::FP16>
|
||||
{
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::BF16>
|
||||
{
|
||||
using ADataType = ck::bhalf_t;
|
||||
using BDataType = ck::bhalf_t;
|
||||
using CShuffleDataType = ck::bhalf_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck::bhalf_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::FP32>
|
||||
{
|
||||
using ADataType = float;
|
||||
using BDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = float;
|
||||
};
|
||||
|
||||
template <ElementwiseOperation T>
|
||||
struct ElementwiseOps
|
||||
{
|
||||
// This will trigger if a specialization for the given DataType is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
|
||||
"Internal error. Unsupported elementwise operation for convolution factory.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOps<ElementwiseOperation::PASS_THROUGH>
|
||||
{
|
||||
using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
};
|
||||
|
||||
// The algorithm specializations for the convolution and GEMM.
|
||||
template <typename CONV_ENUM>
|
||||
requires(
|
||||
std::is_same_v<CONV_ENUM, ck::tensor_operation::device::ConvolutionForwardSpecialization>)
|
||||
struct ConvSpec
|
||||
{
|
||||
CONV_ENUM conv_spec;
|
||||
ck::tensor_operation::device::GemmSpecialization gemm_spec;
|
||||
};
|
||||
|
||||
// Deduction guide for ConvSpec to simplify brace initialization.
|
||||
template <typename CONV_ENUM, typename GEMM_ENUM>
|
||||
ConvSpec(CONV_ENUM, GEMM_ENUM) -> ConvSpec<CONV_ENUM>;
|
||||
|
||||
// Block info for a convolution.
|
||||
struct MNK
|
||||
{
|
||||
size_t m{};
|
||||
size_t n{};
|
||||
size_t k{};
|
||||
};
|
||||
struct ConvBlock
|
||||
{
|
||||
size_t block_size = 0;
|
||||
MNK per_block = {};
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr ConvBlock SetThreadBlockInfo()
|
||||
{
|
||||
constexpr auto& TB = ALGORITHM.thread_block;
|
||||
return ConvBlock{.block_size = TB.block_size,
|
||||
.per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k}};
|
||||
}
|
||||
|
||||
// Convolution tuning parameters.
|
||||
struct GridwiseGemm
|
||||
{
|
||||
size_t ak1 = 0;
|
||||
size_t bk1 = 0;
|
||||
size_t m_per_xdl = 0;
|
||||
size_t n_per_xdl = 0;
|
||||
size_t m_xdl_per_wave = 0;
|
||||
size_t n_xdl_per_wave = 0;
|
||||
};
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr GridwiseGemm SetGridwiseGemmInfo()
|
||||
{
|
||||
constexpr auto& TP = ALGORITHM.gridwise_gemm;
|
||||
return GridwiseGemm{
|
||||
.ak1 = TP.ak1,
|
||||
.bk1 = TP.bk1,
|
||||
.m_per_xdl = TP.m_per_xdl,
|
||||
.n_per_xdl = TP.n_per_xdl,
|
||||
.m_xdl_per_wave = TP.m_xdl_per_wave,
|
||||
.n_xdl_per_wave = TP.n_xdl_per_wave,
|
||||
};
|
||||
}
|
||||
|
||||
// Block transfer parameters for A or B tensor.
|
||||
struct BlockTransfer
|
||||
{
|
||||
ck::Array<size_t, 3> thread_cluster_dims = {0, 0, 0}; // k0, m, k1
|
||||
ck::Array<size_t, 3> thread_cluster_order = {0, 0, 0};
|
||||
ck::Array<size_t, 3> src_access_order = {0, 0, 0};
|
||||
size_t src_vector_dim = 0;
|
||||
size_t src_scalar_per_vector = 0;
|
||||
size_t lds_dst_scalar_per_vector = 0;
|
||||
bool is_direct_load = false;
|
||||
bool lds_padding = false;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr BlockTransfer SetFwdConvABlockTransfer()
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.block_transfer_a;
|
||||
constexpr auto& TCO = ALGORITHM.block_transfer.block_transfer_access_order_a;
|
||||
constexpr auto& SAO = ALGORITHM.block_transfer.src_access_order_a;
|
||||
constexpr auto& LDS = ALGORITHM.block_transfer.lds_transfer_a;
|
||||
|
||||
BlockTransfer block_transfer{.thread_cluster_dims = {TCL.k0, TCL.m_n, TCL.k1},
|
||||
.thread_cluster_order = {TCO.order[0], TCO.order[1], TCO.order[2]},
|
||||
.src_access_order = {SAO.order[0], SAO.order[1], SAO.order[2]},
|
||||
.src_vector_dim = LDS.src_vector_dim,
|
||||
.src_scalar_per_vector = LDS.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = LDS.lds_dst_scalar_per_vector,
|
||||
.is_direct_load = LDS.is_direct_load,
|
||||
.lds_padding = LDS.lds_padding};
|
||||
return block_transfer;
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr BlockTransfer SetFwdConvBBlockTransfer()
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.block_transfer_b;
|
||||
constexpr auto& TCO = ALGORITHM.block_transfer.block_transfer_access_order_b;
|
||||
constexpr auto& SAO = ALGORITHM.block_transfer.src_access_order_b;
|
||||
constexpr auto& LDS = ALGORITHM.block_transfer.lds_transfer_b;
|
||||
|
||||
BlockTransfer block_transfer{.thread_cluster_dims = {TCL.k0, TCL.m_n, TCL.k1},
|
||||
.thread_cluster_order = {TCO.order[0], TCO.order[1], TCO.order[2]},
|
||||
.src_access_order = {SAO.order[0], SAO.order[1], SAO.order[2]},
|
||||
.src_vector_dim = LDS.src_vector_dim,
|
||||
.src_scalar_per_vector = LDS.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = LDS.lds_dst_scalar_per_vector,
|
||||
.is_direct_load = LDS.is_direct_load,
|
||||
.lds_padding = LDS.lds_padding};
|
||||
return block_transfer;
|
||||
}
|
||||
|
||||
// Block transfer parameters for C tensor.
|
||||
struct CBlockTransfer
|
||||
{
|
||||
size_t m_xdl_per_wave_per_shuffle = 0;
|
||||
size_t n_xdl_per_wave_per_shuffle = 0;
|
||||
ck::Array<size_t, 4> thread_cluster_dims = {0, 0, 0, 0};
|
||||
size_t scalar_per_vector = 0;
|
||||
};
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr CBlockTransfer SetCBlockTransfer()
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_dims_c;
|
||||
constexpr auto& EPC = ALGORITHM.block_transfer.epilogue_c;
|
||||
CBlockTransfer block_transfer{.m_xdl_per_wave_per_shuffle = EPC.m_xdl_per_wave_per_shuffle,
|
||||
.n_xdl_per_wave_per_shuffle = EPC.n_xdl_per_wave_per_shuffle,
|
||||
.thread_cluster_dims =
|
||||
{
|
||||
TCL.m_block,
|
||||
TCL.m_wave_per_xdl,
|
||||
TCL.n_block,
|
||||
TCL.n_wave_per_xdl,
|
||||
},
|
||||
.scalar_per_vector = EPC.scalar_per_vector};
|
||||
return block_transfer;
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto version = ALGORITHM.pipeline_version;
|
||||
|
||||
if constexpr(version == BlockGemmPipelineVersion::V1)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V3)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v3;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V4)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v4;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V5)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v5;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown BlockGemmPipelineVersion");
|
||||
}
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdConvSpecialization()
|
||||
{
|
||||
constexpr auto specialization = ALGORITHM.fwd_specialization;
|
||||
|
||||
if constexpr(specialization == ConvFwdSpecialization::DEFAULT)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
}
|
||||
else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
|
||||
}
|
||||
else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
|
||||
}
|
||||
else if constexpr(specialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown ConvFwdSpecialization");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory_internal
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Primary template for the convolution factory.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
auto VERSION>
|
||||
struct ConvFactory;
|
||||
|
||||
// Factory specialization for an instance of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts =
|
||||
factory_internal::ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
// Check preconditions for the algorithm description.
|
||||
static_assert(SPATIAL_DIM == 2 || SPATIAL_DIM == 3,
|
||||
"Only 2D and 3D convolutions are supported in this factory.");
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesGridwiseGemm<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gridwise GEMM info.");
|
||||
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block transfer info.");
|
||||
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify LDS transfer info.");
|
||||
static_assert(
|
||||
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread cluster access order info.");
|
||||
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify source access order info.");
|
||||
static_assert(SpecifiesGemmPipelineVersion<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block gemm pipeline version.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr factory_internal::ConvSpec SPECIALIZATION{
|
||||
.conv_spec = FWD_CONV_SPECIALIZATION,
|
||||
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
};
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM =
|
||||
factory_internal::SetGridwiseGemmInfo<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvABlockTransfer<ALGORITHM>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvBBlockTransfer<ALGORITHM>();
|
||||
static constexpr auto C_BLOCK_TRANSFER =
|
||||
factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto PIPELINE_VERSION =
|
||||
factory_internal::SetBlockGemmPipelineVersion<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< //
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
GRIDWISE_GEMM.m_per_xdl,
|
||||
GRIDWISE_GEMM.n_per_xdl,
|
||||
GRIDWISE_GEMM.m_xdl_per_wave,
|
||||
GRIDWISE_GEMM.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
PIPELINE_SCHEDULER,
|
||||
PIPELINE_VERSION>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -0,0 +1,74 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// This file defines the compile-time "signature" for grouped convolution operations.
|
||||
// A signature is a collection of properties that fully describe a convolution kernel's
|
||||
// mathematical characteristics. It uses C++20 concepts and enums to specify these
|
||||
// properties, enabling compile-time validation and specialization.
|
||||
//
|
||||
// The core components of a signature are:
|
||||
// - Spatial dimensionality (1D, 2D, 3D)
|
||||
// - Operational direction (Forward, Backward Data, Backward Weight)
|
||||
// - Tensor memory layout (Channels First/Last)
|
||||
// - Data type (FP32, FP16, BF16)
|
||||
// - Fused element-wise operation (e.g., Bias, Clamp)
|
||||
//
|
||||
// The file also provides predicate concepts to query the properties of a given
|
||||
// signature at compile time.
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Constrains convolution to 1D, 2D, or 3D spatial dimensions.
|
||||
template <auto N>
|
||||
concept ConvSpatialDim = std::is_integral_v<decltype(N)> && (N == 1 || N == 2 || N == 3);
|
||||
|
||||
// Constraints for forward convolution layouts.
|
||||
template <auto LayoutValue, size_t SpatialDim>
|
||||
concept ValidConvLayoutForSpatialDim =
|
||||
(SpatialDim == 1 && std::same_as<decltype(LayoutValue), GroupConvLayout1D>) ||
|
||||
(SpatialDim == 2 && std::same_as<decltype(LayoutValue), GroupConvLayout2D>) ||
|
||||
(SpatialDim == 3 && std::same_as<decltype(LayoutValue), GroupConvLayout3D>);
|
||||
|
||||
// Constrains convolution data types to common floating-point types.
|
||||
template <DataType T>
|
||||
concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) ||
|
||||
(T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8);
|
||||
|
||||
// Concept for a type that defines a convolution's operational signature.
|
||||
template <typename T>
|
||||
concept ConvSignatureDescriptor = requires(T t) {
|
||||
{ t.spatial_dim } -> std::convertible_to<unsigned int>;
|
||||
{ t.direction } -> std::convertible_to<ConvDirection>;
|
||||
requires std::convertible_to<decltype(t.layout), GroupConvLayout1D> ||
|
||||
std::convertible_to<decltype(t.layout), GroupConvLayout2D> ||
|
||||
std::convertible_to<decltype(t.layout), GroupConvLayout3D>;
|
||||
{ t.data_type } -> std::convertible_to<DataType>;
|
||||
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
|
||||
};
|
||||
|
||||
// Concept to validate a convolution signature's values.
|
||||
template <auto Sig>
|
||||
concept ValidConvSignature = requires {
|
||||
requires ConvSpatialDim<Sig.spatial_dim>;
|
||||
requires ConvDataType<Sig.data_type>;
|
||||
};
|
||||
|
||||
// Predicate for forward convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
|
||||
|
||||
// Predicate for backward data convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
|
||||
|
||||
// Predicate for backward weight convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -16,25 +16,19 @@
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck_tile/ops/common/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
|
||||
namespace ck_tile::reflect::detail {
|
||||
|
||||
// Metaprogramming helper to convert ck::Sequence to constexpr std::array
|
||||
template <typename Seq>
|
||||
struct SequenceToArray;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
struct SequenceToArray<ck::Sequence<Is...>>
|
||||
{
|
||||
static constexpr std::array<int, sizeof...(Is)> value = {static_cast<int>(Is)...};
|
||||
};
|
||||
|
||||
// Convert data types to string names
|
||||
// Implementation detail for type name mapping
|
||||
// This is the single source of truth for supported data types that
|
||||
// returns an empty string to indicate an unsupported type.
|
||||
namespace impl {
|
||||
template <typename T>
|
||||
consteval std::string_view type_name()
|
||||
consteval std::string_view type_name_impl()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, ck::half_t>)
|
||||
return "fp16";
|
||||
@@ -55,20 +49,38 @@ consteval std::string_view type_name()
|
||||
else if constexpr(std::is_same_v<T, ck::bf8_t>)
|
||||
return "bf8";
|
||||
else
|
||||
static_assert(false, "unknown_type");
|
||||
return std::string_view{}; // Return empty for supported types
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
// Convert data types to string names
|
||||
// Fails at compile time for unsupported types
|
||||
template <typename T>
|
||||
consteval std::string_view type_name()
|
||||
{
|
||||
constexpr auto name = impl::type_name_impl<T>();
|
||||
static_assert(!name.empty(), "Unsupported data type");
|
||||
return name;
|
||||
}
|
||||
|
||||
// Convert layout types to string names
|
||||
// Concept that checks if a type is a valid data type
|
||||
// Uses the impl directly to avoid triggering static_assert during concept evaluation
|
||||
template <typename T>
|
||||
concept IsDataType = !impl::type_name_impl<T>().empty();
|
||||
|
||||
// Concept that checks valid layout types
|
||||
template <typename T>
|
||||
concept IsLayoutType = (std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> ||
|
||||
std::is_base_of_v<ck::tensor_layout::BaseTensorLayout, T>) &&
|
||||
requires {
|
||||
{ T::name } -> std::convertible_to<std::string_view>;
|
||||
};
|
||||
|
||||
// Convert layout types to string names
|
||||
template <IsLayoutType T>
|
||||
constexpr std::string_view layout_name()
|
||||
{
|
||||
if constexpr(std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> && requires {
|
||||
{ T::name } -> std::convertible_to<std::string_view>;
|
||||
})
|
||||
return T::name;
|
||||
else
|
||||
static_assert(false,
|
||||
"Layout type must derive from BaseTensorLayout and have name attribute");
|
||||
return T::name;
|
||||
}
|
||||
|
||||
// Convert element-wise operation types to string names
|
||||
@@ -87,64 +99,64 @@ constexpr std::string_view elementwise_op_name()
|
||||
constexpr std::string_view
|
||||
conv_fwd_spec_name(ck::tensor_operation::device::ConvolutionForwardSpecialization spec)
|
||||
{
|
||||
using ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case ConvolutionForwardSpecialization::Default: return "Default";
|
||||
case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
|
||||
case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
|
||||
case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3";
|
||||
case ConvolutionForwardSpecialization::OddC: return "OddC";
|
||||
case Default: return "Default";
|
||||
case Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
|
||||
case Filter1x1Pad0: return "Filter1x1Pad0";
|
||||
case Filter3x3: return "Filter3x3";
|
||||
case OddC: return "OddC";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert GemmSpecialization enum to string
|
||||
constexpr std::string_view gemm_spec_name(ck::tensor_operation::device::GemmSpecialization spec)
|
||||
{
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
using enum ck::tensor_operation::device::GemmSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case GemmSpecialization::Default: return "Default";
|
||||
case GemmSpecialization::MPadding: return "MPadding";
|
||||
case GemmSpecialization::NPadding: return "NPadding";
|
||||
case GemmSpecialization::KPadding: return "KPadding";
|
||||
case GemmSpecialization::MNPadding: return "MNPadding";
|
||||
case GemmSpecialization::MKPadding: return "MKPadding";
|
||||
case GemmSpecialization::NKPadding: return "NKPadding";
|
||||
case GemmSpecialization::MNKPadding: return "MNKPadding";
|
||||
case GemmSpecialization::OPadding: return "OPadding";
|
||||
case GemmSpecialization::MOPadding: return "MOPadding";
|
||||
case GemmSpecialization::NOPadding: return "NOPadding";
|
||||
case GemmSpecialization::KOPadding: return "KOPadding";
|
||||
case GemmSpecialization::MNOPadding: return "MNOPadding";
|
||||
case GemmSpecialization::MKOPadding: return "MKOPadding";
|
||||
case GemmSpecialization::NKOPadding: return "NKOPadding";
|
||||
case GemmSpecialization::MNKOPadding: return "MNKOPadding";
|
||||
case Default: return "Default";
|
||||
case MPadding: return "MPadding";
|
||||
case NPadding: return "NPadding";
|
||||
case KPadding: return "KPadding";
|
||||
case MNPadding: return "MNPadding";
|
||||
case MKPadding: return "MKPadding";
|
||||
case NKPadding: return "NKPadding";
|
||||
case MNKPadding: return "MNKPadding";
|
||||
case OPadding: return "OPadding";
|
||||
case MOPadding: return "MOPadding";
|
||||
case NOPadding: return "NOPadding";
|
||||
case KOPadding: return "KOPadding";
|
||||
case MNOPadding: return "MNOPadding";
|
||||
case MKOPadding: return "MKOPadding";
|
||||
case NKOPadding: return "NKOPadding";
|
||||
case MNKOPadding: return "MNKOPadding";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert BlockGemmPipelineScheduler enum to string
|
||||
constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineScheduler sched)
|
||||
{
|
||||
using ck::BlockGemmPipelineScheduler;
|
||||
using enum ck::BlockGemmPipelineScheduler;
|
||||
switch(sched)
|
||||
{
|
||||
case BlockGemmPipelineScheduler::Intrawave: return "Intrawave";
|
||||
case BlockGemmPipelineScheduler::Interwave: return "Interwave";
|
||||
case Intrawave: return "Intrawave";
|
||||
case Interwave: return "Interwave";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert BlockGemmPipelineVersion enum to string
|
||||
constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver)
|
||||
{
|
||||
using ck::BlockGemmPipelineVersion;
|
||||
using enum ck::BlockGemmPipelineVersion;
|
||||
switch(ver)
|
||||
{
|
||||
case BlockGemmPipelineVersion::v1: return "v1";
|
||||
case BlockGemmPipelineVersion::v2: return "v2";
|
||||
case BlockGemmPipelineVersion::v3: return "v3";
|
||||
case BlockGemmPipelineVersion::v4: return "v4";
|
||||
case BlockGemmPipelineVersion::v5: return "v5";
|
||||
case v1: return "v1";
|
||||
case v2: return "v2";
|
||||
case v3: return "v3";
|
||||
case v4: return "v4";
|
||||
case v5: return "v5";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -164,12 +176,138 @@ inline std::string array_to_string(const std::array<T, N>& arr)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
// Handle ck::Tuple (empty tuple for DsLayout/DsDataType)
|
||||
template <typename T>
|
||||
constexpr std::string_view tuple_name()
|
||||
// Metaprogramming helper to convert ck::Sequence to constexpr std::array
|
||||
template <typename Seq>
|
||||
struct SequenceToArray;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
struct SequenceToArray<ck::Sequence<Is...>>
|
||||
{
|
||||
// For now, just check if it's an empty tuple
|
||||
return "EmptyTuple";
|
||||
static constexpr std::array<int, sizeof...(Is)> value = {static_cast<int>(Is)...};
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
// Generic helper to build list-like strings (Tuple, Seq, etc.)
|
||||
//
|
||||
// Example output: "Seq(1,2,3)"
|
||||
//
|
||||
// prefix: The list-like container name (e.g. "Tuple" or "Seq")
|
||||
// converter_fn: A callable that converts each element to a string representation
|
||||
// For types: converter_fn should be a template lambda like []<typename U>() { return
|
||||
// type_name<U>(); } For values: converter_fn should be a regular lambda like [](auto value) {
|
||||
// return std::to_string(value); }
|
||||
template <typename ConverterFn, typename... Elements>
|
||||
constexpr std::string build_list_string(std::string_view prefix, const ConverterFn& converter_fn)
|
||||
{
|
||||
if constexpr(sizeof...(Elements) == 0)
|
||||
{
|
||||
return std::string(prefix) + "()";
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string result = std::string(prefix) + "(";
|
||||
std::size_t index = 0;
|
||||
((result +=
|
||||
(index++ > 0 ? "," : "") + std::string(converter_fn.template operator()<Elements>())),
|
||||
...);
|
||||
result += ")";
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Overload for value-based lists (sequences)
|
||||
template <typename ConverterFn, auto... Values>
|
||||
constexpr std::string build_list_string_values(std::string_view prefix,
|
||||
const ConverterFn& converter_fn)
|
||||
{
|
||||
if constexpr(sizeof...(Values) == 0)
|
||||
{
|
||||
return std::string(prefix) + "()";
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string result = std::string(prefix) + "(";
|
||||
std::size_t index = 0;
|
||||
((result += (index++ > 0 ? "," : "") + converter_fn(Values)), ...);
|
||||
result += ")";
|
||||
return result;
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// Convert ck::Sequence to string representation
|
||||
// Converts a ck::Sequence type to a string in the format "Seq(v1,v2,...,vn)"
|
||||
// where each value is converted using std::to_string.
|
||||
//
|
||||
// Template parameter:
|
||||
// T: Must be a ck::Sequence<...> type
|
||||
//
|
||||
// Constraints:
|
||||
// - Sequence elements must support std::to_string (typically ck::index_t)
|
||||
//
|
||||
// Examples:
|
||||
// sequence_name<ck::Sequence<>>() returns "Seq()"
|
||||
// sequence_name<ck::Sequence<42>>() returns "Seq(42)"
|
||||
// sequence_name<ck::Sequence<1,2,3>>() returns "Seq(1,2,3)"
|
||||
// sequence_name<ck::Sequence<256,128,64>>() returns "Seq(256,128,64)"
|
||||
template <typename T>
|
||||
requires requires { []<ck::index_t... Is>(ck::Sequence<Is...>*) {}(static_cast<T*>(nullptr)); }
|
||||
constexpr std::string sequence_name()
|
||||
{
|
||||
return []<ck::index_t... Is>(ck::Sequence<Is...>*) constexpr {
|
||||
auto to_string_fn = [](auto value) { return std::to_string(value); };
|
||||
return detail::build_list_string_values<decltype(to_string_fn), Is...>("Seq", to_string_fn);
|
||||
}(static_cast<T*>(nullptr));
|
||||
}
|
||||
|
||||
// Convert ck::Tuple to string representation
|
||||
// Converts a ck::Tuple type to a string in the format "Tuple(e1,e2,...,en)"
|
||||
// where each element is converted based on its type (layout names or data type names).
|
||||
//
|
||||
// Template parameter:
|
||||
// T: Must be a ck::Tuple<...> type
|
||||
//
|
||||
// Constraints:
|
||||
// - Empty tuples are supported and return "EmptyTuple"
|
||||
// - All tuple elements must be homogeneous: either all layouts (IsLayoutType) or all data types
|
||||
// (IsDataType)
|
||||
// - Mixed layouts and data types in the same tuple will cause a compile-time error
|
||||
//
|
||||
// Examples:
|
||||
// tuple_name<ck::Tuple<>>() returns "EmptyTuple"
|
||||
// tuple_name<ck::Tuple<ck::tensor_layout::gemm::RowMajor>>() returns "Tuple(RowMajor)"
|
||||
// tuple_name<ck::Tuple<NCHW,NHWC>>() returns "Tuple(NCHW,NHWC)"
|
||||
// tuple_name<ck::Tuple<ck::half_t>>() returns "Tuple(fp16)"
|
||||
// tuple_name<ck::Tuple<ck::half_t,float,double>>() returns "Tuple(fp16,fp32,fp64)"
|
||||
template <typename T>
|
||||
requires requires { []<typename... Ts>(ck::Tuple<Ts...>*) {}(static_cast<T*>(nullptr)); }
|
||||
constexpr std::string tuple_name()
|
||||
{
|
||||
return []<typename... Ts>(ck::Tuple<Ts...>*) constexpr {
|
||||
if constexpr(sizeof...(Ts) == 0)
|
||||
{
|
||||
return std::string("EmptyTuple");
|
||||
}
|
||||
else if constexpr((IsLayoutType<Ts> && ...))
|
||||
{
|
||||
// Lambda wrapper for layout_name
|
||||
auto layout_name_fn = []<typename U>() { return layout_name<U>(); };
|
||||
return detail::build_list_string<decltype(layout_name_fn), Ts...>("Tuple",
|
||||
layout_name_fn);
|
||||
}
|
||||
else if constexpr((IsDataType<Ts> && ...))
|
||||
{
|
||||
// Lambda wrapper for type_name
|
||||
auto type_name_fn = []<typename U>() { return type_name<U>(); };
|
||||
return detail::build_list_string<decltype(type_name_fn), Ts...>("Tuple", type_name_fn);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert((IsLayoutType<Ts> && ...) || (IsDataType<Ts> && ...),
|
||||
"Tuple elements must be all layouts or all data types, not mixed");
|
||||
return std::string{}; // unreachable
|
||||
}
|
||||
}(static_cast<T*>(nullptr));
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::detail
|
||||
|
||||
90
experimental/builder/include/ck_tile/builder/types.hpp
Normal file
90
experimental/builder/include/ck_tile/builder/types.hpp
Normal file
@@ -0,0 +1,90 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
enum class DataType
|
||||
{
|
||||
FP32,
|
||||
FP16,
|
||||
BF16,
|
||||
FP8,
|
||||
I8,
|
||||
U8
|
||||
};
|
||||
|
||||
// Memory layouts for 1D convolution tensors.
|
||||
// G: Group, N: Batch, K: Output Channel, C: Input Channel, W: Width
|
||||
// Enum defines Input, Weight, and Output tensor layouts respectively.
|
||||
enum class GroupConvLayout1D
|
||||
{
|
||||
GNWC_GKXC_GNWK,
|
||||
NWGC_GKXC_NWGK,
|
||||
NGCW_GKXC_NGKW,
|
||||
NGCW_GKCX_NGKW
|
||||
};
|
||||
|
||||
// Memory layouts for 2D convolution tensors.
|
||||
// G: Group, N: Batch, K: Output Channel, C: Input Channel, Y: Height, X: Width, H: Height
|
||||
// Enum defines Input, Weight, and Output tensor layouts respectively.
|
||||
enum class GroupConvLayout2D
|
||||
{
|
||||
GNHWC_GKYXC_GNHWK,
|
||||
NHWGC_GKYXC_NHWGK,
|
||||
NGCHW_GKYXC_NGKHW,
|
||||
NGCHW_GKCYX_NGKHW
|
||||
};
|
||||
|
||||
// Memory layouts for 3D convolution tensors.
|
||||
// G: Group, N: Batch, K: Output Channel, C: Input Channel, Z: Depth, Y: Height, X: Width, D: Depth,
|
||||
// H: Height Enum defines Input, Weight, and Output tensor layouts respectively.
|
||||
enum class GroupConvLayout3D
|
||||
{
|
||||
GNDHWC_GKZYXC_GNDHWK,
|
||||
NDHWGC_GKZYXC_NDHWGK,
|
||||
NGCDHW_GKZYXC_NGKDHW,
|
||||
NGCDHW_GKCZYX_NGKDHW,
|
||||
};
|
||||
|
||||
// Direction of the convolution operation.
|
||||
enum class ConvDirection
|
||||
{
|
||||
FORWARD,
|
||||
BACKWARD_DATA,
|
||||
BACKWARD_WEIGHT
|
||||
};
|
||||
|
||||
// Fused element-wise operations.
|
||||
enum class ElementwiseOperation
|
||||
{
|
||||
BIAS,
|
||||
BIAS_CLAMP,
|
||||
BIAS_BNORM_CLAMP,
|
||||
BILINEAR,
|
||||
CLAMP,
|
||||
SCALE,
|
||||
PASS_THROUGH
|
||||
};
|
||||
|
||||
// Enums for the current block GEMM pipeline versions.
|
||||
enum class BlockGemmPipelineVersion
|
||||
{
|
||||
V1,
|
||||
V2,
|
||||
V3,
|
||||
V4,
|
||||
V5
|
||||
};
|
||||
|
||||
// Enums for the forward convolution specialization.
|
||||
enum class ConvFwdSpecialization
|
||||
{
|
||||
DEFAULT,
|
||||
FILTER_1X1_PAD0,
|
||||
FILTER_1X1_STRIDE1_PAD0,
|
||||
FILTER_3x3
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
18
experimental/builder/include/ck_tile/builder/versions.hpp
Normal file
18
experimental/builder/include/ck_tile/builder/versions.hpp
Normal file
@@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <string_view>
|
||||
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
static constexpr StringLiteral V0_0_0 = "0.0.0";
|
||||
static constexpr StringLiteral V0_1_0 = "0.1.0";
|
||||
|
||||
static constexpr StringLiteral LATEST_API_VERSION = V0_1_0;
|
||||
|
||||
template <StringLiteral V>
|
||||
concept SupportedVersion = (V == V0_0_0) || (V == V0_1_0);
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -2,11 +2,12 @@ include(gtest)
|
||||
|
||||
# Helper function to create a gtest executable with common properties
|
||||
function(add_ck_builder_test test_name)
|
||||
add_executable(${test_name} ${ARGN})
|
||||
add_executable(${test_name} ${ARGN} testing_utils.cpp)
|
||||
target_compile_features(${test_name} PRIVATE cxx_std_20)
|
||||
target_include_directories(${test_name} PRIVATE
|
||||
"${PROJECT_SOURCE_DIR}/experimental/builder/include"
|
||||
"${PROJECT_SOURCE_DIR}/include"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}"
|
||||
)
|
||||
target_compile_options(${test_name} PRIVATE
|
||||
-Wno-global-constructors
|
||||
@@ -15,12 +16,32 @@ function(add_ck_builder_test test_name)
|
||||
target_link_libraries(${test_name} PRIVATE GTest::gtest_main GTest::gmock)
|
||||
endfunction()
|
||||
|
||||
# The test_conv_builder target has all the unit tests (each test should run < 10 ms)
|
||||
add_ck_builder_test(test_conv_builder
|
||||
test_conv_builder.cpp
|
||||
test_instance_traits.cpp
|
||||
testing_utils.cpp)
|
||||
test_instance_traits_util.cpp)
|
||||
|
||||
# Testing the virtual GetInstanceString methods requires kernel compilation.
|
||||
add_ck_builder_test(test_get_instance_string
|
||||
test_get_instance_string.cpp)
|
||||
|
||||
add_ck_builder_test(test_inline_diff test_inline_diff.cpp testing_utils.cpp)
|
||||
add_ck_builder_test(test_inline_diff test_inline_diff.cpp)
|
||||
|
||||
function(add_ck_factory_test test_name)
|
||||
add_ck_builder_test(${test_name} ${ARGN})
|
||||
target_link_libraries(${test_name} PRIVATE composablekernels::device_conv_operations)
|
||||
endfunction()
|
||||
|
||||
add_ck_factory_test(test_testing_utils test_testing_utils.cpp)
|
||||
add_ck_factory_test(test_ck_factory_grouped_convolution_forward test_ck_factory_grouped_convolution_forward.cpp)
|
||||
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_clamp test_ck_factory_grouped_convolution_forward_clamp.cpp)
|
||||
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_convscale test_ck_factory_grouped_convolution_forward_convscale.cpp)
|
||||
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_bilinear test_ck_factory_grouped_convolution_forward_bilinear.cpp)
|
||||
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_scale test_ck_factory_grouped_convolution_forward_scale.cpp)
|
||||
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_scaleadd_ab test_ck_factory_grouped_convolution_forward_scaleadd_ab.cpp)
|
||||
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_bias_clamp test_ck_factory_grouped_convolution_forward_bias_clamp.cpp)
|
||||
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_bias_bnorm_clamp test_ck_factory_grouped_convolution_forward_bias_bnorm_clamp.cpp)
|
||||
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp)
|
||||
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp)
|
||||
|
||||
|
||||
47
experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp
Normal file
47
experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp
Normal file
@@ -0,0 +1,47 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv2DBF16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
// 2D BF16 NHWGC (channels-last) with Pipeline V1 and DEFAULT
|
||||
TEST_F(FwdConv2DBF16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V1,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
|
||||
// 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3
|
||||
TEST_F(FwdConv2DBF16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V5,
|
||||
ConvFwdSpecialization::FILTER_3x3>();
|
||||
}
|
||||
26
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp
Normal file
26
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv2DFP16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(FwdConv2DFP16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V3,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
26
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp
Normal file
26
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv2DFP32Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(FwdConv2DFP32Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V4,
|
||||
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
|
||||
}
|
||||
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp
Normal file
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv3DBF16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
// 3D BF16 GNDHWC (group-first, channels-last) with Pipeline V3 and DEFAULT
|
||||
TEST_F(FwdConv3DBF16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V3,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp
Normal file
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv3DFP16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
// 3D FP16 NDHWGC (channels-last) with Pipeline V4 and FILTER_1X1_PAD0
|
||||
TEST_F(FwdConv3DFP16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V4,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp
Normal file
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv3DFP32Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
// 3D FP32 NGCDHW (channels-first) with Pipeline V1 and FILTER_1X1_PAD0
|
||||
TEST_F(FwdConv3DFP32Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V1,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
119
experimental/builder/test/impl/conv_algorithm_types.hpp
Normal file
119
experimental/builder/test/impl/conv_algorithm_types.hpp
Normal file
@@ -0,0 +1,119 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
|
||||
// Convenience struct for a tuple of m, n, and k values.
|
||||
template <typename T>
|
||||
struct MNK
|
||||
{
|
||||
T m{};
|
||||
T n{};
|
||||
T k{};
|
||||
};
|
||||
|
||||
// Specify thread block dimensions for a GEMM.
|
||||
struct ThreadBlock
|
||||
{
|
||||
// Thread block size.
|
||||
size_t block_size;
|
||||
// Size of the submatrix problem in a thread block.
|
||||
MNK<size_t> tile_size;
|
||||
};
|
||||
static_assert(ckb::ThreadBlockDescriptor<ThreadBlock>);
|
||||
|
||||
// Describe gridwise GEMM parameters.
|
||||
struct GridwiseGemm
|
||||
{
|
||||
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
|
||||
size_t ak1 = 0;
|
||||
size_t bk1 = 0;
|
||||
size_t m_per_xdl = 0;
|
||||
size_t n_per_xdl = 0;
|
||||
size_t m_xdl_per_wave = 0;
|
||||
size_t n_xdl_per_wave = 0;
|
||||
};
|
||||
static_assert(ckb::GridwiseGemmDescriptor<GridwiseGemm>);
|
||||
|
||||
// Describe Aand B block transfer thread cluster lengths.
|
||||
struct BlockTransfer
|
||||
{
|
||||
size_t k0;
|
||||
size_t m_n;
|
||||
size_t k1;
|
||||
};
|
||||
static_assert(ckb::BlockTransferDescriptor<BlockTransfer>);
|
||||
|
||||
// Describe C block transfer thread cluster lengths.
|
||||
struct ThreadCluster
|
||||
{
|
||||
size_t m_block;
|
||||
size_t m_wave_per_xdl;
|
||||
size_t n_block;
|
||||
size_t n_wave_per_xdl;
|
||||
};
|
||||
static_assert(ThreadClusterDescriptor<ThreadCluster>);
|
||||
|
||||
struct LdsTransfer
|
||||
{
|
||||
size_t src_vector_dim;
|
||||
size_t src_scalar_per_vector;
|
||||
size_t lds_dst_scalar_per_vector;
|
||||
bool is_direct_load;
|
||||
bool lds_padding;
|
||||
};
|
||||
static_assert(LdsTransferDescriptor<LdsTransfer>);
|
||||
|
||||
struct Epilogue
|
||||
{
|
||||
size_t m_xdl_per_wave_per_shuffle;
|
||||
size_t n_xdl_per_wave_per_shuffle;
|
||||
size_t scalar_per_vector;
|
||||
};
|
||||
static_assert(EpilogueDescriptor<Epilogue>);
|
||||
|
||||
struct AccessOrder
|
||||
{
|
||||
std::array<size_t, 3> order;
|
||||
};
|
||||
static_assert(AccessOrderDescriptor<AccessOrder>);
|
||||
|
||||
struct BlockTransferABC
|
||||
{
|
||||
BlockTransfer block_transfer_a;
|
||||
BlockTransfer block_transfer_b;
|
||||
ThreadCluster thread_cluster_dims_c;
|
||||
LdsTransfer lds_transfer_a;
|
||||
LdsTransfer lds_transfer_b;
|
||||
Epilogue epilogue_c;
|
||||
AccessOrder block_transfer_access_order_a;
|
||||
AccessOrder block_transfer_access_order_b;
|
||||
AccessOrder src_access_order_a;
|
||||
AccessOrder src_access_order_b;
|
||||
};
|
||||
|
||||
struct ConvAlgorithm
|
||||
{
|
||||
ThreadBlock thread_block;
|
||||
GridwiseGemm gridwise_gemm;
|
||||
BlockTransferABC block_transfer;
|
||||
BlockGemmPipelineVersion pipeline_version;
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithmDescriptor<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesThreadBlock<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesGridwiseGemm<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesBlockTransfer<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesLdsTransfer<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesThreadClusterAccessOrder<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesSourceAccessOrder<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesGemmPipelineVersion<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesFwdConcSpecialization<ConvAlgorithm>);
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
23
experimental/builder/test/impl/conv_signature_types.hpp
Normal file
23
experimental/builder/test/impl/conv_signature_types.hpp
Normal file
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
template <typename GroupConvLayout>
|
||||
struct ConvSignature
|
||||
{
|
||||
int spatial_dim;
|
||||
ConvDirection direction;
|
||||
GroupConvLayout layout;
|
||||
DataType data_type;
|
||||
ElementwiseOperation elementwise_operation;
|
||||
};
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout1D>>);
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout2D>>);
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout3D>>);
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,118 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp>
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
using ck_tile::test::InstanceSet;
|
||||
using ck_tile::test::InstancesMatch;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr static auto NumDimSpatial = 3;
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
using DsLayout = ck::Tuple<ck::tensor_layout::convolution::NDHWGK>;
|
||||
|
||||
using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD;
|
||||
using ck::tensor_operation::element_wise::Bilinear;
|
||||
using ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
template <typename type, typename computeType = type>
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
type, // InDataType
|
||||
type, // WeiDataType
|
||||
ck::Tuple<type>,
|
||||
type, // OutDataType
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear,
|
||||
computeType,
|
||||
computeType>;
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Case>
|
||||
struct CkFactoryTestBilinearFwd : public testing::Test
|
||||
{
|
||||
static auto get_actual_instances()
|
||||
{
|
||||
return InstanceSet::from_factory<typename Case::DeviceOp>();
|
||||
}
|
||||
|
||||
static auto get_expected_instances() { return InstanceSet(Case::expected); }
|
||||
};
|
||||
|
||||
struct Bilinear_F32
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<float>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct Bilinear_F32_TF32
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<float, ck::tf32_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct Bilinear_F16
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<ck::half_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct Bilinear_BF16
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<ck::bhalf_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct Bilinear_INT8
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<int8_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
using TestTypes =
|
||||
::testing::Types<Bilinear_F32, Bilinear_F32_TF32, Bilinear_F16, Bilinear_BF16, Bilinear_INT8>;
|
||||
|
||||
TYPED_TEST_SUITE(CkFactoryTestBilinearFwd, TestTypes);
|
||||
|
||||
TYPED_TEST(CkFactoryTestBilinearFwd, TestInstances)
|
||||
{
|
||||
auto actual = TestFixture::get_actual_instances();
|
||||
auto expected = TestFixture::get_expected_instances();
|
||||
|
||||
EXPECT_THAT(actual, InstancesMatch(expected));
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,246 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_add.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp>
|
||||
#include <ck/library/tensor_operation_instance/device_operation_instance_factory.hpp>
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
using ck_tile::test::InstanceSet;
|
||||
using ck_tile::test::InstancesMatch;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr static auto NumDimSpatial = 3;
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD;
|
||||
using ck::tensor_operation::device::instance::CombConvScale;
|
||||
using ck::tensor_operation::device::instance::CombConvScaleRelu;
|
||||
using ck::tensor_operation::element_wise::ConvInvscale;
|
||||
using ck::tensor_operation::element_wise::ConvScale;
|
||||
using ck::tensor_operation::element_wise::ConvScaleAdd;
|
||||
using ck::tensor_operation::element_wise::ConvScaleRelu;
|
||||
using ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
template <typename DsLayout,
|
||||
typename DsDataType,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename Act,
|
||||
typename AComputeType,
|
||||
typename BComputeType>
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
DsDataType,
|
||||
OutDataType, // OutDataType
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Act,
|
||||
AComputeType,
|
||||
BComputeType>;
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Case>
|
||||
struct CkFactoryTestConvFwd : public testing::Test
|
||||
{
|
||||
static auto get_actual_instances()
|
||||
{
|
||||
return InstanceSet::from_factory<typename Case::DeviceOp>();
|
||||
}
|
||||
|
||||
static auto get_expected_instances() { return InstanceSet(Case::expected); }
|
||||
};
|
||||
|
||||
struct F8_ConvScale
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
ck::f8_t,
|
||||
ck::f8_t,
|
||||
ck::f8_t,
|
||||
ConvScale,
|
||||
ck::f8_t,
|
||||
ck::f8_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct F8_BF8_comb1_ConvScale
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
ck::bf8_t,
|
||||
ck::bf8_t,
|
||||
ck::f8_t,
|
||||
ConvScale,
|
||||
ck::bf8_t,
|
||||
ck::bf8_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct F8_BF8_comb2_ConvScale
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
ck::f8_t,
|
||||
ck::bf8_t,
|
||||
ck::f8_t,
|
||||
ConvScale,
|
||||
ck::f8_t,
|
||||
ck::bf8_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct F8_BF8_comb3_ConvScale
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
ck::bf8_t,
|
||||
ck::f8_t,
|
||||
ck::f8_t,
|
||||
ConvScale,
|
||||
ck::bf8_t,
|
||||
ck::f8_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct F8_float_CombConvScale
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
ck::f8_t,
|
||||
ck::f8_t,
|
||||
float,
|
||||
CombConvScale,
|
||||
ck::f8_t,
|
||||
ck::f8_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct F8_ConvScaleRelu
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
ck::f8_t,
|
||||
ck::f8_t,
|
||||
ck::f8_t,
|
||||
ConvScaleRelu,
|
||||
ck::f8_t,
|
||||
ck::f8_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct F8_CombConvScaleRelu
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
ck::f8_t,
|
||||
ck::f8_t,
|
||||
float,
|
||||
CombConvScaleRelu,
|
||||
ck::f8_t,
|
||||
ck::f8_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct F8_ConvScaleAdd
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<ck::Tuple<OutLayout>,
|
||||
ck::Tuple<float>,
|
||||
ck::f8_t,
|
||||
ck::f8_t,
|
||||
ck::f8_t,
|
||||
ConvScaleAdd,
|
||||
ck::f8_t,
|
||||
ck::f8_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct F8_ConvInvscale
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
ck::f8_t,
|
||||
ck::f8_t,
|
||||
ck::f8_t,
|
||||
ConvInvscale,
|
||||
ck::f8_t,
|
||||
ck::f8_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
using TestTypes = ::testing::Types<F8_ConvScale,
|
||||
F8_BF8_comb1_ConvScale,
|
||||
F8_BF8_comb2_ConvScale,
|
||||
F8_BF8_comb3_ConvScale,
|
||||
F8_float_CombConvScale,
|
||||
F8_ConvScaleRelu,
|
||||
F8_CombConvScaleRelu,
|
||||
F8_ConvScaleAdd,
|
||||
F8_ConvInvscale>;
|
||||
|
||||
TYPED_TEST_SUITE(CkFactoryTestConvFwd, TestTypes);
|
||||
|
||||
TYPED_TEST(CkFactoryTestConvFwd, TestInstances)
|
||||
{
|
||||
auto actual = TestFixture::get_actual_instances();
|
||||
auto expected = TestFixture::get_expected_instances();
|
||||
|
||||
EXPECT_THAT(actual, InstancesMatch(expected));
|
||||
}
|
||||
@@ -0,0 +1,187 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dynamic_op.hpp>
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
using ck_tile::test::InstanceSet;
|
||||
using ck_tile::test::InstancesMatch;
|
||||
|
||||
namespace {
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD;
|
||||
using ck::tensor_operation::element_wise::DynamicUnaryOp;
|
||||
using ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
template <ck::index_t NumDimSpatial, typename T>
|
||||
struct DeviceOpHelper;
|
||||
|
||||
template <typename T>
|
||||
struct DeviceOpHelper<2, T>
|
||||
{
|
||||
using InLayout = ck::tensor_layout::convolution::NHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NHWGK;
|
||||
|
||||
using Type = DeviceGroupedConvFwdMultipleABD<2,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>, // DsLayout
|
||||
OutLayout,
|
||||
T, // InDataType
|
||||
T, // WeiDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
T, // OutDataType
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DynamicUnaryOp>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DeviceOpHelper<3, T>
|
||||
{
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
using Type = DeviceGroupedConvFwdMultipleABD<3,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>, // DsLayout
|
||||
OutLayout,
|
||||
T, // InDataType
|
||||
T, // WeiDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
T, // OutDataType
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DynamicUnaryOp>;
|
||||
};
|
||||
|
||||
template <ck::index_t NumDimSpatial, typename T>
|
||||
using DeviceOp = DeviceOpHelper<NumDimSpatial, T>::Type;
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Case>
|
||||
struct CkFactoryTestBilinearFwd : public testing::Test
|
||||
{
|
||||
static auto get_actual_instances()
|
||||
{
|
||||
return InstanceSet::from_factory<typename Case::DeviceOp>();
|
||||
}
|
||||
|
||||
static auto get_expected_instances() { return InstanceSet(Case::expected); }
|
||||
};
|
||||
|
||||
struct DyOp_F32_2
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<2, float>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct DyOp_F32_3
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<3, float>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct DyOp_F16_2
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<2, ck::half_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct DyOp_F16_3
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<3, ck::half_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct DyOp_BF16_2
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<2, ck::bhalf_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct DyOp_BF16_3
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<3, ck::bhalf_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct DyOp_INT8_2
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<2, int8_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct DyOp_INT8_3
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<3, int8_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
using TestTypes = ::testing::Types<DyOp_F32_2,
|
||||
DyOp_F32_3,
|
||||
DyOp_F16_2,
|
||||
DyOp_F16_3,
|
||||
DyOp_BF16_2,
|
||||
DyOp_BF16_3,
|
||||
DyOp_INT8_2,
|
||||
DyOp_INT8_3>;
|
||||
|
||||
TYPED_TEST_SUITE(CkFactoryTestBilinearFwd, TestTypes);
|
||||
|
||||
TYPED_TEST(CkFactoryTestBilinearFwd, TestInstances)
|
||||
{
|
||||
auto actual = TestFixture::get_actual_instances();
|
||||
auto expected = TestFixture::get_expected_instances();
|
||||
|
||||
EXPECT_THAT(actual, InstancesMatch(expected));
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp>
|
||||
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
using ck_tile::test::InstanceSet;
|
||||
using ck_tile::test::InstancesMatch;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr static auto NumDimSpatial = 3;
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD;
|
||||
using ck::tensor_operation::element_wise::PassThrough;
|
||||
using ck::tensor_operation::element_wise::Scale;
|
||||
|
||||
template <typename T, typename ComputeType = T>
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>, // DsLayout
|
||||
OutLayout,
|
||||
T, // InDataType
|
||||
T, // WeiDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
T, // OutDataType
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
ComputeType>;
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Case>
|
||||
struct CkFactoryTestConvFwd : public testing::Test
|
||||
{
|
||||
static auto get_actual_instances()
|
||||
{
|
||||
return InstanceSet::from_factory<typename Case::DeviceOp>();
|
||||
}
|
||||
|
||||
static auto get_expected_instances() { return InstanceSet(Case::expected); }
|
||||
};
|
||||
|
||||
struct F32
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<float>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct F32_TF32
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<float, ck::tf32_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct F16
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<ck::half_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct BF16
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<ck::bhalf_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct S8
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<int8_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
using TestTypes = ::testing::Types<F32, F32_TF32, F16, BF16, S8>;
|
||||
|
||||
TYPED_TEST_SUITE(CkFactoryTestConvFwd, TestTypes);
|
||||
|
||||
TYPED_TEST(CkFactoryTestConvFwd, TestInstances)
|
||||
{
|
||||
auto actual = TestFixture::get_actual_instances();
|
||||
auto expected = TestFixture::get_expected_instances();
|
||||
|
||||
EXPECT_THAT(actual, InstancesMatch(expected));
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp>
|
||||
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
using ck_tile::test::InstanceSet;
|
||||
using ck_tile::test::InstancesMatch;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr static auto NumDimSpatial = 3;
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD;
|
||||
using ck::tensor_operation::element_wise::PassThrough;
|
||||
using ck::tensor_operation::element_wise::ScaleAdd;
|
||||
|
||||
template <typename T>
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>, // DsLayout
|
||||
OutLayout,
|
||||
ck::Tuple<T, T>, // InDataType
|
||||
ck::Tuple<T, T>, // WeiDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
T, // OutDataType
|
||||
ScaleAdd,
|
||||
ScaleAdd,
|
||||
PassThrough,
|
||||
T>; // ComputeType
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Case>
|
||||
struct CkFactoryTestConvFwd : public testing::Test
|
||||
{
|
||||
static auto get_actual_instances()
|
||||
{
|
||||
return InstanceSet::from_factory<typename Case::DeviceOp>();
|
||||
}
|
||||
|
||||
static auto get_expected_instances() { return InstanceSet(Case::expected); }
|
||||
};
|
||||
|
||||
struct F32
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<float>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct F16
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<ck::half_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct BF16
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<ck::bhalf_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct S8
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<int8_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
using TestTypes = ::testing::Types<F32, F16, BF16, S8>;
|
||||
|
||||
TYPED_TEST_SUITE(CkFactoryTestConvFwd, TestTypes);
|
||||
|
||||
TYPED_TEST(CkFactoryTestConvFwd, TestInstances)
|
||||
{
|
||||
auto actual = TestFixture::get_actual_instances();
|
||||
auto expected = TestFixture::get_expected_instances();
|
||||
|
||||
EXPECT_THAT(actual, InstancesMatch(expected));
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp>
|
||||
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
using ck_tile::test::InstanceSet;
|
||||
using ck_tile::test::InstancesMatch;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr static auto NumDimSpatial = 3;
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
using ck::tensor_layout::convolution::G_K;
|
||||
using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD;
|
||||
using ck::tensor_operation::element_wise::PassThrough;
|
||||
using ck::tensor_operation::element_wise::ScaleAddScaleAddRelu;
|
||||
|
||||
template <typename T, typename U = T>
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<OutLayout, G_K>, // DsLayout
|
||||
OutLayout,
|
||||
T, // InDataType
|
||||
T, // WeiDataType
|
||||
ck::Tuple<U, U>, // DsDataType
|
||||
T, // OutDataType
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ScaleAddScaleAddRelu,
|
||||
T>; // ComputeType
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Case>
|
||||
struct CkFactoryTestConvFwd : public testing::Test
|
||||
{
|
||||
static auto get_actual_instances()
|
||||
{
|
||||
return InstanceSet::from_factory<typename Case::DeviceOp>();
|
||||
}
|
||||
|
||||
static auto get_expected_instances() { return InstanceSet(Case::expected); }
|
||||
};
|
||||
|
||||
struct F32
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<float>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct F16
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<ck::half_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct BF16
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<ck::bhalf_t>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
struct S8
|
||||
{
|
||||
using DeviceOp = ::DeviceOp<int8_t, float>;
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
""
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
using TestTypes = ::testing::Types<F32, F16, BF16, S8>;
|
||||
|
||||
TYPED_TEST_SUITE(CkFactoryTestConvFwd, TestTypes);
|
||||
|
||||
TYPED_TEST(CkFactoryTestConvFwd, TestInstances)
|
||||
{
|
||||
auto actual = TestFixture::get_actual_instances();
|
||||
auto expected = TestFixture::get_expected_instances();
|
||||
|
||||
EXPECT_THAT(actual, InstancesMatch(expected));
|
||||
}
|
||||
263
experimental/builder/test/test_instance_traits_util.cpp
Normal file
263
experimental/builder/test/test_instance_traits_util.cpp
Normal file
@@ -0,0 +1,263 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <ck_tile/builder/reflect/instance_traits_util.hpp>
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
|
||||
namespace ck_tile::reflect::detail {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::IsEmpty;
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceToArrayReturnsEmptyArrayForEmptySequence)
|
||||
{
|
||||
EXPECT_THAT(SequenceToArray<ck::Sequence<>>::value, IsEmpty());
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceToArrayReturnsArrayWithSingleElement)
|
||||
{
|
||||
EXPECT_THAT(SequenceToArray<ck::Sequence<42>>::value, ElementsAre(42));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceToArrayReturnsArrayWithMultipleElements)
|
||||
{
|
||||
EXPECT_THAT((SequenceToArray<ck::Sequence<1, 2, 3, 4, 5>>::value), ElementsAre(1, 2, 3, 4, 5));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TypeNameReturnsCorrectStrings)
|
||||
{
|
||||
EXPECT_THAT((std::vector<std::string_view>{type_name<ck::half_t>(),
|
||||
type_name<float>(),
|
||||
type_name<double>(),
|
||||
type_name<int8_t>(),
|
||||
type_name<int32_t>(),
|
||||
type_name<ck::bhalf_t>(),
|
||||
type_name<ck::f8_t>(),
|
||||
type_name<ck::bf8_t>()}),
|
||||
ElementsAre("fp16", "fp32", "fp64", "s8", "s32", "bf16", "fp8", "bf8"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, LayoutNameReturnsCorrectStringsForGemmLayouts)
|
||||
{
|
||||
namespace gemm = ck::tensor_layout::gemm;
|
||||
EXPECT_THAT((std::vector<std::string_view>{layout_name<gemm::RowMajor>(),
|
||||
layout_name<gemm::ColumnMajor>(),
|
||||
layout_name<gemm::MFMA>()}),
|
||||
ElementsAre("RowMajor", "ColumnMajor", "MFMA"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, LayoutNameReturnsCorrectStringsForConvLayouts)
|
||||
{
|
||||
namespace conv = ck::tensor_layout::convolution;
|
||||
EXPECT_THAT((std::vector<std::string_view>{
|
||||
// Input tensor layouts
|
||||
// TODO(deprecated): Remove non-grouped layouts once instances are removed.
|
||||
layout_name<conv::NCHW>(),
|
||||
layout_name<conv::NHWC>(),
|
||||
layout_name<conv::NCDHW>(),
|
||||
layout_name<conv::NDHWC>(),
|
||||
// Grouped input layouts
|
||||
layout_name<conv::GNCHW>(),
|
||||
layout_name<conv::GNHWC>(),
|
||||
// Weight tensor layouts
|
||||
layout_name<conv::KCYX>(),
|
||||
layout_name<conv::KYXC>(),
|
||||
layout_name<conv::GKCYX>(),
|
||||
layout_name<conv::GKYXC>(),
|
||||
// Output tensor layouts
|
||||
layout_name<conv::NKHW>(),
|
||||
layout_name<conv::NHWK>(),
|
||||
layout_name<conv::GNKHW>(),
|
||||
layout_name<conv::GNHWK>(),
|
||||
// Strided layouts
|
||||
// TODO(deprecated): Remove strided layouts once instances are removed.
|
||||
layout_name<conv::G_NHW_C>(),
|
||||
layout_name<conv::G_K_YX_C>(),
|
||||
layout_name<conv::G_NHW_K>(),
|
||||
// Bias layouts
|
||||
layout_name<conv::G_C>(),
|
||||
layout_name<conv::G_K>()}),
|
||||
ElementsAre("NCHW",
|
||||
"NHWC",
|
||||
"NCDHW",
|
||||
"NDHWC",
|
||||
"GNCHW",
|
||||
"GNHWC",
|
||||
"KCYX",
|
||||
"KYXC",
|
||||
"GKCYX",
|
||||
"GKYXC",
|
||||
"NKHW",
|
||||
"NHWK",
|
||||
"GNKHW",
|
||||
"GNHWK",
|
||||
"G_NHW_C",
|
||||
"G_K_YX_C",
|
||||
"G_NHW_K",
|
||||
"G_C",
|
||||
"G_K"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, ElementwiseOpNameReturnsCorrectStrings)
|
||||
{
|
||||
namespace element_wise = ck::tensor_operation::element_wise;
|
||||
EXPECT_THAT((std::vector<std::string_view>{
|
||||
elementwise_op_name<element_wise::PassThrough>(),
|
||||
elementwise_op_name<element_wise::Scale>(),
|
||||
elementwise_op_name<element_wise::Bilinear>(),
|
||||
elementwise_op_name<element_wise::Add>(),
|
||||
elementwise_op_name<element_wise::AddRelu>(),
|
||||
elementwise_op_name<element_wise::Relu>(),
|
||||
elementwise_op_name<element_wise::BiasNormalizeInInferClamp>(),
|
||||
elementwise_op_name<element_wise::Clamp>(),
|
||||
elementwise_op_name<element_wise::AddClamp>()}),
|
||||
ElementsAre("PassThrough",
|
||||
"Scale",
|
||||
"Bilinear",
|
||||
"Add",
|
||||
"AddRelu",
|
||||
"Relu",
|
||||
"BiasNormalizeInInferClamp",
|
||||
"Clamp",
|
||||
"AddClamp"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, ConvFwdSpecNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
EXPECT_THAT(
|
||||
(std::vector<std::string_view>{conv_fwd_spec_name(Default),
|
||||
conv_fwd_spec_name(Filter1x1Stride1Pad0),
|
||||
conv_fwd_spec_name(Filter1x1Pad0),
|
||||
conv_fwd_spec_name(Filter3x3),
|
||||
conv_fwd_spec_name(OddC)}),
|
||||
ElementsAre("Default", "Filter1x1Stride1Pad0", "Filter1x1Pad0", "Filter3x3", "OddC"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, GemmSpecNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::tensor_operation::device::GemmSpecialization;
|
||||
EXPECT_THAT((std::vector<std::string_view>{gemm_spec_name(Default),
|
||||
gemm_spec_name(MPadding),
|
||||
gemm_spec_name(NPadding),
|
||||
gemm_spec_name(KPadding),
|
||||
gemm_spec_name(MNPadding),
|
||||
gemm_spec_name(MKPadding),
|
||||
gemm_spec_name(NKPadding),
|
||||
gemm_spec_name(MNKPadding),
|
||||
gemm_spec_name(OPadding),
|
||||
gemm_spec_name(MOPadding),
|
||||
gemm_spec_name(NOPadding),
|
||||
gemm_spec_name(KOPadding),
|
||||
gemm_spec_name(MNOPadding),
|
||||
gemm_spec_name(MKOPadding),
|
||||
gemm_spec_name(NKOPadding),
|
||||
gemm_spec_name(MNKOPadding)}),
|
||||
ElementsAre("Default",
|
||||
"MPadding",
|
||||
"NPadding",
|
||||
"KPadding",
|
||||
"MNPadding",
|
||||
"MKPadding",
|
||||
"NKPadding",
|
||||
"MNKPadding",
|
||||
"OPadding",
|
||||
"MOPadding",
|
||||
"NOPadding",
|
||||
"KOPadding",
|
||||
"MNOPadding",
|
||||
"MKOPadding",
|
||||
"NKOPadding",
|
||||
"MNKOPadding"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, PipelineSchedulerNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::BlockGemmPipelineScheduler;
|
||||
EXPECT_THAT((std::vector<std::string_view>{pipeline_scheduler_name(Intrawave),
|
||||
pipeline_scheduler_name(Interwave)}),
|
||||
ElementsAre("Intrawave", "Interwave"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, PipelineVersionNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::BlockGemmPipelineVersion;
|
||||
EXPECT_THAT((std::vector<std::string_view>{pipeline_version_name(v1),
|
||||
pipeline_version_name(v2),
|
||||
pipeline_version_name(v3),
|
||||
pipeline_version_name(v4),
|
||||
pipeline_version_name(v5)}),
|
||||
ElementsAre("v1", "v2", "v3", "v4", "v5"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsEmptyTupleForEmptyTuple)
|
||||
{
|
||||
EXPECT_EQ(tuple_name<ck::Tuple<>>(), "EmptyTuple");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForSingleLayout)
|
||||
{
|
||||
EXPECT_EQ(tuple_name<ck::Tuple<ck::tensor_layout::convolution::NCHW>>(), "Tuple(NCHW)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForTwoLayouts)
|
||||
{
|
||||
EXPECT_EQ((tuple_name<ck::Tuple<ck::tensor_layout::convolution::NCHW,
|
||||
ck::tensor_layout::convolution::NHWC>>()),
|
||||
"Tuple(NCHW,NHWC)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForThreeLayouts)
|
||||
{
|
||||
EXPECT_EQ((tuple_name<ck::Tuple<ck::tensor_layout::convolution::NCHW,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::NKHW>>()),
|
||||
"Tuple(NCHW,NHWC,NKHW)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForSingleDataType)
|
||||
{
|
||||
EXPECT_EQ(tuple_name<ck::Tuple<ck::half_t>>(), "Tuple(fp16)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForTwoDataTypes)
|
||||
{
|
||||
EXPECT_EQ((tuple_name<ck::Tuple<ck::half_t, float>>()), "Tuple(fp16,fp32)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForThreeDataTypes)
|
||||
{
|
||||
EXPECT_EQ((tuple_name<ck::Tuple<ck::half_t, float, double>>()), "Tuple(fp16,fp32,fp64)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForEmptySequence)
|
||||
{
|
||||
EXPECT_EQ(sequence_name<ck::Sequence<>>(), "Seq()");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForSingleValueSequence)
|
||||
{
|
||||
EXPECT_EQ(sequence_name<ck::Sequence<42>>(), "Seq(42)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForTwoValueSequence)
|
||||
{
|
||||
EXPECT_EQ((sequence_name<ck::Sequence<1, 2>>()), "Seq(1,2)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForMultipleValueSequence)
|
||||
{
|
||||
EXPECT_EQ((sequence_name<ck::Sequence<256, 128, 64, 32, 16>>()), "Seq(256,128,64,32,16)");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace ck_tile::reflect::detail
|
||||
98
experimental/builder/test/test_testing_utils.cpp
Normal file
98
experimental/builder/test/test_testing_utils.cpp
Normal file
@@ -0,0 +1,98 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp>
|
||||
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
using ck_tile::test::InstanceMatcher;
|
||||
using ck_tile::test::InstanceSet;
|
||||
using ck_tile::test::StringEqWithDiff;
|
||||
|
||||
TEST(InstanceSet, FromFactory)
|
||||
{
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_operation::device::instance::NHWGC, // InLayout
|
||||
ck::tensor_operation::device::instance::GKYXC, // WeiLayout
|
||||
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
|
||||
ck::tensor_operation::device::instance::NHWGK, // OutLayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::half_t, // AComputeType
|
||||
ck::half_t>; // BComputeType
|
||||
|
||||
const auto instances = InstanceSet::from_factory<DeviceOp>();
|
||||
|
||||
EXPECT_THAT(instances.instances.size(), testing::Gt(0));
|
||||
|
||||
const auto* el =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,"
|
||||
"fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,PassThrough,OddC,MNKPadding,64,16,16,64,"
|
||||
"8,8,16,16,1,1,Seq(8,8,1),Seq(1,0,2),Seq(1,0,2),2,8,8,0,Seq(8,8,1),Seq(1,0,2),Seq(1,0,2),2,"
|
||||
"8,8,0,1,1,Seq(1,16,1,4),4,Intrawave,v2,fp16,fp16>";
|
||||
EXPECT_THAT(instances.instances, testing::Contains(el));
|
||||
}
|
||||
|
||||
TEST(InstanceMatcher, Basic)
|
||||
{
|
||||
auto a = InstanceSet{
|
||||
"python",
|
||||
"cobra",
|
||||
"boa",
|
||||
};
|
||||
|
||||
auto b = InstanceSet{
|
||||
"cobra",
|
||||
"boa",
|
||||
"python",
|
||||
};
|
||||
|
||||
auto c = InstanceSet{
|
||||
"adder",
|
||||
"boa",
|
||||
"cobra",
|
||||
};
|
||||
|
||||
auto d = InstanceSet{
|
||||
"boa",
|
||||
"python",
|
||||
};
|
||||
|
||||
EXPECT_THAT(a, InstancesMatch(b));
|
||||
EXPECT_THAT(c, Not(InstancesMatch(b)));
|
||||
EXPECT_THAT(a, Not(InstancesMatch(d)));
|
||||
EXPECT_THAT(d, Not(InstancesMatch(b)));
|
||||
}
|
||||
|
||||
TEST(InstanceMatcher, ExplainMatchResult)
|
||||
{
|
||||
auto actual = InstanceSet{
|
||||
"python",
|
||||
"cobra",
|
||||
"boa",
|
||||
};
|
||||
|
||||
auto expected = InstanceSet{
|
||||
"adder",
|
||||
"boa",
|
||||
"cobra",
|
||||
"rattlesnake",
|
||||
};
|
||||
|
||||
testing::StringMatchResultListener listener;
|
||||
EXPECT_TRUE(!ExplainMatchResult(InstancesMatch(expected), actual, &listener));
|
||||
|
||||
EXPECT_THAT(listener.str(),
|
||||
StringEqWithDiff("\n"
|
||||
" Missing: 2\n"
|
||||
"- adder\n"
|
||||
"- rattlesnake\n"
|
||||
"Unexpected: 1\n"
|
||||
"- python\n"));
|
||||
}
|
||||
@@ -1,21 +1,18 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <unistd.h>
|
||||
#include "testing_utils.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include "testing_utils.hpp"
|
||||
#include <unistd.h>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <ostream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
namespace ck_tile::test {
|
||||
|
||||
namespace {
|
||||
|
||||
} // namespace
|
||||
|
||||
// Wagner-Fischer Algorithm for Computing Edit Distance and Inline Diff
|
||||
//
|
||||
// OUTPUT FORMAT: [expected|actual] for differences, plain text for matches
|
||||
@@ -216,4 +213,88 @@ void StringEqWithDiffMatcher::DescribeNegationTo(std::ostream* os) const
|
||||
return ::testing::MakeMatcher(new StringEqWithDiffMatcher(expected));
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const InstanceSet& set)
|
||||
{
|
||||
// These sets can grow very large, and so its not very nice or useful to print them
|
||||
// in the event of a mismatch. Just print a brief description here, and use
|
||||
// InstancesMatcher to print a more useful message.
|
||||
return (os << "(set of " << set.instances.size() << " instances)");
|
||||
}
|
||||
|
||||
InstanceMatcher::InstanceMatcher(const InstanceSet& expected) : expected_(expected) {}
|
||||
|
||||
::testing::Matcher<InstanceSet> InstancesMatch(const InstanceSet& expected)
|
||||
{
|
||||
return ::testing::MakeMatcher(new InstanceMatcher(expected));
|
||||
}
|
||||
|
||||
bool InstanceMatcher::MatchAndExplain(InstanceSet actual,
|
||||
::testing::MatchResultListener* listener) const
|
||||
{
|
||||
if(actual.instances == expected_.instances)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
if(listener->IsInterested())
|
||||
{
|
||||
std::vector<std::string> instances;
|
||||
std::set_difference(expected_.instances.begin(),
|
||||
expected_.instances.end(),
|
||||
actual.instances.begin(),
|
||||
actual.instances.end(),
|
||||
std::back_inserter(instances));
|
||||
|
||||
*listener << "\n";
|
||||
|
||||
if(instances.size() > 0)
|
||||
{
|
||||
*listener << " Missing: " << instances.size() << "\n";
|
||||
for(const auto& instance : instances)
|
||||
{
|
||||
if(instance == "")
|
||||
{
|
||||
*listener << "- (empty string)\n";
|
||||
}
|
||||
else
|
||||
{
|
||||
*listener << "- " << instance << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
instances.clear();
|
||||
std::set_difference(actual.instances.begin(),
|
||||
actual.instances.end(),
|
||||
expected_.instances.begin(),
|
||||
expected_.instances.end(),
|
||||
std::back_inserter(instances));
|
||||
|
||||
if(instances.size() > 0)
|
||||
{
|
||||
*listener << "Unexpected: " << instances.size() << "\n";
|
||||
for(const auto& instance : instances)
|
||||
{
|
||||
if(instance == "")
|
||||
{
|
||||
*listener << "- (empty string)\n";
|
||||
}
|
||||
else
|
||||
{
|
||||
*listener << "- " << instance << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void InstanceMatcher::DescribeTo(std::ostream* os) const { *os << expected_; }
|
||||
|
||||
void InstanceMatcher::DescribeNegationTo(std::ostream* os) const
|
||||
{
|
||||
*os << "is not equal to " << expected_;
|
||||
}
|
||||
|
||||
} // namespace ck_tile::test
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck/library/tensor_operation_instance/device_operation_instance_factory.hpp>
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <iosfwd>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
|
||||
namespace ck_tile::test {
|
||||
|
||||
@@ -40,4 +45,68 @@ class StringEqWithDiffMatcher : public ::testing::MatcherInterface<std::string>
|
||||
// Factory function for the StringEqWithDiff matcher
|
||||
::testing::Matcher<std::string> StringEqWithDiff(const std::string& expected);
|
||||
|
||||
using ck::tensor_operation::device::instance::DeviceOperationInstanceFactory;
|
||||
|
||||
// This utility concept checks whether a type is a valid "Device Operation" -
|
||||
// that is, there is a valid specialization of `DeviceOperationInstanceFactory`
|
||||
// for it available.
|
||||
template <typename DeviceOp>
|
||||
concept HasCkFactory = requires {
|
||||
{
|
||||
DeviceOperationInstanceFactory<DeviceOp>::GetInstances()
|
||||
} -> std::convertible_to<std::vector<std::unique_ptr<DeviceOp>>>;
|
||||
};
|
||||
|
||||
// This structure represents a (unique) set of instances, either a statically
|
||||
// defined one (for testing) or one obtained from DeviceOperationInstanceFactory.
|
||||
// The idea is that we use this structure as a utility to compare a set of
|
||||
// instances. Instances are stored in a set so that they can be lexicographically
|
||||
// compared, this helps generating readable error messages which just contain
|
||||
// the differenses between sets.
|
||||
struct InstanceSet
|
||||
{
|
||||
explicit InstanceSet() {}
|
||||
|
||||
explicit InstanceSet(std::initializer_list<const char*> items)
|
||||
: instances(items.begin(), items.end())
|
||||
{
|
||||
}
|
||||
|
||||
template <HasCkFactory DeviceOp>
|
||||
static InstanceSet from_factory()
|
||||
{
|
||||
auto set = InstanceSet();
|
||||
|
||||
const auto ops = DeviceOperationInstanceFactory<DeviceOp>::GetInstances();
|
||||
for(const auto& op : ops)
|
||||
{
|
||||
set.instances.insert(op->GetInstanceString());
|
||||
}
|
||||
|
||||
return set;
|
||||
}
|
||||
|
||||
std::set<std::string> instances;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const InstanceSet& set);
|
||||
|
||||
// This is a custom Google Test matcher which can be used to compare two sets
|
||||
// of instance names, with utility functions that print a helpful error
|
||||
// message about the difference between the checked sets. Use `InstancesMatch`
|
||||
// to obtain an instance of this type.
|
||||
struct InstanceMatcher : public ::testing::MatcherInterface<InstanceSet>
|
||||
{
|
||||
explicit InstanceMatcher(const InstanceSet& expected);
|
||||
|
||||
bool MatchAndExplain(InstanceSet actual,
|
||||
::testing::MatchResultListener* listener) const override;
|
||||
void DescribeTo(std::ostream* os) const override;
|
||||
void DescribeNegationTo(std::ostream* os) const override;
|
||||
|
||||
InstanceSet expected_;
|
||||
};
|
||||
|
||||
::testing::Matcher<InstanceSet> InstancesMatch(const InstanceSet& expected);
|
||||
|
||||
} // namespace ck_tile::test
|
||||
|
||||
103
experimental/builder/test/utils/ckb_conv_test_common.hpp
Normal file
103
experimental/builder/test/utils/ckb_conv_test_common.hpp
Normal file
@@ -0,0 +1,103 @@
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
|
||||
namespace ck_tile::builder::test_utils {
|
||||
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
// Common test base class
|
||||
class FwdConvBuilderTestBase : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
// Common test implementation
|
||||
template <auto FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
BlockGemmPipelineVersion FwdPipelineVersion,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test()
|
||||
{
|
||||
constexpr GridwiseGemm FwdGemmParams{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 32,
|
||||
.n_per_xdl = 32,
|
||||
.m_xdl_per_wave = 4,
|
||||
.n_xdl_per_wave = 4};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.epilogue_c = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_xdl_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr ConvAlgorithm FwdConvAlgorithm{.thread_block = FwdThreadBlock,
|
||||
.gridwise_gemm = FwdGemmParams,
|
||||
.block_transfer = FwdBlockTransfer,
|
||||
.pipeline_version = FwdPipelineVersion,
|
||||
.fwd_specialization = FwdConvSpecialization};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"));
|
||||
|
||||
// Verify pipeline version is correct
|
||||
if(FwdPipelineVersion == BlockGemmPipelineVersion::V1)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos);
|
||||
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V3)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos);
|
||||
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V4)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos);
|
||||
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V5)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos);
|
||||
|
||||
// Verify specialization is correct
|
||||
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
|
||||
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
// Common thread block configurations
|
||||
constexpr ThreadBlock DefaultThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock SmallThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
@@ -12,6 +12,8 @@ namespace element_wise {
|
||||
|
||||
struct Add
|
||||
{
|
||||
static constexpr const char* name = "Add";
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
@@ -94,6 +96,8 @@ struct Add
|
||||
|
||||
struct Max
|
||||
{
|
||||
static constexpr const char* name = "Max";
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
|
||||
{
|
||||
@@ -105,6 +109,8 @@ struct Max
|
||||
|
||||
struct Min
|
||||
{
|
||||
static constexpr const char* name = "Min";
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
|
||||
{
|
||||
@@ -116,6 +122,8 @@ struct Min
|
||||
|
||||
struct Multiply
|
||||
{
|
||||
static constexpr const char* name = "Multiply";
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
@@ -208,6 +216,8 @@ struct Multiply
|
||||
|
||||
struct ScaleAdd
|
||||
{
|
||||
static constexpr const char* name = "ScaleAdd";
|
||||
|
||||
__host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {}
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
@@ -235,6 +245,8 @@ struct ScaleAdd
|
||||
|
||||
struct Subtract
|
||||
{
|
||||
static constexpr const char* name = "Subtract";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
|
||||
|
||||
@@ -279,6 +291,8 @@ struct Subtract
|
||||
|
||||
struct Bilinear
|
||||
{
|
||||
static constexpr const char* name = "Bilinear";
|
||||
|
||||
Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
@@ -353,6 +367,8 @@ struct Bilinear
|
||||
|
||||
struct AddClamp
|
||||
{
|
||||
static constexpr const char* name = "AddClamp";
|
||||
|
||||
AddClamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
|
||||
: floor_(floor), ceil_(ceil){};
|
||||
|
||||
@@ -442,6 +458,8 @@ struct AddClamp
|
||||
|
||||
struct AddRelu
|
||||
{
|
||||
static constexpr const char* name = "AddRelu";
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
@@ -523,6 +541,8 @@ struct AddRelu
|
||||
|
||||
struct AddHardswish
|
||||
{
|
||||
static constexpr const char* name = "AddHardswish";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
|
||||
|
||||
@@ -560,6 +580,8 @@ struct AddHardswish
|
||||
// E = FastGelu(C + D)
|
||||
struct AddFastGelu
|
||||
{
|
||||
static constexpr const char* name = "AddFastGelu";
|
||||
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||
|
||||
@@ -625,6 +647,8 @@ struct AddFastGelu
|
||||
// E = MultiplyFastGelu(C + D)
|
||||
struct MultiplyFastGelu
|
||||
{
|
||||
static constexpr const char* name = "MultiplyFastGelu";
|
||||
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||
|
||||
@@ -690,6 +714,8 @@ struct MultiplyFastGelu
|
||||
// E = Silu(C + D)
|
||||
struct AddSilu
|
||||
{
|
||||
static constexpr const char* name = "AddSilu";
|
||||
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||
|
||||
@@ -740,6 +766,8 @@ struct AddSilu
|
||||
|
||||
struct ConvScaleAdd
|
||||
{
|
||||
static constexpr const char* name = "ConvScaleAdd";
|
||||
|
||||
__host__ __device__ ConvScaleAdd(float scale_in = 1.f,
|
||||
float scale_wei = 1.f,
|
||||
float scale_out = 1.f)
|
||||
|
||||
@@ -13,6 +13,8 @@ namespace element_wise {
|
||||
template <typename... UnaryOpsSet>
|
||||
struct UnaryCombinedOp
|
||||
{
|
||||
static constexpr const char* name = "UnaryCombinedOp";
|
||||
|
||||
__host__ __device__ UnaryCombinedOp() : unary_ops_() {}
|
||||
|
||||
__host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {}
|
||||
@@ -33,6 +35,8 @@ struct UnaryCombinedOp
|
||||
template <typename BinaryOp, typename UnaryOp0, typename UnaryOp1>
|
||||
struct BinaryWithUnaryCombinedOp
|
||||
{
|
||||
static constexpr const char* name = "BinaryWithUnaryCombinedOp";
|
||||
|
||||
__host__ __device__ BinaryWithUnaryCombinedOp() : binary_op_(), unary_op0_(), unary_op1_() {}
|
||||
|
||||
__host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op,
|
||||
@@ -66,6 +70,8 @@ template <typename BinaryOp0,
|
||||
typename UnaryOp2>
|
||||
struct TrinaryWithUnaryCombinedOp
|
||||
{
|
||||
static constexpr const char* name = "TrinaryWithUnaryCombinedOp";
|
||||
|
||||
__host__ __device__ TrinaryWithUnaryCombinedOp()
|
||||
: binary_op0_(), binary_op1_(), unary_op0_(), unary_op1_(), unary_op2_()
|
||||
{
|
||||
|
||||
@@ -33,6 +33,8 @@ namespace element_wise {
|
||||
|
||||
struct AddReluAdd
|
||||
{
|
||||
static constexpr const char* name = "AddReluAdd";
|
||||
|
||||
template <typename Y, typename X0, typename X1, typename X2>
|
||||
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
|
||||
|
||||
@@ -102,6 +104,8 @@ struct AddReluAdd
|
||||
|
||||
struct AddHardswishAdd
|
||||
{
|
||||
static constexpr const char* name = "AddHardswishAdd";
|
||||
|
||||
template <typename Y, typename X0, typename X1, typename X2>
|
||||
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
|
||||
|
||||
@@ -134,6 +138,8 @@ struct AddHardswishAdd
|
||||
// E = C + D0 + D1
|
||||
struct AddAdd
|
||||
{
|
||||
static constexpr const char* name = "AddAdd";
|
||||
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const
|
||||
{
|
||||
@@ -163,6 +169,8 @@ struct AddAdd
|
||||
// E = (C + D0) x D1
|
||||
struct AddMultiply
|
||||
{
|
||||
static constexpr const char* name = "AddMultiply";
|
||||
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
|
||||
@@ -199,6 +207,8 @@ struct AddMultiply
|
||||
// E = C x D0 + D1
|
||||
struct MultiplyAdd
|
||||
{
|
||||
static constexpr const char* name = "MultiplyAdd";
|
||||
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
|
||||
@@ -251,6 +261,8 @@ struct MultiplyAdd
|
||||
|
||||
struct MultiplyMultiply
|
||||
{
|
||||
static constexpr const char* name = "MultiplyMultiply";
|
||||
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
@@ -306,6 +318,8 @@ struct MultiplyMultiply
|
||||
|
||||
struct MultiplyAddFastGelu
|
||||
{
|
||||
static constexpr const char* name = "MultiplyAddFastGelu";
|
||||
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
@@ -327,6 +341,8 @@ struct MultiplyAddFastGelu
|
||||
// E = FastGelu(C + D0 + D1)
|
||||
struct AddAddFastGelu
|
||||
{
|
||||
static constexpr const char* name = "AddAddFastGelu";
|
||||
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
@@ -398,6 +414,7 @@ struct AddAddFastGelu
|
||||
// E = Relu(alpha1 * C + alpha2 * D0 + D1)
|
||||
struct ScaleAddScaleAddRelu
|
||||
{
|
||||
static constexpr const char* name = "ScaleAddScaleAddRelu";
|
||||
|
||||
ScaleAddScaleAddRelu(const float alpha1 = 1.f, const float alpha2 = 1.f)
|
||||
: alpha1_(alpha1), alpha2_(alpha2)
|
||||
@@ -462,6 +479,8 @@ struct ScaleAddScaleAddRelu
|
||||
|
||||
struct Normalize
|
||||
{
|
||||
static constexpr const char* name = "Normalize";
|
||||
|
||||
// FIXME: is double absolutely necessary?
|
||||
Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {}
|
||||
|
||||
@@ -533,6 +552,8 @@ struct Normalize
|
||||
// The data type of mean and variance is used as AccDataType
|
||||
struct NormalizeInInfer
|
||||
{
|
||||
static constexpr const char* name = "NormalizeInInfer";
|
||||
|
||||
NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename T4>
|
||||
@@ -565,6 +586,8 @@ struct NormalizeInInfer
|
||||
// used by Conv+Bias+BatchNorm+Clamp inference
|
||||
struct BiasNormalizeInInferClamp
|
||||
{
|
||||
static constexpr const char* name = "BiasNormalizeInInferClamp";
|
||||
|
||||
BiasNormalizeInInferClamp(float floor = 0.f,
|
||||
float ceil = NumericLimits<float>::Max(),
|
||||
float epsilon = 1e-4)
|
||||
@@ -620,6 +643,8 @@ struct UnaryTypeConvert;
|
||||
template <>
|
||||
struct UnaryTypeConvert<float, ck::bhalf_t>
|
||||
{
|
||||
static constexpr const char* name = "UnaryTypeConvert";
|
||||
|
||||
__host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
|
||||
{
|
||||
y = ck::type_convert<float, ck::bhalf_t>(x);
|
||||
@@ -629,6 +654,8 @@ struct UnaryTypeConvert<float, ck::bhalf_t>
|
||||
template <>
|
||||
struct UnaryTypeConvert<ck::bhalf_t, float>
|
||||
{
|
||||
static constexpr const char* name = "UnaryTypeConvert";
|
||||
|
||||
__host__ __device__ void operator()(ck::bhalf_t& y, float& x) const
|
||||
{
|
||||
y = ck::type_convert<ck::bhalf_t, float>(x);
|
||||
|
||||
@@ -24,6 +24,8 @@ namespace element_wise {
|
||||
template <typename Activation>
|
||||
struct Activation_Mul_Clamp
|
||||
{
|
||||
static constexpr const char* name = "Activation_Mul_Clamp";
|
||||
|
||||
// Convolution + Activation (piecewise linear function)
|
||||
// If an activation is piecewise linear function, then Activation(Sy * Qy) = Sy * Activation(Qy)
|
||||
// Z = Activation(Y) = Activation(W @ X)
|
||||
@@ -71,6 +73,8 @@ struct Activation_Mul_Clamp
|
||||
template <typename Activation>
|
||||
struct Mul_Activation_Mul_Clamp
|
||||
{
|
||||
static constexpr const char* name = "Mul_Activation_Mul_Clamp";
|
||||
|
||||
// Convolution + Activation (non piecewise linear function)
|
||||
// Z = Activation(Y) = Activation(W @ X)
|
||||
// Sz * Qz = Activation(Sy * Qy)
|
||||
@@ -101,6 +105,8 @@ struct Mul_Activation_Mul_Clamp
|
||||
template <typename Activation>
|
||||
struct Activation_Mul2_Clamp
|
||||
{
|
||||
static constexpr const char* name = "Activation_Mul2_Clamp";
|
||||
|
||||
Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
@@ -131,6 +137,8 @@ struct Activation_Mul2_Clamp
|
||||
template <typename Activation>
|
||||
struct Add_Activation_Mul_Clamp
|
||||
{
|
||||
static constexpr const char* name = "Add_Activation_Mul_Clamp";
|
||||
|
||||
// Convolution + bias
|
||||
// Let Bias = B = Sw * Sx * Qb
|
||||
// Where Qb is int32
|
||||
@@ -175,6 +183,8 @@ struct Add_Activation_Mul_Clamp
|
||||
template <typename Activation>
|
||||
struct Add_Activation_Mul2_Clamp
|
||||
{
|
||||
static constexpr const char* name = "Add_Activation_Mul2_Clamp";
|
||||
|
||||
Add_Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
@@ -206,6 +216,8 @@ struct Add_Activation_Mul2_Clamp
|
||||
template <typename Activation>
|
||||
struct Add_Mul_Activation_Mul_Clamp
|
||||
{
|
||||
static constexpr const char* name = "Add_Mul_Activation_Mul_Clamp";
|
||||
|
||||
// Convolution + Activation (non piecewise linear function)
|
||||
// Z = Activation(Y) = Activation(W @ X + B)
|
||||
// Sz * Qz = Activation(Sy * Qy)
|
||||
@@ -250,6 +262,8 @@ struct Add_Mul_Activation_Mul_Clamp
|
||||
template <typename Activation>
|
||||
struct Add_Mul2_Activation_Mul_Clamp
|
||||
{
|
||||
static constexpr const char* name = "Add_Mul2_Activation_Mul_Clamp";
|
||||
|
||||
Add_Mul2_Activation_Mul_Clamp(float scale_z_inv, Activation activationOp)
|
||||
: scale_z_inv_(scale_z_inv), activationOp_(activationOp)
|
||||
{
|
||||
|
||||
@@ -157,6 +157,8 @@ namespace element_wise {
|
||||
|
||||
struct PassThroughPack8
|
||||
{
|
||||
static constexpr const char* name = "PassThroughPack8";
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -265,6 +267,8 @@ struct PassThroughPack8
|
||||
|
||||
struct DequantPack8
|
||||
{
|
||||
static constexpr const char* name = "DequantPack8";
|
||||
|
||||
template <typename Y, typename X, typename Z>
|
||||
__host__ __device__ void operator()(Y& y, const X& x, const Z& z) const;
|
||||
|
||||
@@ -301,6 +305,8 @@ struct DequantPack8
|
||||
|
||||
struct PassThroughPack2
|
||||
{
|
||||
static constexpr const char* name = "PassThroughPack2";
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -332,6 +338,8 @@ struct PassThroughPack2
|
||||
|
||||
struct PassThrough
|
||||
{
|
||||
static constexpr const char* name = "PassThrough";
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -556,6 +564,8 @@ struct PassThrough
|
||||
|
||||
struct UnaryConvert
|
||||
{
|
||||
static constexpr const char* name = "UnaryConvert";
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
{
|
||||
@@ -565,6 +575,8 @@ struct UnaryConvert
|
||||
|
||||
struct ConvertBF16RTN
|
||||
{
|
||||
static constexpr const char* name = "ConvertBF16RTN";
|
||||
|
||||
// convert to bf16 using round to nearest (rtn)
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
@@ -582,6 +594,8 @@ struct ConvertBF16RTN
|
||||
|
||||
struct ConvertF8SR
|
||||
{
|
||||
static constexpr const char* name = "ConvertF8SR";
|
||||
|
||||
// convert to fp8 using stochastic rounding (SR)
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
@@ -600,6 +614,8 @@ struct ConvertF8SR
|
||||
|
||||
struct ConvertF8RNE
|
||||
{
|
||||
static constexpr const char* name = "ConvertF8RNE";
|
||||
|
||||
// convert to fp8 using rounding to nearest even
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
@@ -618,6 +634,8 @@ struct ConvertF8RNE
|
||||
|
||||
struct Scale
|
||||
{
|
||||
static constexpr const char* name = "Scale";
|
||||
|
||||
__host__ __device__ Scale(float scale = 1.f) : scale_(scale) {}
|
||||
|
||||
template <typename Y, typename X>
|
||||
@@ -663,6 +681,8 @@ struct Scale
|
||||
|
||||
struct ScaleAndResetNaNToMinusInfinity
|
||||
{
|
||||
static constexpr const char* name = "ScaleAndResetNaNToMinusInfinity";
|
||||
|
||||
__host__ __device__ ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {}
|
||||
|
||||
template <typename Y, typename X>
|
||||
@@ -679,6 +699,8 @@ struct ScaleAndResetNaNToMinusInfinity
|
||||
|
||||
struct UnaryDivide
|
||||
{
|
||||
static constexpr const char* name = "UnaryDivide";
|
||||
|
||||
__host__ __device__ UnaryDivide(const int32_t divider = 1) : divider_(divider) {}
|
||||
|
||||
template <typename T>
|
||||
@@ -723,6 +745,8 @@ struct UnaryDivide
|
||||
|
||||
struct UnarySquare
|
||||
{
|
||||
static constexpr const char* name = "UnarySquare";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -739,6 +763,8 @@ struct UnarySquare
|
||||
|
||||
struct UnaryAbs
|
||||
{
|
||||
static constexpr const char* name = "UnaryAbs";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -769,6 +795,8 @@ struct UnaryAbs
|
||||
|
||||
struct UnarySqrt
|
||||
{
|
||||
static constexpr const char* name = "UnarySqrt";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -781,6 +809,8 @@ struct UnarySqrt
|
||||
|
||||
struct Clamp
|
||||
{
|
||||
static constexpr const char* name = "Clamp";
|
||||
|
||||
Clamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
|
||||
: floor_(floor), ceil_(ceil){};
|
||||
|
||||
@@ -854,6 +884,8 @@ struct Clamp
|
||||
|
||||
struct Relu
|
||||
{
|
||||
static constexpr const char* name = "Relu";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -890,6 +922,8 @@ struct Relu
|
||||
// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
|
||||
struct FastGelu
|
||||
{
|
||||
static constexpr const char* name = "FastGelu";
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -1005,6 +1039,8 @@ struct FastGelu
|
||||
// y = 0.5*x*(1+erf(x/sqrt(2)))
|
||||
struct Gelu
|
||||
{
|
||||
static constexpr const char* name = "Gelu";
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -1023,6 +1059,8 @@ struct Gelu
|
||||
|
||||
struct Sigmoid
|
||||
{
|
||||
static constexpr const char* name = "Sigmoid";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1047,6 +1085,8 @@ struct Sigmoid
|
||||
|
||||
struct Silu
|
||||
{
|
||||
static constexpr const char* name = "SiLU";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1060,6 +1100,8 @@ struct Silu
|
||||
|
||||
struct TanH
|
||||
{
|
||||
static constexpr const char* name = "TanH";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1083,6 +1125,8 @@ struct TanH
|
||||
|
||||
struct ACos
|
||||
{
|
||||
static constexpr const char* name = "ACos";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1097,6 +1141,8 @@ struct ACos
|
||||
|
||||
struct Neg
|
||||
{
|
||||
static constexpr const char* name = "Neg";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1111,6 +1157,8 @@ struct Neg
|
||||
|
||||
struct ATan
|
||||
{
|
||||
static constexpr const char* name = "ATan";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1125,6 +1173,8 @@ struct ATan
|
||||
|
||||
struct Sin
|
||||
{
|
||||
static constexpr const char* name = "Sin";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1139,6 +1189,8 @@ struct Sin
|
||||
|
||||
struct ASinH
|
||||
{
|
||||
static constexpr const char* name = "ASinH";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1153,6 +1205,8 @@ struct ASinH
|
||||
|
||||
struct Cos
|
||||
{
|
||||
static constexpr const char* name = "Cos";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1167,6 +1221,8 @@ struct Cos
|
||||
|
||||
struct ACosH
|
||||
{
|
||||
static constexpr const char* name = "ACosH";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1181,6 +1237,8 @@ struct ACosH
|
||||
|
||||
struct Tan
|
||||
{
|
||||
static constexpr const char* name = "Tan";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1195,6 +1253,8 @@ struct Tan
|
||||
|
||||
struct ATanH
|
||||
{
|
||||
static constexpr const char* name = "ATanH";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1209,6 +1269,8 @@ struct ATanH
|
||||
|
||||
struct SinH
|
||||
{
|
||||
static constexpr const char* name = "SinH";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1223,6 +1285,8 @@ struct SinH
|
||||
|
||||
struct Ceil
|
||||
{
|
||||
static constexpr const char* name = "Ceil";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1237,6 +1301,8 @@ struct Ceil
|
||||
|
||||
struct Exp
|
||||
{
|
||||
static constexpr const char* name = "Exp";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1251,6 +1317,8 @@ struct Exp
|
||||
|
||||
struct CosH
|
||||
{
|
||||
static constexpr const char* name = "CosH";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1265,6 +1333,8 @@ struct CosH
|
||||
|
||||
struct Floor
|
||||
{
|
||||
static constexpr const char* name = "Floor";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1279,6 +1349,8 @@ struct Floor
|
||||
|
||||
struct Log
|
||||
{
|
||||
static constexpr const char* name = "Log";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1293,6 +1365,8 @@ struct Log
|
||||
|
||||
struct ASin
|
||||
{
|
||||
static constexpr const char* name = "ASin";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1307,6 +1381,8 @@ struct ASin
|
||||
|
||||
struct Rcp
|
||||
{
|
||||
static constexpr const char* name = "Rcp";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1321,6 +1397,8 @@ struct Rcp
|
||||
|
||||
struct Swish
|
||||
{
|
||||
static constexpr const char* name = "Swish";
|
||||
|
||||
Swish(float beta = 1.0f) : beta_(beta) {}
|
||||
|
||||
template <typename Y, typename X>
|
||||
@@ -1350,6 +1428,8 @@ struct Swish
|
||||
|
||||
struct SoftRelu
|
||||
{
|
||||
static constexpr const char* name = "SoftRelu";
|
||||
|
||||
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1378,6 +1458,8 @@ struct SoftRelu
|
||||
|
||||
struct Power
|
||||
{
|
||||
static constexpr const char* name = "Power";
|
||||
|
||||
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
|
||||
: alpha_(alpha), beta_(beta), gamma_(gamma){};
|
||||
|
||||
@@ -1412,6 +1494,8 @@ struct Power
|
||||
|
||||
struct ClippedRelu
|
||||
{
|
||||
static constexpr const char* name = "ClippedRelu";
|
||||
|
||||
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1441,6 +1525,8 @@ struct ClippedRelu
|
||||
|
||||
struct LeakyRelu
|
||||
{
|
||||
static constexpr const char* name = "LeakyRelu";
|
||||
|
||||
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1468,6 +1554,8 @@ struct LeakyRelu
|
||||
|
||||
struct Elu
|
||||
{
|
||||
static constexpr const char* name = "Elu";
|
||||
|
||||
Elu(float alpha = 1.f) : alpha_(alpha){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1495,6 +1583,8 @@ struct Elu
|
||||
|
||||
struct Logistic
|
||||
{
|
||||
static constexpr const char* name = "Logistic";
|
||||
|
||||
Logistic(float alpha = 1.f) : alpha_(alpha){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1523,6 +1613,8 @@ struct Logistic
|
||||
|
||||
struct ConvInvscale
|
||||
{
|
||||
static constexpr const char* name = "ConvInvscale";
|
||||
|
||||
__host__ __device__ ConvInvscale(float scale_in = 1.f,
|
||||
float scale_wei = 1.f,
|
||||
float scale_out = 1.f)
|
||||
@@ -1546,6 +1638,8 @@ struct ConvInvscale
|
||||
|
||||
struct ConvScale
|
||||
{
|
||||
static constexpr const char* name = "ConvScale";
|
||||
|
||||
__host__ __device__ ConvScale(float scale_in = 1.f,
|
||||
float scale_wei = 1.f,
|
||||
float scale_out = 1.f)
|
||||
@@ -1569,6 +1663,8 @@ struct ConvScale
|
||||
|
||||
struct ConvScaleRelu
|
||||
{
|
||||
static constexpr const char* name = "ConvScaleRelu";
|
||||
|
||||
__host__ __device__ ConvScaleRelu(float scale_in = 1.f,
|
||||
float scale_wei = 1.f,
|
||||
float scale_out = 1.f)
|
||||
|
||||
@@ -16,10 +16,17 @@ __device__ void llvm_amdgcn_s_wait_dscnt(short cnt) __asm("llvm.amdgcn.s.wait.ds
|
||||
__device__ void block_sync_lds()
|
||||
{
|
||||
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
#ifdef __gfx12__
|
||||
#if defined(__gfx12__)
|
||||
llvm_amdgcn_s_wait_dscnt(0);
|
||||
asm volatile("s_barrier_signal -1\n\t"
|
||||
"s_barrier_wait -1");
|
||||
#elif defined(__gfx11__)
|
||||
// asm volatile("\
|
||||
// s_waitcnt lgkmcnt(0) \n \
|
||||
// s_barrier \
|
||||
// " ::);
|
||||
__builtin_amdgcn_s_waitcnt(0xfc07);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
#else
|
||||
// asm volatile("\
|
||||
// s_waitcnt lgkmcnt(0) \n \
|
||||
|
||||
@@ -46,7 +46,7 @@
|
||||
#include "ck_tile/host/reference/reference_topk.hpp"
|
||||
#include "ck_tile/host/reference/reference_transpose.hpp"
|
||||
#include "ck_tile/host/rotating_buffers.hpp"
|
||||
#include "ck_tile/host/shuffle_utils.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/host/tensor_shuffle_utils.hpp"
|
||||
#include "ck_tile/host/timer.hpp"
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -14,14 +15,18 @@ namespace ck_tile {
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
typename OutDataType,
|
||||
typename Elfunc = ck_tile::element_wise::PassThrough,
|
||||
typename Tuple = ck_tile::tuple<>>
|
||||
CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input,
|
||||
const HostTensor<WeiDataType>& weight,
|
||||
HostTensor<OutDataType>& output,
|
||||
std::vector<ck_tile::long_index_t> conv_strides,
|
||||
std::vector<ck_tile::long_index_t> conv_dilations,
|
||||
std::vector<ck_tile::long_index_t> in_left_pads,
|
||||
std::vector<ck_tile::long_index_t>)
|
||||
std::vector<ck_tile::long_index_t>,
|
||||
Elfunc elfunc = Elfunc{},
|
||||
Tuple ds = {})
|
||||
{
|
||||
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
weight.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
@@ -52,8 +57,12 @@ CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input
|
||||
}
|
||||
}
|
||||
}
|
||||
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, wo) = v_acc_converted;
|
||||
if constexpr(Tuple::size() > 0)
|
||||
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, wo));
|
||||
else
|
||||
elfunc(v_acc, v_acc);
|
||||
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, wo) = v_acc_out;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
@@ -95,8 +104,12 @@ CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input
|
||||
}
|
||||
}
|
||||
}
|
||||
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, ho, wo) = v_acc_converted;
|
||||
if constexpr(Tuple::size() > 0)
|
||||
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, ho, wo));
|
||||
else
|
||||
elfunc(v_acc, v_acc);
|
||||
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, ho, wo) = v_acc_out;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
@@ -145,8 +158,12 @@ CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input
|
||||
}
|
||||
}
|
||||
}
|
||||
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, d_o, ho, wo) = v_acc_converted;
|
||||
if constexpr(Tuple::size() > 0)
|
||||
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, d_o, ho, wo));
|
||||
else
|
||||
elfunc(v_acc, v_acc);
|
||||
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, d_o, ho, wo) = v_acc_out;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
|
||||
0
include/ck_tile/host/shuffle_utils.hpp → include/ck_tile/host/tensor_shuffle_utils.hpp
Normal file → Executable file
0
include/ck_tile/host/shuffle_utils.hpp → include/ck_tile/host/tensor_shuffle_utils.hpp
Normal file → Executable file
@@ -1540,6 +1540,23 @@ struct Logistic
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
struct Clamp
|
||||
{
|
||||
CK_TILE_HOST_DEVICE Clamp(float lower = std::numeric_limits<float>::lowest(),
|
||||
float upper = std::numeric_limits<float>::max())
|
||||
: lower_(lower), upper_(upper) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(T& y, const T& x) const
|
||||
{
|
||||
T lower = ck_tile::type_convert<T>(lower_);
|
||||
T upper = ck_tile::type_convert<T>(upper_);
|
||||
y = ck_tile::clamp(x, lower, upper);
|
||||
}
|
||||
|
||||
float lower_, upper_;
|
||||
};
|
||||
|
||||
struct ConvInvscale
|
||||
{
|
||||
static constexpr const char* name = "ConvInvscale";
|
||||
@@ -1629,6 +1646,55 @@ struct Cast
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Compose two unary element-wise functions into one.
|
||||
*
|
||||
*
|
||||
* @note The Ds tensor can be used by at most one of the composed functions.
|
||||
* This holds even if compositions are chained:
|
||||
* In `Compose<FA, Compose<FB, FC>>`, only one of `FA`, `FB`, or `FC` can use
|
||||
* the Ds tensor.
|
||||
*
|
||||
* @tparam FuncA The first function to be applied.
|
||||
* @tparam FuncB The second function to be applied.
|
||||
* @tparam FuncADs Whether `FuncA` uses the Ds tensor.
|
||||
* @tparam FuncBDs Whether `FuncB` uses the Ds tensor.
|
||||
*/
|
||||
template <typename FuncA, typename FuncB, bool FuncADs = false, bool FuncBDs = false>
|
||||
struct Compose
|
||||
{
|
||||
static_assert(!(FuncADs && FuncBDs), "Only one composed function may use the Ds tensor.");
|
||||
|
||||
CK_TILE_HOST_DEVICE Compose(FuncA func_a_ = FuncA{}, FuncB func_b_ = FuncB{})
|
||||
: func_a(func_a_), func_b(func_b_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename AIn, typename BOut, typename AOut = AIn, typename... ADs>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(BOut& y, const AIn& x, const ADs&... ds) const
|
||||
{
|
||||
AOut tmp;
|
||||
if constexpr(FuncADs)
|
||||
{
|
||||
func_a(tmp, x, ds...);
|
||||
func_b(y, tmp);
|
||||
}
|
||||
else if constexpr(FuncBDs)
|
||||
{
|
||||
func_a(tmp, x);
|
||||
func_b(y, tmp, ds...);
|
||||
}
|
||||
else
|
||||
{
|
||||
func_a(tmp, x);
|
||||
func_b(y, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
const FuncA func_a;
|
||||
const FuncB func_b;
|
||||
};
|
||||
|
||||
// support fastconvert of int8 to fp16
|
||||
#if 0
|
||||
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -117,6 +117,10 @@ struct CShuffleEpilogue
|
||||
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
|
||||
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
|
||||
|
||||
CDElementwise elfunc_;
|
||||
|
||||
CK_TILE_DEVICE CShuffleEpilogue(CDElementwise elfunc = CDElementwise{}) : elfunc_(elfunc) {};
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
/**
|
||||
@@ -385,7 +389,7 @@ struct CShuffleEpilogue
|
||||
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
|
||||
number<NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
|
||||
tile_elementwise_inout_unpack(elfunc_, c_ds_tiles);
|
||||
}
|
||||
|
||||
template <typename OutDramWindow, typename COutTensor>
|
||||
@@ -450,7 +454,7 @@ struct CShuffleEpilogue
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* /*p_smem*/,
|
||||
void* /* p_smem */,
|
||||
const ScaleM& scale_m = {},
|
||||
const ScaleN& scale_n = {})
|
||||
{
|
||||
|
||||
@@ -185,14 +185,6 @@ struct GemmKernelMultiABD
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// Currently MultiABD kernel doesn't support F8 data type
|
||||
if(ck_tile::get_device_name() == "gfx950" &&
|
||||
(std::is_same<ck_tile::fp8_t, ADataType>::value ||
|
||||
std::is_same<ck_tile::fp8_t, BDataType>::value ||
|
||||
std::is_same<ck_tile::fp8_t, DDataType>::value))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,478 @@
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace reboot {
|
||||
|
||||
/// @brief The Stream K GEMM kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to @ref StreamKKernel "StreamKKernel" when creating the kernel
|
||||
/// arguments object. It contains all necessary information required to build proper kernel
|
||||
/// arguments and launch the kernel on GPU. This structure defines the GEMM problem
|
||||
/// configuration by stating all required information like M,N,K sizes and respective strides.
|
||||
struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<>
|
||||
{
|
||||
CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
void* c_ptr_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_C_,
|
||||
StreamKReductionStrategy reduction_strategy_)
|
||||
: UniversalGemmHostArgs<>({a_ptr_},
|
||||
{b_ptr_},
|
||||
{/*ds_ptr*/},
|
||||
c_ptr_,
|
||||
/*k_batch_ =*/1,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
{stride_A_},
|
||||
{stride_B_},
|
||||
{/*stride_Ds_*/},
|
||||
stride_C_),
|
||||
reduction_strategy{reduction_strategy_}
|
||||
{
|
||||
}
|
||||
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy;
|
||||
};
|
||||
|
||||
/// @brief The Stream K GEMM kernel class.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This class is responsible for the Stream-K kernel, making use of UniversalGemm.
|
||||
// The main kernel functions are the operator() functions. There is one for Persistent
|
||||
// and one for Non-Persistent data parallel sections of the Stream-K algorithm.
|
||||
//
|
||||
// Both the Non-Persistent and Persistent kernels make use of `BaseGemm()` and
|
||||
// `StreamKGemm()`. `BaseGemm()` computes offsets into the A,B,C tensors, then calls
|
||||
// `RunGemm()` which runs the GEMM pipeline and epilogue. `StreamKGemm()` performs the
|
||||
// main Stream-K algorithm. Each iteration of the Stream-K loop calls `BaseGemm()`.
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct StreamKKernel
|
||||
{
|
||||
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
|
||||
/// functions.
|
||||
using UniversalGemmKernel =
|
||||
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
|
||||
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
|
||||
static constexpr bool PersistentDP = UniversalGemmKernel::PersistentKernel;
|
||||
|
||||
using TilePartitioner = TilePartitioner_;
|
||||
using GemmPipeline = GemmPipeline_;
|
||||
using EpiloguePipeline = EpiloguePipeline_;
|
||||
|
||||
static_assert(
|
||||
TilePartitioner::PERSISTENT == PersistentDP,
|
||||
"Persistent flag from TilePartitioner must match Persistent flag from UniversalGemm.");
|
||||
|
||||
/// @brief Specify the layout configurations for A, B, and C
|
||||
using ALayout = typename GemmPipeline::ALayout;
|
||||
using BLayout = typename GemmPipeline::BLayout;
|
||||
using CLayout = typename GemmPipeline::CLayout;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, and C
|
||||
using ADataType = typename GemmPipeline::ADataType;
|
||||
using BDataType = typename GemmPipeline::BDataType;
|
||||
using CDataType = typename EpiloguePipeline::ODataType;
|
||||
|
||||
template <typename T>
|
||||
static constexpr bool is_tuple_v = is_detected<is_tuple, T>::value;
|
||||
|
||||
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_tuple_v<ALayout> && !is_tuple_v<ADataType>,
|
||||
"ALayout and ADataType must be scalars.");
|
||||
|
||||
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_tuple_v<BLayout> && !is_tuple_v<BDataType>,
|
||||
"BLayout and BDataType must be scalars.");
|
||||
|
||||
/// @brief CLayout and CDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_tuple_v<CLayout> && !is_tuple_v<CDataType>,
|
||||
"CLayout and CDataType must be scalars.");
|
||||
|
||||
struct StreamKKernelArgs : ck_tile::UniversalGemmKernelArgs<>
|
||||
{
|
||||
StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid)
|
||||
: UniversalGemmKernelArgs{host_args.as_ptr,
|
||||
host_args.bs_ptr,
|
||||
host_args.ds_ptr,
|
||||
host_args.e_ptr,
|
||||
host_args.M,
|
||||
host_args.N,
|
||||
host_args.K,
|
||||
host_args.stride_As,
|
||||
host_args.stride_Bs,
|
||||
host_args.stride_Ds,
|
||||
host_args.stride_E,
|
||||
host_args.k_batch},
|
||||
reduction_strategy{host_args.reduction_strategy},
|
||||
// The workspace pointer is set to nullptr because we must first
|
||||
// instantiate the TilePartitioner to get the necessary size
|
||||
workspace_ptr{nullptr},
|
||||
tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}}
|
||||
|
||||
{
|
||||
}
|
||||
|
||||
/// @brief The strategy used by work groups to compute final results in C tensor.
|
||||
StreamKReductionStrategy reduction_strategy;
|
||||
/// @brief A pointer to a buffer in device memory for accumulating partial via reduction
|
||||
/// strategy.
|
||||
void* workspace_ptr;
|
||||
/// @brief An instance of the TilePartioner class for assisting with mapping workgroups to
|
||||
/// the C tensor.
|
||||
TilePartitioner tile_partitioner;
|
||||
};
|
||||
|
||||
using KernelArgs = StreamKKernelArgs;
|
||||
using Kernel = StreamKKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
using P_ = GemmPipeline;
|
||||
using WarpTile = typename P_::BlockGemmShape::WarpTile;
|
||||
|
||||
return concat('_', "streamk", gemm_prec_str<ADataType, BDataType>(),
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', WarpTile::at(number<0>{}), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
/// @brief Compute the grid size for the Stream K kernel using the tile_partitioner.
|
||||
/// @return The grid size.
|
||||
CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3
|
||||
{
|
||||
return tile_partitioner.grid_size();
|
||||
}
|
||||
|
||||
/// @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
|
||||
/// @return The maximum occupancy grid size.
|
||||
/// @note This function queries the maximum occupancy of the kernel using
|
||||
/// `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::BlockSize();
|
||||
}
|
||||
|
||||
/// @brief Constructs kernel arguments for the Stream-K kernel.
|
||||
/// @param host_args Stream-K host arguments.
|
||||
/// @param num_cu Number of compute units (CUs). The default is the number of CUs on the device.
|
||||
/// The caller may select their own to assist with test reproducibility, etc.
|
||||
/// @param occupancy The maximum number of active blocks per CU for this kernel. The caller may
|
||||
/// select their own to assist with test reproducibility, etc.
|
||||
/// @return The kernel arguments for Stream-K.
|
||||
CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args,
|
||||
int num_cu = NumCU(),
|
||||
int occupancy = Occupancy())
|
||||
{
|
||||
const index_t grid = num_cu * occupancy;
|
||||
|
||||
return StreamKKernelArgs{host_args, grid};
|
||||
}
|
||||
|
||||
template <bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void
|
||||
RunGemm(const std::array<const ADataType*, UniversalGemmKernel::NumATensor>& as_ptr,
|
||||
const std::array<const BDataType*, UniversalGemmKernel::NumBTensor>& bs_ptr,
|
||||
const std::array<const void*, UniversalGemmKernel::NumDTensor>& ds_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const typename UniversalGemmKernel::KernelArgs& kargs,
|
||||
const index_t num_loop,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n,
|
||||
const index_t k_size)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
|
||||
|
||||
const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
|
||||
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
|
||||
|
||||
// Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
|
||||
// has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
|
||||
// case, we call the GemmPipeline's operator() function that takes both has_hot_loop and
|
||||
// tail_num.
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
|
||||
bs_block_window[UniversalGemmKernel::I0],
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs)
|
||||
{
|
||||
if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
/// @brief Computes the buffer size needed to store accumulation results for Stream K.
|
||||
/// @return The buffer size needed.
|
||||
CK_TILE_HOST static uint32_t GetWorkSpaceSize(const StreamKKernelArgs& kargs)
|
||||
{
|
||||
return kargs.tile_partitioner.GetWorkSpaceSize(sizeof(CDataType));
|
||||
}
|
||||
|
||||
/// @brief Sets the kargs' current workspace_ptr to the given workspace_ptr.
|
||||
/// @note Assumes that the given workspace_ptr points to allocated device memory.
|
||||
CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr)
|
||||
{
|
||||
kargs.workspace_ptr = workspace_ptr;
|
||||
}
|
||||
|
||||
/// @brief Computes offsets into A, B, and C tensors then runs the GEMM pipeline and epilogue.
|
||||
/// @param kargs Stream-K kernel arguments.
|
||||
/// @param tile_idx The 1D tile index in the C tensor for this workgroup.
|
||||
/// @param num_loop The number of iterations (at the macro tile level) in the K dimension this
|
||||
/// workgroup will perform in the C tile.
|
||||
/// @param i_k_a The K offset in the A tensor.
|
||||
/// @param i_k_b The K offset in the B tensor.
|
||||
/// @param k_size The portion of the K dimension this workgroup processes in the assigned
|
||||
/// `tile_idx`.
|
||||
/// @param smem_ptr_0 Pointer to LDS.
|
||||
CK_TILE_DEVICE void BaseGemm(StreamKKernelArgs& kargs,
|
||||
index_t tile_idx,
|
||||
index_t num_loop,
|
||||
index_t i_k_a,
|
||||
index_t i_k_b,
|
||||
index_t k_size,
|
||||
void* smem_ptr_0) const
|
||||
{
|
||||
const auto c_macro_tile_idx = kargs.tile_partitioner.get_output_tile_index(tile_idx);
|
||||
index_t i_m = c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock;
|
||||
index_t i_n = c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock;
|
||||
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
|
||||
|
||||
// Run the GEMM pipeline and Epilogue.
|
||||
RunGemm(
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, k_size);
|
||||
}
|
||||
|
||||
/// @brief Runs the main Stream-K algorithm.
|
||||
/// @param kargs Stream-K kernel arguments.
|
||||
/// @param cta_idx The current Stream-K workgroup's index.
|
||||
/// @param smem_ptr_0 Pointer to LDS.
|
||||
/// @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
|
||||
/// non-persistent data-parallel (DP) section is used, then a Stream-K workgroup's `cta_idx`
|
||||
/// should be something like `blockIdx.x` minus number of DP workgroups.
|
||||
CK_TILE_DEVICE void
|
||||
StreamKGemm(StreamKKernelArgs& kargs, index_t cta_idx, void* smem_ptr_0) const
|
||||
{
|
||||
index_t iter_start, iter_end;
|
||||
kargs.tile_partitioner.get_iter_boundaries(iter_start, iter_end, cta_idx);
|
||||
|
||||
while(iter_start < iter_end)
|
||||
{
|
||||
// Get the 1D tile index in the C tensor that this workgroup will work in for this
|
||||
// iteration of the loop.
|
||||
index_t tile_idx =
|
||||
amd_wave_read_first_lane(kargs.tile_partitioner.get_tile_index(iter_start));
|
||||
|
||||
// Get the start and end boundaries for the current tile.
|
||||
index_t tile_iter_start, tile_iter_end;
|
||||
kargs.tile_partitioner.get_tile_boundaries(tile_iter_start, tile_iter_end, tile_idx);
|
||||
|
||||
// Get the start and end iteration within the current tile for the workgroup.
|
||||
index_t local_iter_start = amd_wave_read_first_lane(
|
||||
kargs.tile_partitioner.get_local_iter(iter_start, tile_iter_start));
|
||||
index_t local_iter_end =
|
||||
amd_wave_read_first_lane(kargs.tile_partitioner.get_local_iter_end(
|
||||
tile_iter_start, iter_end, tile_iter_end));
|
||||
|
||||
// Get the iteration length.
|
||||
index_t num_loop_sk = local_iter_end - local_iter_start;
|
||||
|
||||
// Determine the total size along the K dimension the workgroup is using in this
|
||||
// iteration (used to construct tensor views).
|
||||
index_t k_size = num_loop_sk * TilePartitioner::KPerBlock;
|
||||
|
||||
// Get the K offsets for the A and B tensors
|
||||
auto [i_k_a, i_k_b] = GetKOffsets<ALayout, BLayout>(
|
||||
local_iter_start, kargs.stride_As[0], kargs.stride_Bs[0]);
|
||||
|
||||
if constexpr(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: Apply reduction logic.
|
||||
}
|
||||
|
||||
// Prepare for next Stream-K loop iteration.
|
||||
iter_start = tile_iter_end;
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Entry point for the Stream-K Kernel with non-persistent DP.
|
||||
///
|
||||
/// @par Overview
|
||||
/// For the Non-Persistent kernel, each data parallel workgroup will
|
||||
/// compute the results for their assigned macro-tile by calling `BaseGemm()`.
|
||||
/// The Stream-K workgroups will do their assigned work by calling
|
||||
/// `StreamKGemm()`, which calls `BaseGemm()` in the Stream-K loop.
|
||||
template <bool U = PersistentDP>
|
||||
CK_TILE_DEVICE typename std::enable_if_t<!U> operator()(StreamKKernelArgs kargs) const
|
||||
{
|
||||
// Allocate LDS
|
||||
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
|
||||
|
||||
index_t block_idx = ck_tile::get_block_1d_id();
|
||||
index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
|
||||
index_t dp_ctas = kargs.tile_partitioner.get_dp_ctas();
|
||||
bool is_dp_ctas = block_idx < kargs.tile_partitioner.get_dp_ctas();
|
||||
|
||||
// Check if at the data parallel section
|
||||
if(is_dp_ctas)
|
||||
{
|
||||
BaseGemm(kargs, block_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Stream-K
|
||||
StreamKGemm(kargs, block_idx - dp_ctas, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Entry point for the Stream-K Kernel with persistent DP.
|
||||
///
|
||||
/// @par Overview
|
||||
/// For the Persistent kernel, each workgroup will first compute their
|
||||
/// assigned data-parallel tiles. Each data parallel tile will be computed
|
||||
/// by calling `BaseGemm()`. Then the workgroups will proceed with the
|
||||
/// Stream-K portion by calling `StreamKGemm()`, which calls `BaseGemm()`
|
||||
/// in the Stream-K loop.
|
||||
template <bool U = PersistentDP>
|
||||
CK_TILE_DEVICE typename std::enable_if_t<U> operator()(StreamKKernelArgs kargs) const
|
||||
{
|
||||
// Allocate LDS
|
||||
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
|
||||
|
||||
index_t block_idx = ck_tile::get_block_1d_id();
|
||||
index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
|
||||
|
||||
// Data-parallel section
|
||||
for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles();
|
||||
tile_idx += kargs.tile_partitioner.get_grid())
|
||||
{
|
||||
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
|
||||
}
|
||||
|
||||
// Stream-K section
|
||||
StreamKGemm(kargs, block_idx, smem_ptr_0);
|
||||
}
|
||||
|
||||
private:
|
||||
/// @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
|
||||
/// the starting macro tile index in the K dimension for the workgroup.
|
||||
/// @return A tuple containing the offsets into the A and B tensors accounting for the layouts
|
||||
/// of A and B.
|
||||
/// @note The default case is that A is assumed to be row major and B is assumed to be column
|
||||
/// major.
|
||||
template <typename ALayout, typename BLayout>
|
||||
CK_TILE_DEVICE static tuple<index_t, index_t>
|
||||
GetKOffsets(index_t iter_offset, index_t stride_a, index_t stride_b)
|
||||
{
|
||||
index_t stride_offset_a;
|
||||
index_t stride_offset_b;
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
stride_offset_a = stride_a;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_offset_a = 1;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
stride_offset_b = stride_b;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_offset_b = 1;
|
||||
}
|
||||
|
||||
index_t base_offset = iter_offset * TilePartitioner::KPerBlock;
|
||||
|
||||
return make_tuple(base_offset * stride_offset_a, base_offset * stride_offset_b);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static int NumCU()
|
||||
{
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
hip_check_error(hipGetDevice(&dev));
|
||||
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
int num_cu = dev_prop.multiProcessorCount;
|
||||
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
/// @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel
|
||||
/// @return The occupancy
|
||||
/// @note This function queries the maximum occupancy of the kernel using
|
||||
/// `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
CK_TILE_HOST static int Occupancy()
|
||||
{
|
||||
int occupancy;
|
||||
|
||||
// Since occupancy of 1 is valid for stream k, we set min_num_block_per_cu to 1
|
||||
constexpr int min_block_per_cu = 1;
|
||||
const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
|
||||
|
||||
hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
|
||||
|
||||
return occupancy;
|
||||
}
|
||||
};
|
||||
} // namespace reboot
|
||||
|
||||
/// @brief The Stream K GEMM kernel host arguments.
|
||||
///
|
||||
|
||||
@@ -186,6 +186,11 @@ struct StreamKTilePartitionerBase
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_n() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns an estimate of the number of workgroups writing to the same macro tile in C.
|
||||
*/
|
||||
CK_TILE_HOST index_t estimate_num_wgs_per_tile() const noexcept;
|
||||
|
||||
protected:
|
||||
index_t num_tiles_;
|
||||
index_t grid_;
|
||||
@@ -246,6 +251,7 @@ struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true
|
||||
ck_tile::index_t grid);
|
||||
|
||||
public:
|
||||
static constexpr bool PERSISTENT = true;
|
||||
/**
|
||||
* @brief Calculates the launching grid size for the Stream-K kernel. In the Persistent
|
||||
* case, no extra workgroups are allocated for the data parallel section, making the grid
|
||||
@@ -292,6 +298,7 @@ struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, fals
|
||||
ck_tile::index_t grid);
|
||||
|
||||
public:
|
||||
static constexpr bool PERSISTENT = false;
|
||||
/**
|
||||
* @brief Calculates the launching grid size for the Stream-K kernel. In the Non-Persistent
|
||||
* case, extra workgroups are allocated for the data parallel section, making the grid
|
||||
|
||||
@@ -214,6 +214,27 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_n() c
|
||||
return n_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::estimate_num_wgs_per_tile()
|
||||
const noexcept
|
||||
{
|
||||
// In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup
|
||||
// writing final results to a given macro tile in C.
|
||||
int num_wgs_per_tile = 1;
|
||||
|
||||
// Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
|
||||
if(sk_ctas_ > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1);
|
||||
// Estimate the number of workgroups per macro tile.
|
||||
num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) +
|
||||
((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0);
|
||||
}
|
||||
|
||||
return std::max(num_wgs_per_tile, 1);
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType,
|
||||
StreamKReductionStrategy ReductionStrategyType,
|
||||
bool Persistent>
|
||||
|
||||
@@ -309,6 +309,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as corresponding "
|
||||
"C block tensor data type!");
|
||||
constexpr auto warp_size = get_warp_size();
|
||||
|
||||
// hot loop:
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
@@ -352,7 +353,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
else
|
||||
{
|
||||
// Many N warps/iters share the same scale, index from full
|
||||
// [NQPerBlock=1, BQPerBlock] matrix
|
||||
// [KQPerBlock, NQPerBlock=1] matrix
|
||||
static_assert(Traits::NQPerBlock == 1);
|
||||
return kQScale;
|
||||
}
|
||||
@@ -366,11 +367,11 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
|
||||
|
||||
static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
|
||||
});
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -687,8 +687,8 @@ struct QuantGemmKernel
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
|
||||
make_tuple(kargs.stride_BQ, 1),
|
||||
make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
|
||||
make_tuple(1, kargs.stride_BQ),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
@@ -908,9 +908,9 @@ struct QuantGemmKernel
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
|
||||
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
|
||||
{i_n / QuantGroupSize::kN, 0});
|
||||
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
|
||||
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
|
||||
{0, i_n / QuantGroupSize::kN});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -375,30 +375,48 @@ struct QuantGroupedGemmKernel
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I4);
|
||||
if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
|
||||
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(c_block_window)>(
|
||||
c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
smem_ptr_0,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I4);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::TensorQuant)
|
||||
else
|
||||
{
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I4);
|
||||
if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
|
||||
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
smem_ptr_0,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -55,8 +55,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
tile_distribution_encoding_pattern_bq<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
NPerBlockBQ,
|
||||
KPerBlockBQ,
|
||||
NPerBlockBQ,
|
||||
Problem::QuantGroupSize::kN,
|
||||
VecLoadSize>;
|
||||
|
||||
|
||||
@@ -259,8 +259,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
|
||||
static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
|
||||
static_assert(KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
|
||||
"Bq block window has incorrect lengths for defined BqLayout!");
|
||||
|
||||
static_assert(is_a_col_major
|
||||
@@ -318,7 +318,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr BQDramTileWindowStep bq_dram_tile_window_step =
|
||||
is_bq_col_major ? make_array(0, KPerBlockBQ) : make_array(KPerBlockBQ, 0);
|
||||
is_bq_col_major ? make_array(KPerBlockBQ, 0) : make_array(0, KPerBlockBQ);
|
||||
|
||||
// DRAM prefetch (global read 0)
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
@@ -363,6 +363,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
constexpr index_t tail_count =
|
||||
((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) ? 1 : 2;
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
@@ -408,7 +410,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
} while(i < (num_loop - tail_count));
|
||||
}
|
||||
// tail
|
||||
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
|
||||
@@ -475,6 +477,49 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
/// @brief Runtime pipeline dispatch operator for grouped GEMM kernels.
|
||||
///
|
||||
/// This operator is used by grouped GEMM kernels where pipeline parameters
|
||||
/// (has_hot_loop, num_loop, tail_number) are calculated on the device side
|
||||
/// at runtime, not on the host side during compilation. This is necessary
|
||||
/// because different GEMM problems in the group may have different K dimensions,
|
||||
/// requiring different pipeline configurations that cannot be determined at
|
||||
/// compile time.
|
||||
///
|
||||
/// @param a_dram_block_window_tmp Block window for A tensor in DRAM
|
||||
/// @param b_dram_block_window_tmp Block window for B tensor in DRAM
|
||||
/// @param bq_dram_block_window_tmp Block window for BQ (quantization scale) tensor in DRAM
|
||||
/// @param num_loop Number of main loop iterations (calculated on device)
|
||||
/// @param has_hot_loop Whether the pipeline has a hot loop (calculated on device)
|
||||
/// @param tail_number Type of tail handling required (calculated on device)
|
||||
/// @param p_smem Pointer to shared memory
|
||||
/// @return Accumulated result tile in registers
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename BQDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) {
|
||||
constexpr bool hot_loop = has_hot_loop_.value;
|
||||
constexpr auto tail_num = tail_number_.value;
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
bq_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -192,51 +192,51 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
{
|
||||
if constexpr(YPerQ == 1)
|
||||
{
|
||||
// YPerQ == 1 implementation - each row of B has independent scale
|
||||
constexpr index_t X = XPerTile;
|
||||
constexpr index_t XR = 2;
|
||||
constexpr index_t Y0 = NIterPerWarp;
|
||||
constexpr index_t Y1 = NWarps;
|
||||
constexpr index_t Y2 = WarpGemm::kN;
|
||||
// each row of B has independent scale
|
||||
constexpr index_t Y = XPerTile;
|
||||
constexpr index_t YR = 1;
|
||||
constexpr index_t X0 = NIterPerWarp;
|
||||
constexpr index_t X1 = NWarps;
|
||||
constexpr index_t X2 = WarpGemm::kN;
|
||||
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
|
||||
static_assert(X0 * X1 * X2 == XPerTile, "X0, X1, X2 must cover the blocktile along X.");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, XR>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{});
|
||||
tile_distribution_encoding<sequence<MWarps, YR>,
|
||||
tuple<sequence<Y>, sequence<X0, X1, X2>>,
|
||||
tuple<sequence<0, 2>, sequence<0, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
else if constexpr(YPerTile >= NIterPerWarp * NWarps)
|
||||
else if constexpr(XPerTile >= NIterPerWarp * NWarps)
|
||||
{
|
||||
// small NQ block size case: split NQ axis by iters and Nwarps
|
||||
constexpr index_t NQPerIter = integer_divide_ceil(YPerTile, NIterPerWarp * NWarps);
|
||||
constexpr index_t XR = get_warp_size() / NQPerIter;
|
||||
static_assert(YPerTile == NQPerIter * NWarps * NIterPerWarp);
|
||||
constexpr index_t NQPerIter = integer_divide_ceil(XPerTile, NIterPerWarp * NWarps);
|
||||
constexpr index_t YR = get_warp_size() / NQPerIter;
|
||||
static_assert(XPerTile == NQPerIter * NWarps * NIterPerWarp);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<MWarps, XR>,
|
||||
tuple<sequence<NIterPerWarp, NWarps, NQPerIter>, sequence<XPerTile>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
sequence<MWarps, YR>,
|
||||
tuple<sequence<XPerTile>, sequence<NIterPerWarp, NWarps, NQPerIter>>,
|
||||
tuple<sequence<0, 2>, sequence<0, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
else if constexpr(YPerTile >= NIterPerWarp)
|
||||
else if constexpr(XPerTile >= NIterPerWarp)
|
||||
{
|
||||
// now all NWarps have the same scale -> replicate
|
||||
constexpr index_t NQPerIter = integer_divide_ceil(YPerTile, NIterPerWarp);
|
||||
constexpr index_t XR = get_warp_size() / NQPerIter;
|
||||
constexpr index_t YR = get_warp_size() / NQPerIter;
|
||||
static_assert(YPerTile == NQPerIter * NIterPerWarp);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<MWarps, NWarps, XR>,
|
||||
tuple<sequence<NIterPerWarp, NQPerIter>, sequence<XPerTile>>,
|
||||
tuple<sequence<0, 0>, sequence<0, 1>>,
|
||||
sequence<MWarps, NWarps, YR>,
|
||||
tuple<sequence<XPerTile>, sequence<NIterPerWarp, NQPerIter>>,
|
||||
tuple<sequence<0, 2>, sequence<0, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
else
|
||||
@@ -245,18 +245,13 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
// threads
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NWarps, get_warp_size()>,
|
||||
tuple<sequence<YPerTile>, sequence<XPerTile>>,
|
||||
tuple<sequence<XPerTile>, sequence<YPerTile>>,
|
||||
tuple<sequence<0, 0>, sequence<0>>,
|
||||
tuple<sequence<0, 1>, sequence<2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_fully_replicated()
|
||||
{
|
||||
return YPerQ > 1 && YPerTile < NIterPerWarp * NWarps;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GroupSizes>
|
||||
|
||||
@@ -237,7 +237,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
// BQ DRAM window for load
|
||||
auto bq_copy_dram_window =
|
||||
make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<KPerBlockBQ>{}),
|
||||
make_tuple(number<KPerBlockBQ>{}, number<kNPerBlock>{}),
|
||||
bq_dram_block_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeBQDramTileDistribution<Problem>());
|
||||
|
||||
@@ -270,7 +270,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
BQBlockTile bq_block_tile, bq_block_tile_2;
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
// move BQ to tile 1
|
||||
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
|
||||
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
|
||||
|
||||
// Prefill A0
|
||||
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
@@ -319,7 +319,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
bq_block_tile_2 = load_tile(bq_copy_dram_window);
|
||||
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
|
||||
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
|
||||
|
||||
// Prefill A(2i+1)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
@@ -361,7 +361,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
|
||||
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
|
||||
|
||||
// Prefill A(2i+2)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
|
||||
@@ -7,10 +7,12 @@
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
|
||||
|
||||
@@ -28,6 +30,7 @@ struct GroupedConvFwdKernelArgs
|
||||
GroupedConvTraitsType_::VectorSizeB,
|
||||
GroupedConvTraitsType_::VectorSizeC,
|
||||
true>; // Split N enabled
|
||||
using CDElementwise = typename GroupedConvTraitsType_::CDElementwise;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
template <
|
||||
@@ -38,7 +41,8 @@ struct GroupedConvFwdKernelArgs
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& args)
|
||||
: elfunc(args.elfunc)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
@@ -121,7 +125,8 @@ struct GroupedConvFwdKernelArgs
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& args)
|
||||
: elfunc(args.elfunc)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
@@ -213,7 +218,8 @@ struct GroupedConvFwdKernelArgs
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& args)
|
||||
: elfunc(args.elfunc)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
@@ -335,6 +341,7 @@ struct GroupedConvFwdKernelArgs
|
||||
const void* in_ptr;
|
||||
const void* wei_ptr;
|
||||
std::array<const void*, NumDTensor> ds_ptr;
|
||||
const CDElementwise elfunc;
|
||||
void* out_ptr;
|
||||
|
||||
AGridDescMK a_grid_desc_m_k;
|
||||
@@ -423,6 +430,8 @@ struct GroupedConvolutionForwardKernel
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using OutDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using CDElementwise = typename EpiloguePipeline::CDElementwise;
|
||||
|
||||
using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs<GroupedConvTraitsType_>;
|
||||
|
||||
// TODO: Enable this
|
||||
@@ -458,7 +467,7 @@ struct GroupedConvolutionForwardKernel
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized
|
||||
MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs)
|
||||
MakeKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& hostArgs)
|
||||
{
|
||||
return GroupedConvFwdKernelArgsSpecialized(hostArgs);
|
||||
}
|
||||
@@ -636,7 +645,7 @@ struct GroupedConvolutionForwardKernel
|
||||
"Not supported!");
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
|
||||
static_cast<const OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
@@ -765,8 +774,9 @@ struct GroupedConvolutionForwardKernel
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{kargs.elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -14,7 +15,7 @@ namespace ck_tile {
|
||||
/// This structure is passed to Grouped Convolution Kernels when creating kernel
|
||||
/// arguments object. It contain all necessary information required to
|
||||
/// build proper kernel argument and launch kernel on GPU.
|
||||
template <typename InPtr, typename WeiPtr, typename OutPtr>
|
||||
template <typename InPtr, typename WeiPtr, typename OutPtr, typename CDElementwise>
|
||||
struct GroupedConvHostArgs : public conv::ConvParam
|
||||
{
|
||||
CK_TILE_HOST GroupedConvHostArgs() = delete;
|
||||
@@ -23,13 +24,15 @@ struct GroupedConvHostArgs : public conv::ConvParam
|
||||
WeiPtr wei_ptr_,
|
||||
const std::vector<const void*> ds_ptr_,
|
||||
OutPtr out_ptr_,
|
||||
index_t k_batch_)
|
||||
index_t k_batch_,
|
||||
CDElementwise elfunc_ = CDElementwise{})
|
||||
: conv::ConvParam(conv_param),
|
||||
in_ptr(in_ptr_),
|
||||
wei_ptr(wei_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
out_ptr(out_ptr_),
|
||||
k_batch(k_batch_)
|
||||
k_batch(k_batch_),
|
||||
elfunc(elfunc_)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -38,11 +41,17 @@ struct GroupedConvHostArgs : public conv::ConvParam
|
||||
const std::vector<const void*> ds_ptr;
|
||||
OutPtr out_ptr;
|
||||
index_t k_batch;
|
||||
const CDElementwise elfunc;
|
||||
};
|
||||
|
||||
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*>;
|
||||
using GroupedConvBwdWeightHostArgs = GroupedConvHostArgs<const void*, void*, const void*>;
|
||||
using GroupedConvBwdDataHostArgs = GroupedConvHostArgs<void*, const void*, const void*>;
|
||||
using PassThrough = ck_tile::element_wise::PassThrough;
|
||||
|
||||
template <typename CDElementwise = PassThrough>
|
||||
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*, CDElementwise>;
|
||||
using GroupedConvBwdWeightHostArgs =
|
||||
GroupedConvHostArgs<const void*, void*, const void*, PassThrough>;
|
||||
using GroupedConvBwdDataHostArgs =
|
||||
GroupedConvHostArgs<void*, const void*, const void*, PassThrough>;
|
||||
|
||||
template <index_t NDimSpatial_,
|
||||
ConvolutionSpecialization ConvSpecialization_,
|
||||
@@ -50,9 +59,10 @@ template <index_t NDimSpatial_,
|
||||
typename WeiLayout_,
|
||||
typename DsLayout_,
|
||||
typename OutLayout_,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1,
|
||||
index_t VectorSizeC_ = 1>
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1,
|
||||
index_t VectorSizeC_ = 1,
|
||||
typename CDElementwise_ = PassThrough>
|
||||
struct GroupedConvTraits
|
||||
{
|
||||
private:
|
||||
@@ -70,6 +80,7 @@ struct GroupedConvTraits
|
||||
using WeiLayout = WeiLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using OutLayout = OutLayout_;
|
||||
using CDElementwise = CDElementwise_;
|
||||
using GroupedConvImplicitGemmTraitsFwd =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
|
||||
@@ -300,7 +300,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#endif
|
||||
}
|
||||
|
||||
// layout NGCHW/GKYXC/NGKHW
|
||||
// layout NGCHW/GKCYX/NGKHW
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NGCHW> &&
|
||||
is_same_v<WeiLayout, GKCYX> && is_same_v<OutLayout, NGKHW>)
|
||||
{
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include <memory>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dynamic.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
@@ -161,7 +161,7 @@ template <ck::index_t NumDimSpatial,
|
||||
typename DDataTypes,
|
||||
typename OutDataType,
|
||||
typename AComputeType,
|
||||
typename BComputeType = AComputeType>
|
||||
typename BComputeType>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
|
||||
NumDimSpatial,
|
||||
InLayout,
|
||||
|
||||
@@ -69,7 +69,7 @@ void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
PassThrough,
|
||||
Scale>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#!/bin/bash
|
||||
#!/usr/bin/env bash
|
||||
# Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ set(REGRESSION_TESTS
|
||||
test_ck_tile_fmha_fwd_bf16
|
||||
test_ck_tile_fmha_fwd_fp16
|
||||
test_ck_tile_fmha_fwd_fp8
|
||||
test_ck_tile_streamk_reboot_extended
|
||||
)
|
||||
|
||||
function(add_test_executable TEST_NAME)
|
||||
|
||||
@@ -33,6 +33,14 @@ struct elementwise_op_traits<ck_tile::element_wise::Relu>
|
||||
static constexpr int num_inputs = 1;
|
||||
};
|
||||
|
||||
using NegRelu =
|
||||
ck_tile::element_wise::Compose<ck_tile::element_wise::Relu, ck_tile::element_wise::Neg>;
|
||||
template <>
|
||||
struct elementwise_op_traits<NegRelu>
|
||||
{
|
||||
static constexpr int num_inputs = 1;
|
||||
};
|
||||
|
||||
template <std::size_t D, typename F>
|
||||
auto make_uniform_array_with_factory(F&& factory)
|
||||
{
|
||||
@@ -194,7 +202,11 @@ using TestConfig_F16_Add = std::tuple<ck_tile::half_t,
|
||||
Shape1_BlockTile,
|
||||
Shape1_WarpTile>;
|
||||
|
||||
using TestTypes = ::testing::Types<TestConfig_F32_Add, TestConfig_F32_Relu, TestConfig_F16_Add>;
|
||||
using TestConfig_F32_Neg_Relu =
|
||||
std::tuple<float, float, float, NegRelu, Shape1_BlockWarps, Shape1_BlockTile, Shape1_WarpTile>;
|
||||
|
||||
using TestTypes = ::testing::
|
||||
Types<TestConfig_F32_Add, TestConfig_F32_Relu, TestConfig_F16_Add, TestConfig_F32_Neg_Relu>;
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileElementwise, TestTypes);
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "test_gemm_quant_base.hpp"
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
#include "ck_tile/host/shuffle_utils.hpp"
|
||||
#include "ck_tile/host/tensor_shuffle_utils.hpp"
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
|
||||
@@ -20,20 +20,19 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using KernelTypes = ::testing::Types<
|
||||
// Has cshuffle epilogue enabled
|
||||
// A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>
|
||||
|
||||
// Currently MultiABD kernel doesn't support F8 data type
|
||||
//std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
//std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -20,19 +20,17 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using KernelTypes = ::testing::Types<
|
||||
// Has cshuffle epilogue disabled
|
||||
// A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>
|
||||
|
||||
// Currently MultiABD kernel doesn't support F8 data type
|
||||
//std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
//std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -1,5 +1,95 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x512x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x256x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x512x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x256x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x768x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x1280x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x1280x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_768x512x256)
|
||||
{
|
||||
constexpr int M = 768;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x512x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x512x512)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
|
||||
@@ -13,40 +13,9 @@
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
struct AddScale
|
||||
{
|
||||
template <typename E, typename A0, typename A1>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const A0& a0, const A1& a1) const
|
||||
{
|
||||
a = scale * (ck_tile::type_convert<float>(a0) + ck_tile::type_convert<float>(a1));
|
||||
}
|
||||
|
||||
float scale = 1.0;
|
||||
};
|
||||
|
||||
struct MultiplyMultiply
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void
|
||||
{
|
||||
const float x0_f = ck_tile::type_convert<float>(c) * ck_tile::type_convert<float>(d0) *
|
||||
ck_tile::type_convert<float>(d1);
|
||||
|
||||
e = ck_tile::type_convert<E>(x0_f);
|
||||
}
|
||||
};
|
||||
|
||||
struct ElementWiseAddAdd
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void
|
||||
{
|
||||
const float x0_f = ck_tile::type_convert<float>(c) + ck_tile::type_convert<float>(d0) +
|
||||
ck_tile::type_convert<float>(d1);
|
||||
|
||||
e = ck_tile::type_convert<E>(x0_f);
|
||||
}
|
||||
};
|
||||
using AddScale = ck_tile::element_wise::AddScale;
|
||||
using ElementWiseAddAdd = ck_tile::element_wise::MultiDAdd;
|
||||
using MultiplyMultiply = ck_tile::element_wise::MultiDMultiply;
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
set(EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
|
||||
-mllvm
|
||||
-enable-noalias-to-md-conversion=0
|
||||
)
|
||||
set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
|
||||
|
||||
# Currently test_ck_tile_streamk is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
|
||||
@@ -6,23 +20,33 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
#TODO: support all arches
|
||||
#TODO: current c-shuffle only supports C layout as R
|
||||
add_gtest_executable(test_ck_tile_streamk_smoke
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rcc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_crc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rcc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_crc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
)
|
||||
# TODO: enable extended tests after tolerances for atomic reductions are addressed.
|
||||
# add_gtest_executable(test_ck_tile_streamk_extended
|
||||
@@ -117,6 +141,19 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# )
|
||||
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
|
||||
add_gtest_executable(test_ck_tile_streamk_reboot_smoke
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_fp16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_bf16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp
|
||||
test_gemm_streamk_reboot_util.cpp)
|
||||
add_gtest_executable(test_ck_tile_streamk_reboot_extended
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp
|
||||
test_gemm_streamk_reboot_util.cpp)
|
||||
target_compile_options(test_ck_tile_streamk_smoke PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping test_ck_tile_streamk tests for current target")
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_reboot_types.hpp"
|
||||
#include "test_gemm_streamk_reboot_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKRebootBf16NonPersistent : public TestCkTileStreamKReboot<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16NonPersistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16NonPersistent, KernelTypesStreamKBf16NonPersistent);
|
||||
|
||||
#include "test_gemm_streamk_reboot_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,19 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_reboot_types.hpp"
|
||||
#include "test_gemm_streamk_reboot_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKRebootBf16Persistent : public TestCkTileStreamKReboot<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16Persistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16Persistent, KernelTypesStreamKBf16Persistent);
|
||||
|
||||
#include "test_gemm_streamk_reboot_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,19 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_reboot_types.hpp"
|
||||
#include "test_gemm_streamk_reboot_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKRebootFp16NonPersistent : public TestCkTileStreamKReboot<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16NonPersistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16NonPersistent, KernelTypesStreamKFp16NonPersistent);
|
||||
|
||||
#include "test_gemm_streamk_reboot_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,19 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_reboot_types.hpp"
|
||||
#include "test_gemm_streamk_reboot_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKRebootFp16Persistent : public TestCkTileStreamKReboot<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16Persistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16Persistent, KernelTypesStreamKFp16Persistent);
|
||||
|
||||
#include "test_gemm_streamk_reboot_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,11 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF8_CCR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -0,0 +1,11 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF8_CRR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user